Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ repos:
hooks:
- id: ty
name: type checking using ty
entry: uvx ty check .
entry: uvx ty check src/pruna
language: system
types: [python]
pass_filenames: false
Expand Down
6 changes: 4 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,10 @@ invalid-return-type = "ignore" # mypy is more permissive with return types
invalid-parameter-default = "ignore" # mypy is more permissive with parameter defaults
no-matching-overload = "ignore" # mypy is more permissive with overloads
unresolved-reference = "ignore" # mypy is more permissive with references
possibly-unbound-import = "ignore"
possibly-missing-import = "ignore"
possibly-missing-attribute = "ignore"
missing-argument = "ignore"
unused-type-ignore-comment = "ignore"
Comment thread
gsprochette marked this conversation as resolved.

[tool.coverage.run]
source = ["src/pruna"]
Expand Down Expand Up @@ -190,7 +192,7 @@ dev = [
"pytest-rerunfailures",
"coverage",
"docutils",
"ty==0.0.1a21",
"ty==0.0.20",
"types-PyYAML",
"logbar",
"pytest-xdist>=3.8.0",
Expand Down
4 changes: 2 additions & 2 deletions src/pruna/algorithms/c_translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,7 +393,7 @@ def __call__(
x_tensor = x["input_ids"]
else:
x_tensor = x
token_list = [self.tokenizer.convert_ids_to_tokens(x_tensor[i]) for i in range(len(x_tensor))]
token_list = [self.tokenizer.convert_ids_to_tokens(x_tensor[i]) for i in range(len(x_tensor))] # type: ignore[not-subscriptable]
return self.generator.generate_batch(token_list, min_length=min_length, max_length=max_length, *args, **kwargs) # type: ignore[operator]


Expand Down Expand Up @@ -468,7 +468,7 @@ def __call__(
x_tensor = x["input_ids"]
else:
x_tensor = x
token_list = [self.tokenizer.convert_ids_to_tokens(x_tensor[i]) for i in range(len(x_tensor))]
token_list = [self.tokenizer.convert_ids_to_tokens(x_tensor[i]) for i in range(len(x_tensor))] # type: ignore[not-subscriptable]
return self.translator.translate_batch( # type: ignore[operator]
token_list,
min_decoding_length=min_decoding_length,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@
from abc import ABC, abstractmethod
from typing import Any

import torch

from pruna.config.smash_config import SmashConfigPrefixWrapper


Expand All @@ -45,16 +43,14 @@ def get_hyperparameters(cls, **override_defaults: Any) -> list:

@classmethod
@abstractmethod
def finetune(
cls, model: torch.nn.Module, smash_config: SmashConfigPrefixWrapper, seed: int, recoverer: str
) -> torch.nn.Module:
def finetune(cls, model: Any, smash_config: SmashConfigPrefixWrapper, seed: int, recoverer: str) -> Any:
"""
Apply the component to the model: activate parameters for Adapters, or finetune them for Finetuners.

Parameters
----------
model : torch.nn.Module
The model to apply the component to.
model : Any
The model or pipeline to apply the component to (e.g. torch.nn.Module or DiffusionPipeline).
smash_config : SmashConfigPrefixWrapper
The configuration for the component.
seed : int
Expand All @@ -64,7 +60,7 @@ def finetune(

Returns
-------
torch.nn.Module
The model with the component applied.
Any
The model or pipeline with the component applied.
"""
pass
Original file line number Diff line number Diff line change
Expand Up @@ -162,13 +162,13 @@ def get_hyperparameters(cls, **override_defaults) -> List:
]

@classmethod
def finetune(cls, pipeline: Any, smash_config: SmashConfigPrefixWrapper, seed: int, recoverer: str) -> Any:
def finetune(cls, model: Any, smash_config: SmashConfigPrefixWrapper, seed: int, recoverer: str) -> Any:
"""
Train the model previously activated parameters on distillation data extracted from the original model.

Parameters
----------
pipeline : Any
model : Any
The pipeline containing components to finetune.
smash_config : SmashConfigPrefixWrapper
The configuration for the finetuner.
Expand All @@ -187,8 +187,8 @@ def finetune(cls, pipeline: Any, smash_config: SmashConfigPrefixWrapper, seed: i
f"DiffusionDistillation data module is required for distillation, but got {smash_config.data}."
)

dtype = get_dtype(pipeline)
device = get_device(pipeline)
dtype = get_dtype(model)
device = get_device(model)
try:
lora_r = smash_config["lora_r"]
except KeyError:
Expand All @@ -204,7 +204,7 @@ def finetune(cls, pipeline: Any, smash_config: SmashConfigPrefixWrapper, seed: i

# Finetune the model
trainable_distiller = DistillerTL(
pipeline,
model,
smash_config["training_batch_size"],
smash_config["gradient_accumulation_steps"],
smash_config["optimizer"],
Expand Down Expand Up @@ -247,7 +247,7 @@ def finetune(cls, pipeline: Any, smash_config: SmashConfigPrefixWrapper, seed: i
else:
precision = "32"

accelerator = get_device_type(pipeline)
accelerator = get_device_type(model)
if accelerator == "accelerator":
accelerator = "auto"
trainer = pl.Trainer(
Expand All @@ -270,7 +270,7 @@ def finetune(cls, pipeline: Any, smash_config: SmashConfigPrefixWrapper, seed: i

# Loading the best checkpoint is slow and currently creates conflicts with some quantization algorithms,
# e.g. diffusers_int8. Skipping calling DenoiserTL.load_from_checkpoint for now.
return pipeline.to(device)
return model.to(device)


class DistillerTL(pl.LightningModule):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def get_hyperparameters(cls, **override_defaults) -> List:
]

@classmethod
def finetune(cls, pipeline: Any, smash_config: SmashConfigPrefixWrapper, seed: int, recoverer: str) -> Any:
def finetune(cls, model: Any, smash_config: SmashConfigPrefixWrapper, seed: int, recoverer: str) -> Any:
"""
Finetune the model's previously activated parameters on data.

Expand All @@ -149,7 +149,7 @@ def finetune(cls, pipeline: Any, smash_config: SmashConfigPrefixWrapper, seed: i

Parameters
----------
pipeline : Any
model : Any
The pipeline containing components to finetune.
smash_config : SmashConfigPrefixWrapper
The configuration for the finetuner.
Expand All @@ -163,7 +163,7 @@ def finetune(cls, pipeline: Any, smash_config: SmashConfigPrefixWrapper, seed: i
Any
The finetuned pipeline.
"""
dtype = get_dtype(pipeline)
dtype = get_dtype(model)
device = smash_config.device if isinstance(smash_config.device, str) else smash_config.device.type

# split seed into two rng: generator for the dataloader and a seed for the training part
Expand Down Expand Up @@ -202,11 +202,11 @@ def finetune(cls, pipeline: Any, smash_config: SmashConfigPrefixWrapper, seed: i
optimizer_name = "AdamW"

# Check resolution mismatch
utils.check_resolution_mismatch(pipeline, train_dataloader)
utils.check_resolution_mismatch(model, train_dataloader)

# Finetune the model
trainable_denoiser = DenoiserTL(
pipeline,
model,
optimizer_name,
smash_config["learning_rate"],
smash_config["weight_decay"],
Expand Down Expand Up @@ -263,7 +263,7 @@ def finetune(cls, pipeline: Any, smash_config: SmashConfigPrefixWrapper, seed: i

# Loading the best checkpoint is slow and currently creates conflicts with some quantization algorithms,
# e.g. diffusers_int8. Skipping calling DenoiserTL.load_from_checkpoint for now.
return pipeline.to(device)
return model.to(device)


class DenoiserTL(pl.LightningModule):
Expand Down
2 changes: 1 addition & 1 deletion src/pruna/algorithms/huggingface_diffusers_int8.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def quantize_working_model(attr_name: str | None, working_model: nn.Module, subp
subpaths : list[str]
The subpaths of the working model to quantize.
"""
if not hasattr(working_model, "save_pretrained") or not callable(working_model.save_pretrained):
if not hasattr(working_model, "save_pretrained") or not callable(getattr(working_model, "save_pretrained")):
raise ValueError(
"diffusers-int8 was applied to a module which didn't have a callable save_pretrained method."
)
Expand Down
8 changes: 6 additions & 2 deletions src/pruna/algorithms/ring_attn/utils/ring_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,14 @@ class LocalFunc(torch.autograd.Function):
"""

@staticmethod
def forward(cls, *args, **kwargs):
def forward(ctx, *args, **kwargs):
"""
Forward pass for the ring attention implementation.

Parameters
----------
ctx : torch.autograd.Context
The context of the forward pass.
*args : Any
The arguments to the forward pass.
**kwargs : Any
Expand All @@ -67,12 +69,14 @@ def forward(cls, *args, **kwargs):
return ring_attention._scaled_dot_product_ring_flash_attention(*args, **kwargs)[:2]

@staticmethod
def backward(cls, *args, **kwargs):
def backward(ctx, *args, **kwargs):
"""
Backward pass for ring attention implementation of flash attention.

Parameters
----------
ctx : torch.autograd.Context
The context of the backward pass.
*args : Any
The arguments to the backward pass.
**kwargs : Any
Expand Down
2 changes: 1 addition & 1 deletion src/pruna/algorithms/torch_dynamic.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def _apply(self, model: Any, smash_config: SmashConfigPrefixWrapper) -> Any:
else:
modules_to_quantize = {torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d, torch.nn.Linear}

quantized_model = torch.quantization.quantize_dynamic(
quantized_model = torch.quantization.quantize_dynamic( # type: ignore[deprecated]
model,
modules_to_quantize,
dtype=getattr(torch, smash_config["weight_bits"]),
Expand Down
2 changes: 1 addition & 1 deletion src/pruna/config/smash_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,7 @@ def load_from_json(self, path: str | Path) -> None:
setattr(self, name, config_dict.pop(name))

# Keep only values that still exist in the space, drop stale keys
supported_hparam_names = {hp.name for hp in SMASH_SPACE.get_hyperparameters()}
supported_hparam_names = {hp.name for hp in list(SMASH_SPACE.values())}
saved_values = {k: v for k, v in config_dict.items() if k in supported_hparam_names}

# Seed with the defaults, then overlay the saved values
Expand Down
6 changes: 3 additions & 3 deletions src/pruna/data/diffuser_distillation_data_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ def _prepare_one_sample(self, filename: str, caption: str, subdir_name: str) ->
torch.save(sample, filepath)


class DiffusionDistillationDataset(Dataset):
class DiffusionDistillationDataset(Dataset[Tuple[str, torch.Tensor, torch.Tensor, int]]):
"""
Dataset for distilling a diffusion pipeline, containing captions, latent inputs, latent outputs and seeds.

Expand All @@ -243,9 +243,9 @@ def __len__(self) -> int:
"""Return the number of samples in the dataset."""
return len(self.filenames)

def __getitem__(self, idx: int) -> Tuple[str, torch.Tensor, torch.Tensor, int]:
def __getitem__(self, index: int) -> Tuple[str, torch.Tensor, torch.Tensor, int]:
"""Get an item from the dataset."""
filepath = self.path / self.filenames[idx]
filepath = self.path / self.filenames[index]
# This is the most generic way to load the data, but may cause a bottleneck because of continuous disk access
# Loading the whole dataset into memory is often possible given the typically small size of distillation datasets
# This can be explored if this is identified as a causing a latency bottleneck
Expand Down
3 changes: 2 additions & 1 deletion tests/algorithms/testers/ring_distributer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import pytest
from _pytest.outcomes import Skipped

from pruna.algorithms.ring_attn.ring import RingAttn
from pruna.engine.pruna_model import PrunaModel
Expand Down Expand Up @@ -46,7 +47,7 @@ def execute_load(self):

def execute_evaluation(self, model, datamodule, device):
"""Skip evaluation for distributed models as it's not fully supported."""
pytest.skip("Evaluation not supported for distributed ring_attn models")
raise Skipped("Evaluation not supported for distributed ring_attn models")

@classmethod
def execute_save(cls, smashed_model: PrunaModel):
Expand Down
2 changes: 1 addition & 1 deletion tests/algorithms/testers/sage_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,4 @@ class TestSageAttn(AlgorithmTesterBase):
reject_models = ["opt_tiny_random"]
allow_pickle_files = False
algorithm_class = SageAttn
metrics = ["latency"]
metrics = ["latency"]
7 changes: 5 additions & 2 deletions tests/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import numpydoc_validation
import pytest
import torch
from _pytest.outcomes import Skipped
from accelerate.utils import compute_module_sizes, infer_auto_device_map
from docutils.core import publish_doctree
from docutils.nodes import literal_block, section, title
Expand Down Expand Up @@ -118,7 +119,9 @@ def run_full_integration(
try:
model, smash_config = model_fixture[0], model_fixture[1]
if device not in algorithm_tester.compatible_devices():
pytest.skip(f"Algorithm {algorithm_tester.get_algorithm_name()} is not compatible with {device}")
raise Skipped(
f"Algorithm {algorithm_tester.get_algorithm_name()} is not compatible with {device}"
)
algorithm_tester.prepare_smash_config(smash_config, device)
device_map = construct_device_map_manually(model) if device == "accelerate" else None
move_to_device(model, device=smash_config["device"], device_map=device_map)
Expand Down Expand Up @@ -344,4 +347,4 @@ def extract_code_blocks_from_node(node: Any, section_name: str) -> None:
section_title = section_title_node.astext().replace(" ", "_").lower()
extract_code_blocks_from_node(sec, section_title)

print(f"Code blocks extracted and written to {output_dir}")
print(f"Code blocks extracted and written to {output_dir}")