diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 8daac431..c05df366 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -47,7 +47,7 @@ repos: hooks: - id: ty name: type checking using ty - entry: uvx ty check . + entry: uvx ty check src/pruna language: system types: [python] pass_filenames: false diff --git a/pyproject.toml b/pyproject.toml index c199603c..d8ae1007 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,8 +33,10 @@ invalid-return-type = "ignore" # mypy is more permissive with return types invalid-parameter-default = "ignore" # mypy is more permissive with parameter defaults no-matching-overload = "ignore" # mypy is more permissive with overloads unresolved-reference = "ignore" # mypy is more permissive with references -possibly-unbound-import = "ignore" +possibly-missing-import = "ignore" +possibly-missing-attribute = "ignore" missing-argument = "ignore" +unused-type-ignore-comment = "ignore" [tool.coverage.run] source = ["src/pruna"] @@ -190,7 +192,7 @@ dev = [ "pytest-rerunfailures", "coverage", "docutils", - "ty==0.0.1a21", + "ty==0.0.20", "types-PyYAML", "logbar", "pytest-xdist>=3.8.0", diff --git a/src/pruna/algorithms/c_translate.py b/src/pruna/algorithms/c_translate.py index 51e876a5..0b9be863 100644 --- a/src/pruna/algorithms/c_translate.py +++ b/src/pruna/algorithms/c_translate.py @@ -393,7 +393,7 @@ def __call__( x_tensor = x["input_ids"] else: x_tensor = x - token_list = [self.tokenizer.convert_ids_to_tokens(x_tensor[i]) for i in range(len(x_tensor))] + token_list = [self.tokenizer.convert_ids_to_tokens(x_tensor[i]) for i in range(len(x_tensor))] # type: ignore[not-subscriptable] return self.generator.generate_batch(token_list, min_length=min_length, max_length=max_length, *args, **kwargs) # type: ignore[operator] @@ -468,7 +468,7 @@ def __call__( x_tensor = x["input_ids"] else: x_tensor = x - token_list = [self.tokenizer.convert_ids_to_tokens(x_tensor[i]) for i in range(len(x_tensor))] + token_list = [self.tokenizer.convert_ids_to_tokens(x_tensor[i]) for i in range(len(x_tensor))] # type: ignore[not-subscriptable] return self.translator.translate_batch( # type: ignore[operator] token_list, min_decoding_length=min_decoding_length, diff --git a/src/pruna/algorithms/global_utils/recovery/finetuners/__init__.py b/src/pruna/algorithms/global_utils/recovery/finetuners/__init__.py index 76028212..b4231678 100644 --- a/src/pruna/algorithms/global_utils/recovery/finetuners/__init__.py +++ b/src/pruna/algorithms/global_utils/recovery/finetuners/__init__.py @@ -17,8 +17,6 @@ from abc import ABC, abstractmethod from typing import Any -import torch - from pruna.config.smash_config import SmashConfigPrefixWrapper @@ -45,16 +43,14 @@ def get_hyperparameters(cls, **override_defaults: Any) -> list: @classmethod @abstractmethod - def finetune( - cls, model: torch.nn.Module, smash_config: SmashConfigPrefixWrapper, seed: int, recoverer: str - ) -> torch.nn.Module: + def finetune(cls, model: Any, smash_config: SmashConfigPrefixWrapper, seed: int, recoverer: str) -> Any: """ Apply the component to the model: activate parameters for Adapters, or finetune them for Finetuners. Parameters ---------- - model : torch.nn.Module - The model to apply the component to. + model : Any + The model or pipeline to apply the component to (e.g. torch.nn.Module or DiffusionPipeline). smash_config : SmashConfigPrefixWrapper The configuration for the component. seed : int @@ -64,7 +60,7 @@ def finetune( Returns ------- - torch.nn.Module - The model with the component applied. + Any + The model or pipeline with the component applied. """ pass diff --git a/src/pruna/algorithms/global_utils/recovery/finetuners/text_to_image_distiller.py b/src/pruna/algorithms/global_utils/recovery/finetuners/text_to_image_distiller.py index 8392fdb3..842b9421 100644 --- a/src/pruna/algorithms/global_utils/recovery/finetuners/text_to_image_distiller.py +++ b/src/pruna/algorithms/global_utils/recovery/finetuners/text_to_image_distiller.py @@ -162,13 +162,13 @@ def get_hyperparameters(cls, **override_defaults) -> List: ] @classmethod - def finetune(cls, pipeline: Any, smash_config: SmashConfigPrefixWrapper, seed: int, recoverer: str) -> Any: + def finetune(cls, model: Any, smash_config: SmashConfigPrefixWrapper, seed: int, recoverer: str) -> Any: """ Train the model previously activated parameters on distillation data extracted from the original model. Parameters ---------- - pipeline : Any + model : Any The pipeline containing components to finetune. smash_config : SmashConfigPrefixWrapper The configuration for the finetuner. @@ -187,8 +187,8 @@ def finetune(cls, pipeline: Any, smash_config: SmashConfigPrefixWrapper, seed: i f"DiffusionDistillation data module is required for distillation, but got {smash_config.data}." ) - dtype = get_dtype(pipeline) - device = get_device(pipeline) + dtype = get_dtype(model) + device = get_device(model) try: lora_r = smash_config["lora_r"] except KeyError: @@ -204,7 +204,7 @@ def finetune(cls, pipeline: Any, smash_config: SmashConfigPrefixWrapper, seed: i # Finetune the model trainable_distiller = DistillerTL( - pipeline, + model, smash_config["training_batch_size"], smash_config["gradient_accumulation_steps"], smash_config["optimizer"], @@ -247,7 +247,7 @@ def finetune(cls, pipeline: Any, smash_config: SmashConfigPrefixWrapper, seed: i else: precision = "32" - accelerator = get_device_type(pipeline) + accelerator = get_device_type(model) if accelerator == "accelerator": accelerator = "auto" trainer = pl.Trainer( @@ -270,7 +270,7 @@ def finetune(cls, pipeline: Any, smash_config: SmashConfigPrefixWrapper, seed: i # Loading the best checkpoint is slow and currently creates conflicts with some quantization algorithms, # e.g. diffusers_int8. Skipping calling DenoiserTL.load_from_checkpoint for now. - return pipeline.to(device) + return model.to(device) class DistillerTL(pl.LightningModule): diff --git a/src/pruna/algorithms/global_utils/recovery/finetuners/text_to_image_finetuner.py b/src/pruna/algorithms/global_utils/recovery/finetuners/text_to_image_finetuner.py index 935468b8..c793c4ae 100644 --- a/src/pruna/algorithms/global_utils/recovery/finetuners/text_to_image_finetuner.py +++ b/src/pruna/algorithms/global_utils/recovery/finetuners/text_to_image_finetuner.py @@ -138,7 +138,7 @@ def get_hyperparameters(cls, **override_defaults) -> List: ] @classmethod - def finetune(cls, pipeline: Any, smash_config: SmashConfigPrefixWrapper, seed: int, recoverer: str) -> Any: + def finetune(cls, model: Any, smash_config: SmashConfigPrefixWrapper, seed: int, recoverer: str) -> Any: """ Finetune the model's previously activated parameters on data. @@ -149,7 +149,7 @@ def finetune(cls, pipeline: Any, smash_config: SmashConfigPrefixWrapper, seed: i Parameters ---------- - pipeline : Any + model : Any The pipeline containing components to finetune. smash_config : SmashConfigPrefixWrapper The configuration for the finetuner. @@ -163,7 +163,7 @@ def finetune(cls, pipeline: Any, smash_config: SmashConfigPrefixWrapper, seed: i Any The finetuned pipeline. """ - dtype = get_dtype(pipeline) + dtype = get_dtype(model) device = smash_config.device if isinstance(smash_config.device, str) else smash_config.device.type # split seed into two rng: generator for the dataloader and a seed for the training part @@ -202,11 +202,11 @@ def finetune(cls, pipeline: Any, smash_config: SmashConfigPrefixWrapper, seed: i optimizer_name = "AdamW" # Check resolution mismatch - utils.check_resolution_mismatch(pipeline, train_dataloader) + utils.check_resolution_mismatch(model, train_dataloader) # Finetune the model trainable_denoiser = DenoiserTL( - pipeline, + model, optimizer_name, smash_config["learning_rate"], smash_config["weight_decay"], @@ -263,7 +263,7 @@ def finetune(cls, pipeline: Any, smash_config: SmashConfigPrefixWrapper, seed: i # Loading the best checkpoint is slow and currently creates conflicts with some quantization algorithms, # e.g. diffusers_int8. Skipping calling DenoiserTL.load_from_checkpoint for now. - return pipeline.to(device) + return model.to(device) class DenoiserTL(pl.LightningModule): diff --git a/src/pruna/algorithms/huggingface_diffusers_int8.py b/src/pruna/algorithms/huggingface_diffusers_int8.py index b766dfca..0a506cdf 100644 --- a/src/pruna/algorithms/huggingface_diffusers_int8.py +++ b/src/pruna/algorithms/huggingface_diffusers_int8.py @@ -180,7 +180,7 @@ def quantize_working_model(attr_name: str | None, working_model: nn.Module, subp subpaths : list[str] The subpaths of the working model to quantize. """ - if not hasattr(working_model, "save_pretrained") or not callable(working_model.save_pretrained): + if not hasattr(working_model, "save_pretrained") or not callable(getattr(working_model, "save_pretrained")): raise ValueError( "diffusers-int8 was applied to a module which didn't have a callable save_pretrained method." ) diff --git a/src/pruna/algorithms/ring_attn/utils/ring_utils.py b/src/pruna/algorithms/ring_attn/utils/ring_utils.py index e07112cb..6c33eaed 100644 --- a/src/pruna/algorithms/ring_attn/utils/ring_utils.py +++ b/src/pruna/algorithms/ring_attn/utils/ring_utils.py @@ -45,12 +45,14 @@ class LocalFunc(torch.autograd.Function): """ @staticmethod - def forward(cls, *args, **kwargs): + def forward(ctx, *args, **kwargs): """ Forward pass for the ring attention implementation. Parameters ---------- + ctx : torch.autograd.Context + The context of the forward pass. *args : Any The arguments to the forward pass. **kwargs : Any @@ -67,12 +69,14 @@ def forward(cls, *args, **kwargs): return ring_attention._scaled_dot_product_ring_flash_attention(*args, **kwargs)[:2] @staticmethod - def backward(cls, *args, **kwargs): + def backward(ctx, *args, **kwargs): """ Backward pass for ring attention implementation of flash attention. Parameters ---------- + ctx : torch.autograd.Context + The context of the backward pass. *args : Any The arguments to the backward pass. **kwargs : Any diff --git a/src/pruna/algorithms/torch_dynamic.py b/src/pruna/algorithms/torch_dynamic.py index e3b1ac3f..ce9d1f76 100644 --- a/src/pruna/algorithms/torch_dynamic.py +++ b/src/pruna/algorithms/torch_dynamic.py @@ -107,7 +107,7 @@ def _apply(self, model: Any, smash_config: SmashConfigPrefixWrapper) -> Any: else: modules_to_quantize = {torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d, torch.nn.Linear} - quantized_model = torch.quantization.quantize_dynamic( + quantized_model = torch.quantization.quantize_dynamic( # type: ignore[deprecated] model, modules_to_quantize, dtype=getattr(torch, smash_config["weight_bits"]), diff --git a/src/pruna/config/smash_config.py b/src/pruna/config/smash_config.py index ac269049..2ba9d4f2 100644 --- a/src/pruna/config/smash_config.py +++ b/src/pruna/config/smash_config.py @@ -280,7 +280,7 @@ def load_from_json(self, path: str | Path) -> None: setattr(self, name, config_dict.pop(name)) # Keep only values that still exist in the space, drop stale keys - supported_hparam_names = {hp.name for hp in SMASH_SPACE.get_hyperparameters()} + supported_hparam_names = {hp.name for hp in list(SMASH_SPACE.values())} saved_values = {k: v for k, v in config_dict.items() if k in supported_hparam_names} # Seed with the defaults, then overlay the saved values diff --git a/src/pruna/data/diffuser_distillation_data_module.py b/src/pruna/data/diffuser_distillation_data_module.py index 1134b58e..51c46da1 100644 --- a/src/pruna/data/diffuser_distillation_data_module.py +++ b/src/pruna/data/diffuser_distillation_data_module.py @@ -220,7 +220,7 @@ def _prepare_one_sample(self, filename: str, caption: str, subdir_name: str) -> torch.save(sample, filepath) -class DiffusionDistillationDataset(Dataset): +class DiffusionDistillationDataset(Dataset[Tuple[str, torch.Tensor, torch.Tensor, int]]): """ Dataset for distilling a diffusion pipeline, containing captions, latent inputs, latent outputs and seeds. @@ -243,9 +243,9 @@ def __len__(self) -> int: """Return the number of samples in the dataset.""" return len(self.filenames) - def __getitem__(self, idx: int) -> Tuple[str, torch.Tensor, torch.Tensor, int]: + def __getitem__(self, index: int) -> Tuple[str, torch.Tensor, torch.Tensor, int]: """Get an item from the dataset.""" - filepath = self.path / self.filenames[idx] + filepath = self.path / self.filenames[index] # This is the most generic way to load the data, but may cause a bottleneck because of continuous disk access # Loading the whole dataset into memory is often possible given the typically small size of distillation datasets # This can be explored if this is identified as a causing a latency bottleneck diff --git a/tests/algorithms/testers/ring_distributer.py b/tests/algorithms/testers/ring_distributer.py index 2934bdb1..3044646b 100644 --- a/tests/algorithms/testers/ring_distributer.py +++ b/tests/algorithms/testers/ring_distributer.py @@ -13,6 +13,7 @@ # limitations under the License. import pytest +from _pytest.outcomes import Skipped from pruna.algorithms.ring_attn.ring import RingAttn from pruna.engine.pruna_model import PrunaModel @@ -46,7 +47,7 @@ def execute_load(self): def execute_evaluation(self, model, datamodule, device): """Skip evaluation for distributed models as it's not fully supported.""" - pytest.skip("Evaluation not supported for distributed ring_attn models") + raise Skipped("Evaluation not supported for distributed ring_attn models") @classmethod def execute_save(cls, smashed_model: PrunaModel): diff --git a/tests/algorithms/testers/sage_attn.py b/tests/algorithms/testers/sage_attn.py index 8fdabd30..eff7d68c 100644 --- a/tests/algorithms/testers/sage_attn.py +++ b/tests/algorithms/testers/sage_attn.py @@ -13,4 +13,4 @@ class TestSageAttn(AlgorithmTesterBase): reject_models = ["opt_tiny_random"] allow_pickle_files = False algorithm_class = SageAttn - metrics = ["latency"] \ No newline at end of file + metrics = ["latency"] diff --git a/tests/common.py b/tests/common.py index 9c6f10a5..f7f7daaf 100644 --- a/tests/common.py +++ b/tests/common.py @@ -9,6 +9,7 @@ import numpydoc_validation import pytest import torch +from _pytest.outcomes import Skipped from accelerate.utils import compute_module_sizes, infer_auto_device_map from docutils.core import publish_doctree from docutils.nodes import literal_block, section, title @@ -118,7 +119,9 @@ def run_full_integration( try: model, smash_config = model_fixture[0], model_fixture[1] if device not in algorithm_tester.compatible_devices(): - pytest.skip(f"Algorithm {algorithm_tester.get_algorithm_name()} is not compatible with {device}") + raise Skipped( + f"Algorithm {algorithm_tester.get_algorithm_name()} is not compatible with {device}" + ) algorithm_tester.prepare_smash_config(smash_config, device) device_map = construct_device_map_manually(model) if device == "accelerate" else None move_to_device(model, device=smash_config["device"], device_map=device_map) @@ -344,4 +347,4 @@ def extract_code_blocks_from_node(node: Any, section_name: str) -> None: section_title = section_title_node.astext().replace(" ", "_").lower() extract_code_blocks_from_node(sec, section_title) - print(f"Code blocks extracted and written to {output_dir}") \ No newline at end of file + print(f"Code blocks extracted and written to {output_dir}")