Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion docs/tutorials/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,6 @@ These tutorials will guide you through the process of using |pruna| to optimize

Optimize your ``diffusion`` model with ``deepcache`` ``caching``.


.. toctree::
:hidden:
:maxdepth: 1
Expand Down
2 changes: 1 addition & 1 deletion src/pruna/algorithms/batching/ws2t.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@

from pruna.algorithms.batching import PrunaBatcher
from pruna.algorithms.compilation.c_translate import WhisperWrapper
from pruna.config.hyperparameters import Boolean
from pruna.config.smash_config import SmashConfigPrefixWrapper
from pruna.config.smash_space import Boolean
from pruna.engine.model_checks import is_speech_seq2seq_model, is_transformers_pipeline_with_speech_recognition
from pruna.logging.filter import SuppressOutput
from pruna.logging.logger import pruna_logger
Expand Down
2 changes: 1 addition & 1 deletion src/pruna/algorithms/compilation/torch_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@

from pruna.algorithms.compilation import PrunaCompiler
from pruna.algorithms.compilation.utils import CausalLMGenerator, JanusGenerator
from pruna.config.hyperparameters import Boolean
from pruna.config.smash_config import SmashConfig, SmashConfigPrefixWrapper
from pruna.config.smash_space import Boolean
from pruna.engine.model_checks import (
get_diffusers_transformer_models,
get_diffusers_unet_models,
Expand Down
22 changes: 22 additions & 0 deletions src/pruna/algorithms/pruna_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,28 @@ def get_hyperparameters(self) -> list:
"""
return []

def get_model_dependent_hyperparameter_defaults(
self, model: Any, smash_config: SmashConfig | SmashConfigPrefixWrapper
) -> Any:
"""
Get default values for unconstrained hyperparameters based on the model and configuration.

Subclasses can override this method to provide default values for their unconstrained hyperparameters.

Parameters
----------
model : Any
The model to get the default hyperparameters from.
smash_config : SmashConfig
The SmashConfig object.

Returns
-------
Any
The default unconstrained hyperparameters values for the algorithm.
"""
return None

def pre_smash_hook(self, model: Any, smash_config: SmashConfig) -> None:
"""
Perform any necessary actions before the smashing process begins.
Expand Down
2 changes: 1 addition & 1 deletion src/pruna/algorithms/pruning/torch_structured.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@
from transformers.models.opt.modeling_opt import OPTForCausalLM as Opt

from pruna.algorithms.pruning import PrunaPruner
from pruna.config.hyperparameters import Boolean
from pruna.config.smash_config import SmashConfigPrefixWrapper
from pruna.config.smash_space import Boolean
from pruna.engine.save import SAVE_FUNCTIONS
from pruna.logging.logger import pruna_logger

Expand Down
2 changes: 1 addition & 1 deletion src/pruna/algorithms/quantization/gptq_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
from ConfigSpace import OrdinalHyperparameter

from pruna.algorithms.quantization import PrunaQuantizer
from pruna.config.hyperparameters import Boolean
from pruna.config.smash_config import SmashConfigPrefixWrapper
from pruna.config.smash_space import Boolean
from pruna.data.utils import recover_text_from_dataloader
from pruna.engine.model_checks import is_causal_lm, is_transformers_pipeline_with_causal_lm
from pruna.engine.utils import safe_memory_cleanup
Expand Down
2 changes: 1 addition & 1 deletion src/pruna/algorithms/quantization/hqq.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@
from transformers import AutoModelForCausalLM

from pruna.algorithms.quantization import PrunaQuantizer
from pruna.config.hyperparameters import Boolean
from pruna.config.smash_config import SmashConfigPrefixWrapper
from pruna.config.smash_space import Boolean
from pruna.engine.model_checks import is_causal_lm, is_janus_llamagen_ar, is_transformers_pipeline_with_causal_lm
from pruna.engine.save import SAVE_FUNCTIONS
from pruna.engine.utils import ModelContext, move_to_device, safe_memory_cleanup
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@
from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig

from pruna.algorithms.quantization import PrunaQuantizer
from pruna.config.hyperparameters import Boolean
from pruna.config.smash_config import SmashConfigPrefixWrapper
from pruna.config.smash_space import Boolean
from pruna.engine.model_checks import (
get_diffusers_transformer_models,
get_diffusers_unet_models,
Expand Down
2 changes: 1 addition & 1 deletion src/pruna/algorithms/quantization/huggingface_llm_int8.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@
from transformers import AutoModelForCausalLM, BitsAndBytesConfig

from pruna.algorithms.quantization import PrunaQuantizer
from pruna.config.hyperparameters import Boolean
from pruna.config.smash_config import SmashConfigPrefixWrapper
from pruna.config.smash_space import Boolean
from pruna.engine.model_checks import is_causal_lm, is_transformers_pipeline_with_causal_lm
from pruna.engine.utils import get_device_map, move_to_device

Expand Down
94 changes: 76 additions & 18 deletions src/pruna/algorithms/quantization/quanto.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,21 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from typing import Any, Dict

import torch
from ConfigSpace import Constant, OrdinalHyperparameter

from pruna import SmashConfig
from pruna.algorithms.quantization import PrunaQuantizer
from pruna.config.hyperparameters import Boolean
from pruna.config.smash_config import SmashConfigPrefixWrapper
from pruna.config.smash_space import Boolean
from pruna.config.target_modules import TARGET_MODULES_TYPE, TargetModules, map_targeted_nn_roots
from pruna.data.utils import wrap_batch_for_model_call
from pruna.engine.save import SAVE_FUNCTIONS
from pruna.engine.utils import get_nn_modules
from pruna.logging.logger import pruna_logger


Expand Down Expand Up @@ -63,6 +68,11 @@ def get_hyperparameters(self) -> list:
Constant("act_bits", value=None),
Boolean("calibrate", default=True, meta=dict(desc="Whether to calibrate the model.")),
Constant(name="calibration_samples", value=64),
TargetModules(
name="target_modules",
default_value=None,
meta=dict(desc="Precise choices of which modules to quantize."),
),
]

def model_check_fn(self, model: Any) -> bool:
Expand All @@ -85,6 +95,33 @@ def model_check_fn(self, model: Any) -> bool:
return True
return hasattr(model, "transformer") and isinstance(model.transformer, torch.nn.Module)

def get_model_dependent_hyperparameter_defaults(
self, model: Any, smash_config: SmashConfig | SmashConfigPrefixWrapper
) -> TARGET_MODULES_TYPE:
"""
Get default values for the target_modules based on the model and configuration.

Parameters
----------
model : Any
The model to get the default hyperparameters from.
smash_config : SmashConfig
The SmashConfig object.

Returns
-------
TARGET_MODULES_TYPE
The default target_modules for the algorithm.
"""
include: list[str]
if hasattr(model, "unet"):
include = ["unet*"]
elif hasattr(model, "transformer"):
include = ["transformer*"]
else:
include = ["*"]
return {"include": include, "exclude": []}

def _apply(self, model: Any, smash_config: SmashConfigPrefixWrapper) -> Any:
"""
Quantize the model with QUANTO.
Expand All @@ -102,12 +139,9 @@ def _apply(self, model: Any, smash_config: SmashConfigPrefixWrapper) -> Any:
The quantized model.
"""
imported_modules = self.import_algorithm_packages()
Comment thread
gsprochette marked this conversation as resolved.
Outdated
if hasattr(model, "unet"):
working_model = model.unet
elif hasattr(model, "transformer"):
working_model = model.transformer
else:
working_model = model
target_modules = smash_config["target_modules"]
if target_modules is None:
target_modules = self.get_model_dependent_hyperparameter_defaults(model, smash_config)

weights = getattr(imported_modules["quanto"], smash_config["weight_bits"])
activations = (
Expand All @@ -116,18 +150,39 @@ def _apply(self, model: Any, smash_config: SmashConfigPrefixWrapper) -> Any:
else None
)

try:
imported_modules["quantize"](working_model, weights=weights, activations=activations)
except Exception as e:
pruna_logger.error("Error during quantization: %s", e)
raise
def quantize_nn(attr_name: str | None, module: torch.nn.Module, subpaths: list[str]) -> Any:
"""
Apply Quanto quantization to a nn.Module.

Parameters
----------
attr_name : str
The name of the attribute in the model pointing to the nn.Module to quantize.
module : torch.nn.Module
The nn.Module to quantize.
subpaths : list[str]
The subpaths of the module to quantize.
"""
try:
imported_modules["quantize"](
module,
weights=weights,
activations=activations,
include=subpaths,
)
except Exception as e:
pruna_logger.error("Error during quantization: %s", e)
raise
return module

model = map_targeted_nn_roots(quantize_nn, model, target_modules)

if smash_config["calibrate"]:
if smash_config.tokenizer is not None and smash_config.data is not None:
try:
with imported_modules["Calibration"](streamline=True, debug=False):
calibrate(
working_model,
model,
smash_config.val_dataloader(),
model.device, # only e.g. CUDA here is not enough, we need also the correct device index
batch_size=smash_config.batch_size,
Expand All @@ -139,11 +194,14 @@ def _apply(self, model: Any, smash_config: SmashConfigPrefixWrapper) -> Any:
else:
pruna_logger.error("Calibration requires a tokenizer and dataloader. Skipping calibration.")

try:
imported_modules["freeze"](working_model)
except Exception as e:
pruna_logger.error("Error while freezing the model: %s", e)
raise
for module in get_nn_modules(model).values():
try:
# optimum.quanto.freeze checks whether the module has been quantized by quanto
# so we can call it on all nn.Module without filtering
imported_modules["freeze"](module)
except Exception as e:
pruna_logger.error("Error while freezing the module: %s", e)
raise
return model

def import_algorithm_packages(self) -> Dict[str, Any]:
Expand Down
92 changes: 92 additions & 0 deletions src/pruna/config/hyperparameters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
# Copyright 2025 - Pruna AI GmbH. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from typing import Any

from ConfigSpace import CategoricalHyperparameter, Constant
from typing_extensions import override


class Boolean(CategoricalHyperparameter):
Comment thread
gsprochette marked this conversation as resolved.
Outdated
"""
Represents a boolean hyperparameter with choices True and False.

Parameters
----------
name : str
The name of the hyperparameter.
default : bool
The default value of the hyperparameter.
meta : Any
The metadata for the hyperparameter.
"""

def __init__(self, name: str, default: bool = False, meta: Any = dict()) -> None:
super().__init__(name, choices=[True, False], default_value=default, meta=meta)

def __new__(cls, name: str, default: bool = False, meta: Any = None) -> CategoricalHyperparameter: # type: ignore
"""Create a new boolean hyperparameter."""
return CategoricalHyperparameter(name, choices=[True, False], default_value=default, meta=meta)


class UnconstrainedHyperparameter(Constant):
"""
Represents a hyperparameter that is unconstrained and can be set to any value by the user.

Parameters
----------
name : str
The name of the hyperparameter.
default_value : Any
The default value of the hyperparameter.
meta : Any
The metadata for the hyperparameter.
"""

def __init__(
self,
name: str,
default_value: Any = None,
meta: Any = None,
) -> None:
super().__init__(name, default_value, meta)

@override
def legal_value(self, value): # numpydoc ignore=GL08
Comment thread
gsprochette marked this conversation as resolved.
Outdated
"""
Check if a value is legal for this hyperparameter.

This hyperparameter is unconstrained and can be set to any value by the user.
Therefore, this method always returns `True` as long as the format is accepted
by ConfigSpace.

Parameters
----------
value : Any
The value to check.

Returns
-------
bool or numpy.ndarray
`True` if the value is legal, `False` otherwise. If `value` is an array,
a boolean mask of legal values is returned.
"""
# edit the internal state of the Constant to allow for the new value
self._contains_sequence_as_value = isinstance(value, (list, tuple))
self._transformer.value = value
# we still run the super method which should return True, to make sure internal values
# are correctly updated
return super().legal_value(value)
22 changes: 0 additions & 22 deletions src/pruna/config/smash_space.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,28 +40,6 @@
ALGORITHM_GROUPS = [FACTORIZER, PRUNER, QUANTIZER, KERNEL, CACHER, COMPILER, BATCHER]


class Boolean(CategoricalHyperparameter):
"""
Represents a boolean hyperparameter with choices True and False.

Parameters
----------
name : str
The name of the hyperparameter.
default : bool
The default value of the hyperparameter.
meta : Any
The metadata for the hyperparameter.
"""

def __init__(self, name: str, default: bool = False, meta: Any = dict()) -> None:
super().__init__(name, choices=[True, False], default_value=default, meta=meta)

def __new__(cls, name: str, default: bool = False, meta: Any = None) -> CategoricalHyperparameter: # type: ignore
"""Create a new boolean hyperparameter."""
return CategoricalHyperparameter(name, choices=[True, False], default_value=default, meta=meta)


class IsTrueCondition(EqualsCondition):
"""
Represents a condition that checks if a hyperparameter is set to True.
Expand Down
Loading