Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
156 changes: 111 additions & 45 deletions src/pruna/algorithms/quantization/huggingface_diffusers_int8.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,22 +11,32 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations

import tempfile
from typing import Any, Dict
from typing import Any, Dict, cast

import diffusers
import torch.nn as nn
from ConfigSpace import CategoricalHyperparameter, Constant, OrdinalHyperparameter
from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig

from pruna.algorithms.quantization import PrunaQuantizer
from pruna.config.hyperparameters import Boolean
from pruna.config.smash_config import SmashConfigPrefixWrapper
from pruna.config.smash_config import SmashConfig, SmashConfigPrefixWrapper
from pruna.config.target_modules import (
TARGET_MODULES_TYPE,
TargetModules,
get_skipped_submodules,
is_leaf_module,
map_targeted_nn_roots,
)
from pruna.engine.model_checks import (
get_diffusers_transformer_models,
get_diffusers_unet_models,
)
from pruna.engine.utils import determine_dtype, get_device_map, move_to_device
from pruna.logging.logger import pruna_logger


class DiffusersInt8Quantizer(PrunaQuantizer):
Expand Down Expand Up @@ -76,6 +86,15 @@ def get_hyperparameters(self) -> list:
default_value="fp4",
meta=dict(desc="Quantization type to use."),
),
TargetModules(
name="target_modules",
default_value=None,
meta=dict(
desc="Precise choices of which modules to quantize, "
"e.g. {include: ['transformer.*']} to quantize only the transformer in a diffusion pipeline. "
f"See the {TargetModules.documentation_name_with_link} documentation for more details."
),
),
]

def model_check_fn(self, model: Any) -> bool:
Expand All @@ -102,6 +121,33 @@ def model_check_fn(self, model: Any) -> bool:

return hasattr(model, "unet") and isinstance(model.unet, tuple(transformer_and_unet_models))

def get_model_dependent_hyperparameter_defaults(
self, model: Any, smash_config: SmashConfig | SmashConfigPrefixWrapper
) -> TARGET_MODULES_TYPE:
"""
Get default values for the target_modules based on the model and configuration.

Parameters
----------
model : Any
The model to get the default hyperparameters from.
smash_config : SmashConfig
The SmashConfig object.

Returns
-------
TARGET_MODULES_TYPE
The default target_modules for the algorithm.
"""
prefix: str
if hasattr(model, "transformer"):
prefix = "transformer."
elif hasattr(model, "unet"):
prefix = "unet."
else:
prefix = ""
return {"include": [prefix + "*"], "exclude": [prefix + "lm_head"]}

def _apply(self, model: Any, smash_config: SmashConfigPrefixWrapper) -> Any:
"""
Quantize the model.
Expand All @@ -118,51 +164,71 @@ def _apply(self, model: Any, smash_config: SmashConfigPrefixWrapper) -> Any:
Any
The quantized model.
"""
with tempfile.TemporaryDirectory(prefix=str(smash_config["cache_dir"])) as temp_dir:
# save the latent model (to be quantized) in a temp directory
if hasattr(model, "transformer"):
working_model = model.transformer
device_map = get_device_map(model, subset_key="transformer")

elif hasattr(model, "unet"):
working_model = model.unet
device_map = get_device_map(model, subset_key="unet")
else:
working_model = model
device_map = get_device_map(model)

move_to_device(working_model, "cpu")
working_model.save_pretrained(temp_dir)
latent_class = getattr(diffusers, type(working_model).__name__)
compute_dtype = determine_dtype(working_model)

bnb_config = DiffusersBitsAndBytesConfig(
load_in_8bit=smash_config["weight_bits"] == 8,
load_in_4bit=smash_config["weight_bits"] == 4,
llm_int8_threshold=float(smash_config["threshold"]),
llm_int8_skip_modules=["lm_head"],
llm_int8_enable_fp32_cpu_offload=smash_config["enable_fp32_cpu_offload"],
llm_int8_has_fp16_weight=smash_config["has_fp16_weight"],
bnb_4bit_compute_dtype=compute_dtype,
bnb_4bit_quant_type=smash_config["quant_type"],
bnb_4bit_use_double_quant=smash_config["double_quant"],
target_modules = smash_config["target_modules"]
if target_modules is None:
target_modules = self.get_model_dependent_hyperparameter_defaults(model, smash_config)
target_modules = cast(TARGET_MODULES_TYPE, target_modules)

def quantize_working_model(attr_name: str | None, working_model: nn.Module, subpaths: list[str]) -> Any:
"""
Quantize a working model with bitsandbytes.

Parameters
----------
attr_name : str | None
The name of the attribute in the model pointing to the working model to quantize.
working_model : torch.nn.Module
The working model to quantize, i.e. a nn.Module component of the model.
subpaths : list[str]
The subpaths of the working model to quantize.
"""
if not hasattr(working_model, "save_pretrained") or not callable(working_model.save_pretrained):
raise ValueError(
"diffusers-int8 was applied to a module which didn't have a callable save_pretrained method."
)

# only include leaf modules since the bnb quantizer skips all submodules
# within a skipped module. Only Linear and Conv1d layers can be quantized anyway.
skipped_modules = get_skipped_submodules(working_model, subpaths, filter_fn=is_leaf_module)
pruna_logger.debug(
f"Skipping {self.algorithm_name} quantization for the following "
f"leaf modules within {attr_name or 'the model'} : {skipped_modules}"
)

Comment thread
gsprochette marked this conversation as resolved.
# re-load the latent model (with the quantization config)
smashed_latent = latent_class.from_pretrained(
temp_dir,
quantization_config=bnb_config,
torch_dtype=compute_dtype,
device_map=device_map,
)
# replace the latent model in the pipeline
if hasattr(model, "transformer"):
model.transformer = smashed_latent
elif hasattr(model, "unet"):
model.unet = smashed_latent
else:
model = smashed_latent
return model
with tempfile.TemporaryDirectory(prefix=str(smash_config["cache_dir"])) as temp_dir:
Comment thread
gsprochette marked this conversation as resolved.
# Only the full model contains the device map, so we get it using the attribute name
# attr_name can be None, then get_device_map defaults to the whole model, which is the expected behavior
device_map = get_device_map(model, subset_key=attr_name)

# save the latent model (to be quantized) in a temp directory
move_to_device(working_model, "cpu")
working_model.save_pretrained(temp_dir)
working_class = getattr(diffusers, type(working_model).__name__)
compute_dtype = determine_dtype(working_model)

bnb_config = DiffusersBitsAndBytesConfig(
load_in_8bit=smash_config["weight_bits"] == 8,
load_in_4bit=smash_config["weight_bits"] == 4,
llm_int8_threshold=float(smash_config["threshold"]),
llm_int8_skip_modules=skipped_modules,
llm_int8_enable_fp32_cpu_offload=smash_config["enable_fp32_cpu_offload"],
llm_int8_has_fp16_weight=smash_config["has_fp16_weight"],
bnb_4bit_compute_dtype=compute_dtype,
bnb_4bit_quant_type=smash_config["quant_type"],
bnb_4bit_use_double_quant=smash_config["double_quant"],
)

# re-load the latent model (with the quantization config)
quantized_working_model = working_class.from_pretrained(
temp_dir,
quantization_config=bnb_config,
torch_dtype=compute_dtype,
device_map=device_map,
)
return quantized_working_model

quantized_model = map_targeted_nn_roots(quantize_working_model, model, target_modules)
return quantized_model

def import_algorithm_packages(self) -> Dict[str, Any]:
"""
Expand Down
132 changes: 103 additions & 29 deletions src/pruna/algorithms/quantization/huggingface_llm_int8.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,29 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import tempfile
from typing import Any, Dict
from typing import Any, Dict, cast

import torch
from ConfigSpace import CategoricalHyperparameter, Constant, OrdinalHyperparameter
from transformers import AutoModelForCausalLM, BitsAndBytesConfig
from transformers.modeling_utils import PreTrainedModel

from pruna.algorithms.quantization import PrunaQuantizer
from pruna.config.hyperparameters import Boolean
from pruna.config.smash_config import SmashConfigPrefixWrapper
from pruna.config.smash_config import SmashConfig, SmashConfigPrefixWrapper
from pruna.config.target_modules import (
TARGET_MODULES_TYPE,
TargetModules,
get_skipped_submodules,
is_leaf_module,
map_targeted_nn_roots,
)
from pruna.engine.model_checks import is_causal_lm, is_transformers_pipeline_with_causal_lm
from pruna.engine.utils import get_device_map, move_to_device
from pruna.logging.logger import pruna_logger


class LLMInt8Quantizer(PrunaQuantizer):
Expand Down Expand Up @@ -70,6 +81,15 @@ def get_hyperparameters(self) -> list:
default_value="fp4",
meta=dict(desc="Quantization type to use."),
),
TargetModules(
name="target_modules",
default_value=None,
meta=dict(
desc="Precise choices of which modules to quantize, "
"e.g. {include: ['transformer.*']} to quantize only the transformer in a diffusion pipeline. "
f"See the {TargetModules.documentation_name_with_link} documentation for more details."
),
),
]

def model_check_fn(self, model: Any) -> bool:
Expand All @@ -88,6 +108,27 @@ def model_check_fn(self, model: Any) -> bool:
"""
return is_causal_lm(model) or is_transformers_pipeline_with_causal_lm(model)

def get_model_dependent_hyperparameter_defaults(
self, model: Any, smash_config: SmashConfig | SmashConfigPrefixWrapper
) -> TARGET_MODULES_TYPE:
"""
Get default values for the target_modules based on the model and configuration.

Parameters
----------
model : Any
The model to get the default hyperparameters from.
smash_config : SmashConfig
The SmashConfig object.

Returns
-------
TARGET_MODULES_TYPE
The default target_modules for the algorithm.
"""
prefix = "model." if is_transformers_pipeline_with_causal_lm(model) else ""
return {"include": [prefix + "*"], "exclude": [prefix + "lm_head"]}

def _apply(self, model: Any, smash_config: SmashConfigPrefixWrapper) -> Any:
"""
Quantize the model.
Expand All @@ -104,35 +145,68 @@ def _apply(self, model: Any, smash_config: SmashConfigPrefixWrapper) -> Any:
Any
The quantized model.
"""
if is_transformers_pipeline_with_causal_lm(model):
return self._apply_to_model_within_transformers_pipeline(model, smash_config)
with tempfile.TemporaryDirectory(prefix=str(smash_config["cache_dir"])) as temp_dir:
# cast original model to CPU to free memory for smashed model
device_map = get_device_map(model)
move_to_device(model, "cpu")
model.save_pretrained(temp_dir)

bnb_config = BitsAndBytesConfig(
load_in_8bit=smash_config["weight_bits"] == 8,
load_in_4bit=smash_config["weight_bits"] == 4,
llm_int8_threshold=float(smash_config["threshold"]),
llm_int8_skip_modules=["lm_head"],
llm_int8_enable_fp32_cpu_offload=smash_config["enable_fp32_cpu_offload"],
llm_int8_has_fp16_weight=smash_config["has_fp16_weight"],
bnb_4bit_compute_dtype=getattr(torch, smash_config["compute_dtype"]),
bnb_4bit_quant_type=smash_config["quant_type"],
bnb_4bit_use_double_quant=smash_config["double_quant"],
)

smashed_model = AutoModelForCausalLM.from_pretrained(
temp_dir,
quantization_config=bnb_config,
trust_remote_code=True,
torch_dtype=smash_config["compute_dtype"], # storage type of the non-int8 params
device_map=device_map,
target_modules = smash_config["target_modules"]
if target_modules is None:
target_modules = self.get_model_dependent_hyperparameter_defaults(model, smash_config)
Comment thread
gsprochette marked this conversation as resolved.
target_modules = cast(TARGET_MODULES_TYPE, target_modules)

def quantize_causal_lm(attr_name: str | None, causal_lm: torch.nn.Module, subpaths: list[str]) -> Any:
"""
Quantize a causal language model with bitsandbytes.

Parameters
----------
attr_name : str | None
The name of the attribute in the model pointing to the causal language model to quantize.
causal_lm : torch.nn.Module
The causal language model to quantize.
subpaths : list[str]
The subpaths of the causal language model to quantize.
"""
# this can only be applied to a causal lm because we use AutoModelForCausalLM to load the model again
if not is_causal_lm(causal_lm):
raise ValueError(
"llm-int8 was applied to a model (or part of a model) which is not a causal language model."
)
causal_lm = cast(PreTrainedModel, causal_lm)

# get the skipped modules, only include leaf modules since the bnb quantizer skips all submodules
# within a skipped module. Only Linear and Conv1d layers can be quantized anyway.
skipped_modules = get_skipped_submodules(causal_lm, subpaths, filter_fn=is_leaf_module)
pruna_logger.debug(
f"Skipping {self.algorithm_name} quantization for the following "
f"leaf modules within {attr_name or 'the model'} : {skipped_modules}"
)

return smashed_model
with tempfile.TemporaryDirectory(prefix=str(smash_config["cache_dir"])) as temp_dir:
Comment thread
gsprochette marked this conversation as resolved.
# cast original model to CPU to free memory for smashed model
device_map = get_device_map(causal_lm)
Comment thread
gsprochette marked this conversation as resolved.
move_to_device(causal_lm, "cpu")
causal_lm.save_pretrained(temp_dir)

bnb_config = BitsAndBytesConfig(
load_in_8bit=smash_config["weight_bits"] == 8,
load_in_4bit=smash_config["weight_bits"] == 4,
llm_int8_threshold=float(smash_config["threshold"]),
llm_int8_skip_modules=skipped_modules,
llm_int8_enable_fp32_cpu_offload=smash_config["enable_fp32_cpu_offload"],
llm_int8_has_fp16_weight=smash_config["has_fp16_weight"],
bnb_4bit_compute_dtype=getattr(torch, smash_config["compute_dtype"]),
bnb_4bit_quant_type=smash_config["quant_type"],
bnb_4bit_use_double_quant=smash_config["double_quant"],
)

quantized_causal_lm = AutoModelForCausalLM.from_pretrained(
temp_dir,
quantization_config=bnb_config,
trust_remote_code=True,
torch_dtype=smash_config["compute_dtype"], # storage type of the non-int8 params
device_map=device_map,
)
return quantized_causal_lm

quantized_model = map_targeted_nn_roots(quantize_causal_lm, model, target_modules)
return quantized_model

def import_algorithm_packages(self) -> Dict[str, Any]:
"""
Expand Down
Loading