diff --git a/docs/tutorials/index.rst b/docs/tutorials/index.rst index 238309c6..d4a364f1 100644 --- a/docs/tutorials/index.rst +++ b/docs/tutorials/index.rst @@ -70,7 +70,6 @@ These tutorials will guide you through the process of using |pruna| to optimize Optimize your ``diffusion`` model with ``deepcache`` ``caching``. - .. toctree:: :hidden: :maxdepth: 1 diff --git a/src/pruna/algorithms/batching/ws2t.py b/src/pruna/algorithms/batching/ws2t.py index ba7fe6d7..0b654acc 100644 --- a/src/pruna/algorithms/batching/ws2t.py +++ b/src/pruna/algorithms/batching/ws2t.py @@ -27,8 +27,8 @@ from pruna.algorithms.batching import PrunaBatcher from pruna.algorithms.compilation.c_translate import WhisperWrapper +from pruna.config.hyperparameters import Boolean from pruna.config.smash_config import SmashConfigPrefixWrapper -from pruna.config.smash_space import Boolean from pruna.engine.model_checks import is_speech_seq2seq_model, is_transformers_pipeline_with_speech_recognition from pruna.logging.filter import SuppressOutput from pruna.logging.logger import pruna_logger diff --git a/src/pruna/algorithms/compilation/torch_compile.py b/src/pruna/algorithms/compilation/torch_compile.py index 46afd070..4f68e98c 100644 --- a/src/pruna/algorithms/compilation/torch_compile.py +++ b/src/pruna/algorithms/compilation/torch_compile.py @@ -21,8 +21,8 @@ from pruna.algorithms.compilation import PrunaCompiler from pruna.algorithms.compilation.utils import CausalLMGenerator, JanusGenerator +from pruna.config.hyperparameters import Boolean from pruna.config.smash_config import SmashConfig, SmashConfigPrefixWrapper -from pruna.config.smash_space import Boolean from pruna.engine.model_checks import ( get_diffusers_transformer_models, get_diffusers_unet_models, diff --git a/src/pruna/algorithms/pruna_base.py b/src/pruna/algorithms/pruna_base.py index b7add300..42d84000 100644 --- a/src/pruna/algorithms/pruna_base.py +++ b/src/pruna/algorithms/pruna_base.py @@ -179,6 +179,28 @@ def get_hyperparameters(self) -> list: """ return [] + def get_model_dependent_hyperparameter_defaults( + self, model: Any, smash_config: SmashConfig | SmashConfigPrefixWrapper + ) -> Any: + """ + Get default values for unconstrained hyperparameters based on the model and configuration. + + Subclasses can override this method to provide default values for their unconstrained hyperparameters. + + Parameters + ---------- + model : Any + The model to get the default hyperparameters from. + smash_config : SmashConfig + The SmashConfig object. + + Returns + ------- + Any + The default unconstrained hyperparameters values for the algorithm. + """ + return None + def pre_smash_hook(self, model: Any, smash_config: SmashConfig) -> None: """ Perform any necessary actions before the smashing process begins. diff --git a/src/pruna/algorithms/pruning/torch_structured.py b/src/pruna/algorithms/pruning/torch_structured.py index f23f1f72..ba55eba1 100644 --- a/src/pruna/algorithms/pruning/torch_structured.py +++ b/src/pruna/algorithms/pruning/torch_structured.py @@ -24,8 +24,8 @@ from transformers.models.opt.modeling_opt import OPTForCausalLM as Opt from pruna.algorithms.pruning import PrunaPruner +from pruna.config.hyperparameters import Boolean from pruna.config.smash_config import SmashConfigPrefixWrapper -from pruna.config.smash_space import Boolean from pruna.engine.save import SAVE_FUNCTIONS from pruna.logging.logger import pruna_logger diff --git a/src/pruna/algorithms/quantization/gptq_model.py b/src/pruna/algorithms/quantization/gptq_model.py index ed0d6a09..d33b14f9 100644 --- a/src/pruna/algorithms/quantization/gptq_model.py +++ b/src/pruna/algorithms/quantization/gptq_model.py @@ -18,8 +18,8 @@ from ConfigSpace import OrdinalHyperparameter from pruna.algorithms.quantization import PrunaQuantizer +from pruna.config.hyperparameters import Boolean from pruna.config.smash_config import SmashConfigPrefixWrapper -from pruna.config.smash_space import Boolean from pruna.data.utils import recover_text_from_dataloader from pruna.engine.model_checks import is_causal_lm, is_transformers_pipeline_with_causal_lm from pruna.engine.utils import safe_memory_cleanup diff --git a/src/pruna/algorithms/quantization/hqq.py b/src/pruna/algorithms/quantization/hqq.py index 61bacd54..e488b7ca 100644 --- a/src/pruna/algorithms/quantization/hqq.py +++ b/src/pruna/algorithms/quantization/hqq.py @@ -21,8 +21,8 @@ from transformers import AutoModelForCausalLM from pruna.algorithms.quantization import PrunaQuantizer +from pruna.config.hyperparameters import Boolean from pruna.config.smash_config import SmashConfigPrefixWrapper -from pruna.config.smash_space import Boolean from pruna.engine.model_checks import is_causal_lm, is_janus_llamagen_ar, is_transformers_pipeline_with_causal_lm from pruna.engine.save import SAVE_FUNCTIONS from pruna.engine.utils import ModelContext, move_to_device, safe_memory_cleanup diff --git a/src/pruna/algorithms/quantization/huggingface_diffusers_int8.py b/src/pruna/algorithms/quantization/huggingface_diffusers_int8.py index e107802b..2ce7b417 100644 --- a/src/pruna/algorithms/quantization/huggingface_diffusers_int8.py +++ b/src/pruna/algorithms/quantization/huggingface_diffusers_int8.py @@ -20,8 +20,8 @@ from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig from pruna.algorithms.quantization import PrunaQuantizer +from pruna.config.hyperparameters import Boolean from pruna.config.smash_config import SmashConfigPrefixWrapper -from pruna.config.smash_space import Boolean from pruna.engine.model_checks import ( get_diffusers_transformer_models, get_diffusers_unet_models, diff --git a/src/pruna/algorithms/quantization/huggingface_llm_int8.py b/src/pruna/algorithms/quantization/huggingface_llm_int8.py index ffb7c60b..77ee69f4 100644 --- a/src/pruna/algorithms/quantization/huggingface_llm_int8.py +++ b/src/pruna/algorithms/quantization/huggingface_llm_int8.py @@ -20,8 +20,8 @@ from transformers import AutoModelForCausalLM, BitsAndBytesConfig from pruna.algorithms.quantization import PrunaQuantizer +from pruna.config.hyperparameters import Boolean from pruna.config.smash_config import SmashConfigPrefixWrapper -from pruna.config.smash_space import Boolean from pruna.engine.model_checks import is_causal_lm, is_transformers_pipeline_with_causal_lm from pruna.engine.utils import get_device_map, move_to_device diff --git a/src/pruna/algorithms/quantization/quanto.py b/src/pruna/algorithms/quantization/quanto.py index 19a7b75a..89d9dc0f 100644 --- a/src/pruna/algorithms/quantization/quanto.py +++ b/src/pruna/algorithms/quantization/quanto.py @@ -12,16 +12,21 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + from typing import Any, Dict import torch from ConfigSpace import Constant, OrdinalHyperparameter +from pruna import SmashConfig from pruna.algorithms.quantization import PrunaQuantizer +from pruna.config.hyperparameters import Boolean from pruna.config.smash_config import SmashConfigPrefixWrapper -from pruna.config.smash_space import Boolean +from pruna.config.target_modules import TARGET_MODULES_TYPE, TargetModules, map_targeted_nn_roots from pruna.data.utils import wrap_batch_for_model_call from pruna.engine.save import SAVE_FUNCTIONS +from pruna.engine.utils import get_nn_modules from pruna.logging.logger import pruna_logger @@ -63,6 +68,11 @@ def get_hyperparameters(self) -> list: Constant("act_bits", value=None), Boolean("calibrate", default=True, meta=dict(desc="Whether to calibrate the model.")), Constant(name="calibration_samples", value=64), + TargetModules( + name="target_modules", + default_value=None, + meta=dict(desc="Precise choices of which modules to quantize."), + ), ] def model_check_fn(self, model: Any) -> bool: @@ -85,6 +95,33 @@ def model_check_fn(self, model: Any) -> bool: return True return hasattr(model, "transformer") and isinstance(model.transformer, torch.nn.Module) + def get_model_dependent_hyperparameter_defaults( + self, model: Any, smash_config: SmashConfig | SmashConfigPrefixWrapper + ) -> TARGET_MODULES_TYPE: + """ + Get default values for the target_modules based on the model and configuration. + + Parameters + ---------- + model : Any + The model to get the default hyperparameters from. + smash_config : SmashConfig + The SmashConfig object. + + Returns + ------- + TARGET_MODULES_TYPE + The default target_modules for the algorithm. + """ + include: list[str] + if hasattr(model, "unet"): + include = ["unet*"] + elif hasattr(model, "transformer"): + include = ["transformer*"] + else: + include = ["*"] + return {"include": include, "exclude": []} + def _apply(self, model: Any, smash_config: SmashConfigPrefixWrapper) -> Any: """ Quantize the model with QUANTO. @@ -102,12 +139,9 @@ def _apply(self, model: Any, smash_config: SmashConfigPrefixWrapper) -> Any: The quantized model. """ imported_modules = self.import_algorithm_packages() - if hasattr(model, "unet"): - working_model = model.unet - elif hasattr(model, "transformer"): - working_model = model.transformer - else: - working_model = model + target_modules = smash_config["target_modules"] + if target_modules is None: + target_modules = self.get_model_dependent_hyperparameter_defaults(model, smash_config) weights = getattr(imported_modules["quanto"], smash_config["weight_bits"]) activations = ( @@ -116,18 +150,39 @@ def _apply(self, model: Any, smash_config: SmashConfigPrefixWrapper) -> Any: else None ) - try: - imported_modules["quantize"](working_model, weights=weights, activations=activations) - except Exception as e: - pruna_logger.error("Error during quantization: %s", e) - raise + def quantize_nn(attr_name: str | None, module: torch.nn.Module, subpaths: list[str]) -> Any: + """ + Apply Quanto quantization to a nn.Module. + + Parameters + ---------- + attr_name : str + The name of the attribute in the model pointing to the nn.Module to quantize. + module : torch.nn.Module + The nn.Module to quantize. + subpaths : list[str] + The subpaths of the module to quantize. + """ + try: + imported_modules["quantize"]( + module, + weights=weights, + activations=activations, + include=subpaths, + ) + except Exception as e: + pruna_logger.error("Error during quantization: %s", e) + raise + return module + + model = map_targeted_nn_roots(quantize_nn, model, target_modules) if smash_config["calibrate"]: if smash_config.tokenizer is not None and smash_config.data is not None: try: with imported_modules["Calibration"](streamline=True, debug=False): calibrate( - working_model, + model, smash_config.val_dataloader(), model.device, # only e.g. CUDA here is not enough, we need also the correct device index batch_size=smash_config.batch_size, @@ -139,11 +194,14 @@ def _apply(self, model: Any, smash_config: SmashConfigPrefixWrapper) -> Any: else: pruna_logger.error("Calibration requires a tokenizer and dataloader. Skipping calibration.") - try: - imported_modules["freeze"](working_model) - except Exception as e: - pruna_logger.error("Error while freezing the model: %s", e) - raise + for module in get_nn_modules(model).values(): + try: + # optimum.quanto.freeze checks whether the module has been quantized by quanto + # so we can call it on all nn.Module without filtering + imported_modules["freeze"](module) + except Exception as e: + pruna_logger.error("Error while freezing the module: %s", e) + raise return model def import_algorithm_packages(self) -> Dict[str, Any]: diff --git a/src/pruna/config/hyperparameters.py b/src/pruna/config/hyperparameters.py new file mode 100644 index 00000000..d42ea506 --- /dev/null +++ b/src/pruna/config/hyperparameters.py @@ -0,0 +1,92 @@ +# Copyright 2025 - Pruna AI GmbH. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import Any + +from ConfigSpace import CategoricalHyperparameter, Constant +from typing_extensions import override + + +class Boolean(CategoricalHyperparameter): + """ + Represents a boolean hyperparameter with choices True and False. + + Parameters + ---------- + name : str + The name of the hyperparameter. + default : bool + The default value of the hyperparameter. + meta : Any + The metadata for the hyperparameter. + """ + + def __init__(self, name: str, default: bool = False, meta: Any = dict()) -> None: + super().__init__(name, choices=[True, False], default_value=default, meta=meta) + + def __new__(cls, name: str, default: bool = False, meta: Any = None) -> CategoricalHyperparameter: # type: ignore + """Create a new boolean hyperparameter.""" + return CategoricalHyperparameter(name, choices=[True, False], default_value=default, meta=meta) + + +class UnconstrainedHyperparameter(Constant): + """ + Represents a hyperparameter that is unconstrained and can be set to any value by the user. + + Parameters + ---------- + name : str + The name of the hyperparameter. + default_value : Any + The default value of the hyperparameter. + meta : Any + The metadata for the hyperparameter. + """ + + def __init__( + self, + name: str, + default_value: Any = None, + meta: Any = None, + ) -> None: + super().__init__(name, default_value, meta) + + @override + def legal_value(self, value): # numpydoc ignore=GL08 + """ + Check if a value is legal for this hyperparameter. + + This hyperparameter is unconstrained and can be set to any value by the user. + Therefore, this method always returns `True` as long as the format is accepted + by ConfigSpace. + + Parameters + ---------- + value : Any + The value to check. + + Returns + ------- + bool or numpy.ndarray + `True` if the value is legal, `False` otherwise. If `value` is an array, + a boolean mask of legal values is returned. + """ + # edit the internal state of the Constant to allow for the new value + self._contains_sequence_as_value = isinstance(value, (list, tuple)) + self._transformer.value = value + # we still run the super method which should return True, to make sure internal values + # are correctly updated + return super().legal_value(value) diff --git a/src/pruna/config/smash_space.py b/src/pruna/config/smash_space.py index a81c149d..b0aadaca 100644 --- a/src/pruna/config/smash_space.py +++ b/src/pruna/config/smash_space.py @@ -40,28 +40,6 @@ ALGORITHM_GROUPS = [FACTORIZER, PRUNER, QUANTIZER, KERNEL, CACHER, COMPILER, BATCHER] -class Boolean(CategoricalHyperparameter): - """ - Represents a boolean hyperparameter with choices True and False. - - Parameters - ---------- - name : str - The name of the hyperparameter. - default : bool - The default value of the hyperparameter. - meta : Any - The metadata for the hyperparameter. - """ - - def __init__(self, name: str, default: bool = False, meta: Any = dict()) -> None: - super().__init__(name, choices=[True, False], default_value=default, meta=meta) - - def __new__(cls, name: str, default: bool = False, meta: Any = None) -> CategoricalHyperparameter: # type: ignore - """Create a new boolean hyperparameter.""" - return CategoricalHyperparameter(name, choices=[True, False], default_value=default, meta=meta) - - class IsTrueCondition(EqualsCondition): """ Represents a condition that checks if a hyperparameter is set to True. diff --git a/src/pruna/config/target_modules.py b/src/pruna/config/target_modules.py new file mode 100644 index 00000000..ded4783e --- /dev/null +++ b/src/pruna/config/target_modules.py @@ -0,0 +1,213 @@ +# Copyright 2025 - Pruna AI GmbH. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import fnmatch +from typing import Any, Callable, Dict, List, Literal, Optional, Tuple + +import torch +from typing_extensions import override + +from pruna.config.hyperparameters import UnconstrainedHyperparameter +from pruna.engine.utils import get_nn_modules + +TARGET_MODULES_TYPE = Dict[Literal["include", "exclude"], List[str]] + + +class TargetModules(UnconstrainedHyperparameter): + """ + Represents a target modules hyperparameter, used to select modules based on include and exclude patterns. + + Parameters + ---------- + name : str + The name of the hyperparameter. + default_value : Optional[TARGET_MODULES_TYPE] + The default value of the hyperparameter. + meta : Any + Meta data describing the hyperparameter. + """ + + def __init__(self, name: str, default_value: Optional[TARGET_MODULES_TYPE] = None, meta: Any = None) -> None: + super().__init__(name, default_value, meta=meta) + + @override + def legal_value(self, value: TARGET_MODULES_TYPE | None): # type: ignore[override] # numpydoc ignore=GL08 + """ + Check if a value is a valid target modules of type TARGET_MODULES_TYPE. + + Parameters + ---------- + value : Any + The value to check. + + Returns + ------- + bool or numpy.ndarray + `True` if the value is of type TARGET_MODULES_TYPE, `False` otherwise. + """ + # ensure the value is a TARGET_MODULES_TYPE to make errors more explicit for the user + if value is None: + pass + elif not isinstance(value, dict): + raise TypeError(f"Target modules must be a dictionary with keys 'include' and/or 'exclude'. Got: {value}") + elif any(key not in ["include", "exclude"] for key in value): + raise ValueError(f"Target modules must only use keys 'include' and/or 'exclude'. Got: {list(value.keys())}") + elif any(not isinstance(patterns, list) for patterns in value.values()): + raise TypeError( + f"Target modules must be a dictionary with lists of fnmatch patterns as values. Got: {value}" + ) + else: + include_patterns = value.get("include", []) + exclude_patterns = value.get("exclude", []) + all_patterns = include_patterns + exclude_patterns + unrecognized_patterns = [pattern for pattern in all_patterns if not isinstance(pattern, str)] + if unrecognized_patterns: + raise TypeError( + "Target modules must be a dictionary with lists of " + "Unix shell-style wildcards (fnmatch-style) patterns as values. " + f"Could not recognize the following as fnmatch patterns: {unrecognized_patterns}." + ) + + # handle default value: modify the dict in place to have a match between the value and default value + if value is None: + pass # chosing a default value is left to the algorithm based on the model + elif "include" not in value: + value["include"] = ["*"] + elif "exclude" not in value: + value["exclude"] = [] # for consistency + + return super().legal_value(value) + + +def expand_list_of_targeted_paths(target_modules: TARGET_MODULES_TYPE, model: Any) -> List[str]: + """ + Convert the target modules to a list of module paths. + + Parameters + ---------- + model : Any + The model to get the module paths from. + target_modules : TARGET_MODULES_TYPE + The target modules to convert to a list of module paths. + + Returns + ------- + List[str] + The list of module paths. + + Raises + ------ + ValueError + If no targeted subpath is found within the model. + """ + include = target_modules.get("include", ["*"]) + exclude = target_modules.get("exclude", []) + modules_paths = [] + for root_name, module in get_nn_modules(model).items(): + module_paths = [ + f"{root_name}{'.' + path if path else ''}" if root_name else path for path, _ in module.named_modules() + ] + matching_modules = [ + path + for path in module_paths + if any(fnmatch.fnmatch(path, _include) for _include in include) + and not any(fnmatch.fnmatch(path, _exclude) for _exclude in exclude) + ] + modules_paths.extend(matching_modules) + + if not modules_paths: + raise ValueError(f"No targeted subpath found within the model from target_modules {target_modules}") + return modules_paths + + +def expand_dict_of_roots_and_subpaths( + target_modules: TARGET_MODULES_TYPE, model: Any +) -> Dict[str | None, Tuple[torch.nn.Module, List[str]]]: + """ + Get the torch modules within the model and their associated targeted subpaths. + + Parameters + ---------- + target_modules : TARGET_MODULES_TYPE + The target modules to convert to a list of module paths. + model : Any + The model to get the module paths from. + + Returns + ------- + Dict[str | None, Tuple[torch.nn.Module, List[str]]] + The dictionary of modules attributes in the model with their associated targeted subpaths. + A module attribute which doesn't contain any targeted subpath won't be included in the dictionary. + Each module-subpaths pair is indexed by the module attribute name within the model. + Following the convention of get_nn_modules, if the model itself is a torch.nn.Module, the dictionary + will contain a single item with key None, pointing to the model itself and the targeted paths. + """ + target_modules_paths = expand_list_of_targeted_paths(target_modules, model) + + modules_with_subpaths: Dict[str | None, Tuple[torch.nn.Module, List[str]]] = {} + for root_name, module in get_nn_modules(model).items(): + prefix = f"{root_name}." if root_name else "" + + targeted_submodules = [path for path in target_modules_paths if path.startswith(prefix)] + targeted_submodules = [path.removeprefix(prefix) for path in targeted_submodules] + + # only register the module if it contains at least one targeted submodule + if targeted_submodules: + modules_with_subpaths[root_name] = (module, targeted_submodules) + + return modules_with_subpaths + + +def map_targeted_nn_roots( + apply_single_root_fn: Callable[[str | None, torch.nn.Module, List[str]], Any], + model: Any, + target_modules: TARGET_MODULES_TYPE, +) -> Any: + """ + Apply a function to the model, or to each of its targeted nn.Modules in the case of a Pipeline. + + Parameters + ---------- + apply_single_root_fn : Callable[[str | None, torch.nn.Module, List[str]], Any] + The function to apply to each root in the model. + It must take as input the attribute name of the root in the model, the nn.Module itself, and a list of + paths within the root, each pointing to a targeted submodule. It must return the modified root. + The roots are the model itself if it is a torch.nn.Module (attribute name is None in this case), + or its nn.Module attributes otherwise. + model : Any + The model to apply the function to. + target_modules : TARGET_MODULES_TYPE + The target modules to apply the function to. + + Returns + ------- + Any + The model after the function has been applied. + """ + nn_roots_with_subpaths = expand_dict_of_roots_and_subpaths(target_modules, model) + for attr_name, (nn_root, subpaths) in nn_roots_with_subpaths.items(): + # modify the root with the provided function + applied_root = apply_single_root_fn(attr_name, nn_root, subpaths) + if applied_root is None: + raise ValueError("The 'apply_single_root_fn' function must return the modified root.") + + # replace the root with the modified one + if attr_name is None: + # by convention, this means the model itself is a torch.nn.Module, which we got as module + model = applied_root + else: + setattr(model, attr_name, applied_root) + return model diff --git a/tests/config/test_target_modules.py b/tests/config/test_target_modules.py new file mode 100644 index 00000000..955718e9 --- /dev/null +++ b/tests/config/test_target_modules.py @@ -0,0 +1,57 @@ +from __future__ import annotations + +import pytest +import re + +from pruna import SmashConfig, smash +from pruna.config.target_modules import TARGET_MODULES_TYPE + + +@pytest.mark.cuda +@pytest.mark.parametrize( + "model_fixture, algorithm_group, algorithm, target_modules, expected_number_of_targeted_modules", + [ + ("flux_tiny_random", "quantizer", "quanto", None, 28), + ("flux_tiny_random", "quantizer", "quanto", {"include": ["transformer*"]}, 28), + ("flux_tiny_random", "quantizer", "quanto", {"include": ["transformer*"], "exclude": ["*norm*"]}, 24), + ], + indirect=["model_fixture"], +) +def test_target_modules( + model_fixture: tuple, algorithm_group: str, algorithm: str, target_modules: TARGET_MODULES_TYPE | None, expected_number_of_targeted_modules: int +) -> None: + model, smash_config = model_fixture + smash_config[algorithm_group] = algorithm + smash_config[f"{algorithm}_target_modules"] = target_modules + smashed_model = smash(model, smash_config) + + num_targeted_modules = sum( + 1 for module in smashed_model.get_nn_modules().values() + for submodule in module.modules() + if submodule.__class__.__name__ == "QLinear" + ) + assert num_targeted_modules == expected_number_of_targeted_modules + +@pytest.mark.cpu +@pytest.mark.parametrize("target_modules", [ + {"include": ["test_pattern*", "other"]}, + {"include": ["this*"], "exclude": ["that"]}, +]) +def test_target_modules_format_accept(target_modules: dict[str, list[str]]): + smash_config = SmashConfig() + smash_config["quantizer"] = "quanto" + smash_config["quanto_target_modules"] = target_modules + assert smash_config['quanto_target_modules'] == target_modules + +@pytest.mark.cpu +@pytest.mark.parametrize("target_modules, expected_error", [ + (["transformer*"], TypeError), # not a dict + ({"what_are_the_keywords": ["transformer"]}, ValueError), # keys should be "include" or "exclude" + ({"include": ["transformer*"], "exclude": 1}, TypeError), # "exclude" value is not a list + ({"include": ["transformer*"], "exclude": [1, "transformer*"]}, TypeError), # lists can't contain anything but strings +]) +def test_target_modules_format_reject(target_modules: TARGET_MODULES_TYPE, expected_error: type): + smash_config = SmashConfig() + smash_config["quantizer"] = "quanto" + with pytest.raises(expected_error): + smash_config["quanto_target_modules"] = target_modules