Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions openml/datasets/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,8 +364,10 @@ def __eq__(self, other: Any) -> bool:

# check that common keys and values are identical
ignore_fields = server_fields | cache_fields
self_keys = set(self.__dict__.keys()) - ignore_fields
other_keys = set(other.__dict__.keys()) - ignore_fields

self_keys = set(self.__dict__) - ignore_fields
other_keys = set(other.__dict__) - ignore_fields

return self_keys == other_keys and all(
self.__dict__[key] == other.__dict__[key] for key in self_keys
)
Expand Down Expand Up @@ -616,7 +618,7 @@ def _parse_data_from_pq(self, data_file: Path) -> tuple[list[str], list[bool], p
data = pd.read_parquet(data_file)
except Exception as e:
raise Exception(f"File: {data_file}") from e
categorical = [data[c].dtype.name == "category" for c in data.columns]
categorical = [isinstance(data[c].dtype, pd.CategoricalDtype) for c in data.columns]
attribute_names = list(data.columns)
return attribute_names, categorical, data

Expand Down
4 changes: 2 additions & 2 deletions openml/flows/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,7 +510,7 @@ def assert_flows_equal( # noqa: C901, PLR0912, PLR0913, PLR0915
]
ignored_by_python_api = ["binary_url", "binary_format", "binary_md5", "model", "_entity_id"]

for key in set(flow1.__dict__.keys()).union(flow2.__dict__.keys()):
for key in set(flow1.__dict__).union(flow2.__dict__):
if key in generated_by_the_server + ignored_by_python_api:
continue
attr1 = getattr(flow1, key, None)
Expand All @@ -519,7 +519,7 @@ def assert_flows_equal( # noqa: C901, PLR0912, PLR0913, PLR0915
if not (isinstance(attr1, dict) and isinstance(attr2, dict)):
raise TypeError("Cannot compare components because they are not dictionary.")

for name in set(attr1.keys()).union(attr2.keys()):
for name in set(attr1).union(attr2):
if name not in attr1:
raise ValueError(
f"Component {name} only available in argument2, but not in argument1.",
Expand Down