diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index 3447fee1..4446cb6e 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -24,7 +24,7 @@ jobs: strategy: matrix: - python-version: ["3.10"] + python-version: ["3.11"] steps: - name: Checkout code @@ -67,7 +67,7 @@ jobs: strategy: matrix: - python-version: ["3.10"] + python-version: ["3.11"] env: HF_TOKEN: ${{ secrets.HF_TOKEN }} diff --git a/pyproject.toml b/pyproject.toml index 11287b0b..9cabd9cc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -103,7 +103,7 @@ dependencies = [ "transformers", "pytorch-lightning", "huggingface-hub[hf-xet]>=0.30.0", - "datasets>=0.34", + "datasets>=3.0", "numpy>=1.24.4", "numpydoc>=1.6.0", "diffusers>=0.21.4", diff --git a/src/pruna/algorithms/base/tags.py b/src/pruna/algorithms/base/tags.py index f2e37e6f..b18379d8 100644 --- a/src/pruna/algorithms/base/tags.py +++ b/src/pruna/algorithms/base/tags.py @@ -37,6 +37,9 @@ class AlgorithmTag(Enum): The type of the enum. start : int The start index for auto-numbering enum values. + boundary : enum.FlagBoundary or None + Boundary handling mode used by the Enum functional API for Flag and + IntFlag enums. """ QUANTIZER = ( diff --git a/src/pruna/engine/load.py b/src/pruna/engine/load.py index 7380cf71..b310b593 100644 --- a/src/pruna/engine/load.py +++ b/src/pruna/engine/load.py @@ -565,6 +565,9 @@ class LOAD_FUNCTIONS(Enum): # noqa: N801 The type of the enum. start : int The start index for auto-numbering enum values. + boundary : enum.FlagBoundary or None + Boundary handling mode used by the Enum functional API for Flag and + IntFlag enums. Examples -------- diff --git a/src/pruna/engine/save.py b/src/pruna/engine/save.py index 7173d355..0c741017 100644 --- a/src/pruna/engine/save.py +++ b/src/pruna/engine/save.py @@ -528,6 +528,9 @@ class SAVE_FUNCTIONS(Enum): # noqa: N801 The type of the enum. start : int The start index for auto-numbering enum values. + boundary : enum.FlagBoundary or None + Boundary handling mode used by the Enum functional API for Flag and + IntFlag enums. Examples -------- diff --git a/src/pruna/evaluation/metrics/metric_torch.py b/src/pruna/evaluation/metrics/metric_torch.py index e1dfb0e0..9b402872 100644 --- a/src/pruna/evaluation/metrics/metric_torch.py +++ b/src/pruna/evaluation/metrics/metric_torch.py @@ -168,6 +168,9 @@ class TorchMetrics(Enum): The type of the enum value. start : int The starting value for the enum. + boundary : enum.FlagBoundary or None + Boundary handling mode used by the Enum functional API for Flag and + IntFlag enums. """ fid = (partial(FrechetInceptionDistance), fid_update, "gt_y")