Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 71 additions & 13 deletions src/pruna/algorithms/hqq_diffusers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from __future__ import annotations

from collections.abc import Iterable
from contextlib import contextmanager
from pathlib import Path
from typing import Any, Dict, Type

Expand Down Expand Up @@ -202,23 +203,26 @@ def quantize_component(attr_name: str | None, module: torch.nn.Module, subpaths:
torch.nn.Module
The quantized component.
"""
module.layers = find_module_layers_type(module, nn.Linear)
# needs to be computed on original attribute names, so before protecting the layers attribute
ignored_leaf_modules = get_skipped_submodules(module, subpaths, filter_fn=is_leaf_module)

warn_model_specific_errors(module, subpaths)
with protect_layers(module, ignored_leaf_modules):
module.layers = find_module_layers_type(module, nn.Linear)

ignored_leaf_modules = get_skipped_submodules(module, subpaths, filter_fn=is_leaf_module)
auto_hqq_hf_diffusers_model = construct_base_class(
imported_modules, extra_ignore_modules=ignored_leaf_modules
)
warn_model_specific_errors(module, subpaths)

auto_hqq_hf_diffusers_model = construct_base_class(
imported_modules, extra_ignore_modules=ignored_leaf_modules
)

compute_dtype = module.dtype
compute_dtype = module.dtype

auto_hqq_hf_diffusers_model.quantize_model(
module,
quant_config=config,
compute_dtype=compute_dtype,
device=smash_config["device"],
)
auto_hqq_hf_diffusers_model.quantize_model(
module,
quant_config=config,
compute_dtype=compute_dtype,
device=smash_config["device"],
)

# skipped layers are not casted to device and compute dtype so we need to do it manually
for name, submodule in module.named_modules():
Expand Down Expand Up @@ -463,3 +467,57 @@ def find_module_layers_type(model: Any, layer_type: type, exclude_module_names:
if isinstance(module, layer_type) and name not in exclude_module_names:
layers.append(module)
return layers


@contextmanager
def protect_layers(module: torch.nn.Module, path_list: list[str]):
"""
Temporarily rename 'layers' attribute to '_hqq_original_layers' in a context manager.

Parameters
----------
module : Any
The module whose 'layers' attribute needs to be safely overwritten.
path_list : list[str]
A list of paths in the module, possibly using the 'layers' attribute which must be renamed.
This list is modified in place, and restored when exiting the context manager.

Yields
------
None
This context manager does not yield a value and is intended to be
used for its side effects only (temporary attribute renaming).
"""
has_layers = hasattr(module, "layers")
orig_layers = getattr(module, "layers", None)

try:
if has_layers:
# Avoid overwriting if already renamed
if not hasattr(module, "_hqq_original_layers"):
setattr(module, "_hqq_original_layers", orig_layers)
delattr(module, "layers")

# Replace names in path list with the protected names
for i, path in enumerate(path_list):
path_list[i] = _rename_attribute(path, "layers", "_hqq_original_layers")
yield
finally:
if has_layers:
# Restore the original layers attribute
setattr(module, "layers", getattr(module, "_hqq_original_layers"))
delattr(module, "_hqq_original_layers")

# Restore the original names in path list
for i, path in enumerate(path_list):
path_list[i] = _rename_attribute(path, "_hqq_original_layers", "layers")


def _rename_attribute(path: str, old: str, new: str) -> str:
"""Rename the old attribute name with the new one in the path."""
if path == old:
return new
elif path.startswith(f"{old}."):
return path.replace(f"{old}.", f"{new}.", 1)
else:
return path