From 208be7e243375670c57aa1b61d239768af313d9c Mon Sep 17 00:00:00 2001 From: davidberenstein1957 Date: Thu, 22 May 2025 21:08:15 +0200 Subject: [PATCH 01/11] feat: enhance device compatibility checks across evaluation metrics - Added `check_device_compatibility` utility to ensure proper device assignment in various metrics. - Updated device handling in `Task`, `CMMD`, `InferenceTimeStats`, and other metrics to utilize the new compatibility check. - Improved docstrings to clarify device parameter usage and fallback behavior. - Streamlined device assignment in model evaluation and metric calculations for better robustness. --- src/pruna/engine/utils.py | 10 +++--- src/pruna/evaluation/evaluation_agent.py | 4 +-- src/pruna/evaluation/metrics/metric_cmmd.py | 10 +++--- .../evaluation/metrics/metric_elapsed_time.py | 31 +++++++++++-------- src/pruna/evaluation/metrics/metric_energy.py | 22 ++++++++----- src/pruna/evaluation/metrics/metric_memory.py | 4 +-- .../metrics/metric_model_architecture.py | 20 +++++++----- src/pruna/evaluation/metrics/metric_torch.py | 3 +- src/pruna/evaluation/task.py | 17 ++++++---- tests/engine/test_device.py | 4 +-- 10 files changed, 74 insertions(+), 51 deletions(-) diff --git a/src/pruna/engine/utils.py b/src/pruna/engine/utils.py index 2b145fba..ee02b43d 100644 --- a/src/pruna/engine/utils.py +++ b/src/pruna/engine/utils.py @@ -266,19 +266,19 @@ def check_device_compatibility(device: str | torch.device | None) -> str: if device is None: device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" - pruna_logger.info(f"No device specified. Using best available device: '{device}'") + pruna_logger.info(f"Using best available device: '{device}'") return device if device == "cpu": return "cpu" if device == "cuda" and not torch.cuda.is_available(): - pruna_logger.warning("'cuda' requested but not available. Falling back to 'cpu'") - return "cpu" + pruna_logger.warning("'cuda' requested but not available.") + return check_device_compatibility(device=None) if device == "mps" and not torch.backends.mps.is_available(): - pruna_logger.warning("'mps' requested but not available. Falling back to 'cpu'") - return "cpu" + pruna_logger.warning("'mps' requested but not available.") + return check_device_compatibility(device=None) return device diff --git a/src/pruna/evaluation/evaluation_agent.py b/src/pruna/evaluation/evaluation_agent.py index 462d3fdb..35edf258 100644 --- a/src/pruna/evaluation/evaluation_agent.py +++ b/src/pruna/evaluation/evaluation_agent.py @@ -19,7 +19,7 @@ from pruna.config.utils import is_empty_config from pruna.engine.pruna_model import PrunaModel -from pruna.engine.utils import safe_memory_cleanup +from pruna.engine.utils import check_device_compatibility, safe_memory_cleanup from pruna.evaluation.metrics.metric_stateful import StatefulMetric from pruna.evaluation.metrics.result import MetricResult from pruna.evaluation.metrics.utils import group_metrics_by_inheritance @@ -41,7 +41,7 @@ def __init__(self, task: Task) -> None: self.task = task self.first_model_results: List[MetricResult] = [] self.subsequent_model_results: List[MetricResult] = [] - self.device = self.task.device + self.device = check_device_compatibility(self.task.device) self.cache: List[Tensor] = [] self.evaluation_for_first_model: bool = True diff --git a/src/pruna/evaluation/metrics/metric_cmmd.py b/src/pruna/evaluation/metrics/metric_cmmd.py index 214a019a..3b8e7c3e 100644 --- a/src/pruna/evaluation/metrics/metric_cmmd.py +++ b/src/pruna/evaluation/metrics/metric_cmmd.py @@ -21,6 +21,7 @@ from huggingface_hub.utils import EntryNotFoundError from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection +from pruna.engine.utils import check_device_compatibility from pruna.evaluation.metrics.metric_stateful import StatefulMetric from pruna.evaluation.metrics.registry import MetricRegistry from pruna.evaluation.metrics.result import MetricResult @@ -41,8 +42,9 @@ class CMMD(StatefulMetric): ---------- *args : Any Additional arguments to pass to the StatefulMetric constructor. - device : str | torch.device - The device to run the CLIP model on to calculate the embeddings for the metric. + device : str | torch.device | None, optional + The device to be used for smashing, e.g., 'cuda' or 'cpu'. Default is None. + If None, the best available device will be used. clip_model_name : str The name of the CLIP model to use. call_type : str @@ -60,13 +62,13 @@ class CMMD(StatefulMetric): def __init__( self, *args, - device: str | torch.device = "cuda", + device: str | torch.device | None = None, clip_model_name: str = "openai/clip-vit-large-patch14-336", call_type: str = SINGLE, **kwargs, ) -> None: super().__init__(*args, **kwargs) - self.device = device + self.device = check_device_compatibility(device) try: model_info(clip_model_name) except EntryNotFoundError: diff --git a/src/pruna/evaluation/metrics/metric_elapsed_time.py b/src/pruna/evaluation/metrics/metric_elapsed_time.py index e923fa0a..32160c5b 100644 --- a/src/pruna/evaluation/metrics/metric_elapsed_time.py +++ b/src/pruna/evaluation/metrics/metric_elapsed_time.py @@ -22,6 +22,7 @@ from torch.utils.data import DataLoader from pruna.engine.pruna_model import PrunaModel +from pruna.engine.utils import check_device_compatibility from pruna.evaluation.metrics.metric_base import BaseMetric from pruna.evaluation.metrics.registry import MetricRegistry from pruna.evaluation.metrics.result import MetricResult @@ -54,8 +55,9 @@ class InferenceTimeStats(BaseMetric): The number of batches to evaluate the model. n_warmup_iterations : int, default=10 The number of warmup batches to evaluate the model. - device : str | torch.device, default="cuda" - The device to evaluate the model on. + device : str | torch.device | None, optional + The device to be used for smashing, e.g., 'cuda' or 'cpu'. Default is None. + If None, the best available device will be used. timing_type : str, default="sync" The type of timing to use. """ @@ -64,12 +66,12 @@ def __init__( self, n_iterations: int = 100, n_warmup_iterations: int = 10, - device: str | torch.device = "cuda", + device: str | torch.device | None = None, timing_type: str = "sync", ) -> None: self.n_iterations = n_iterations self.n_warmup_iterations = n_warmup_iterations - self.device = device + self.device = check_device_compatibility(device) self.timing_type = timing_type def _measure(self, model: PrunaModel, dataloader: DataLoader, iterations: int, measure_fn) -> None: @@ -118,12 +120,12 @@ def _time_inference(self, model: PrunaModel, x: Any) -> float: endevent_time = time.time() return (endevent_time - startevent_time) * 1000 # in ms elif self.timing_type == "sync": - startevent = torch.cuda.Event(enable_timing=True) - endevent = torch.cuda.Event(enable_timing=True) + startevent = getattr(torch, self.device).Event(enable_timing=True) + endevent = getattr(torch, self.device).Event(enable_timing=True) startevent.record() _ = model(x, **model.inference_handler.model_args) endevent.record() - torch.cuda.synchronize() + getattr(torch, self.device).synchronize() return startevent.elapsed_time(endevent) # in ms else: raise ValueError(f"Timing type {self.timing_type} not supported.") @@ -183,8 +185,9 @@ class LatencyMetric(InferenceTimeStats): The number of batches to evaluate the model. n_warmup_iterations : int, default=10 The number of warmup batches to evaluate the model. - device : str | torch.device, default="cuda" - The device to evaluate the model on. + device : str | torch.device | None, optional + The device to be used for smashing, e.g., 'cuda' or 'cpu'. Default is None. + If None, the best available device will be used. timing_type : str, default="sync" The type of timing to use. """ @@ -227,8 +230,9 @@ class ThroughputMetric(InferenceTimeStats): The number of batches to evaluate the model. n_warmup_iterations : int, default=10 The number of warmup batches to evaluate the model. - device : str | torch.device, default="cuda" - The device to evaluate the model on. + device : str | torch.device | None, optional + The device to be used for smashing, e.g., 'cuda' or 'cpu'. Default is None. + If None, the best available device will be used. timing_type : str, default="sync" The type of timing to use. """ @@ -271,8 +275,9 @@ class TotalTimeMetric(InferenceTimeStats): The number of batches to evaluate the model. n_warmup_iterations : int, default=10 The number of warmup batches to evaluate the model. - device : str | torch.device, default="cuda" - The device to evaluate the model on. + device : str | torch.device | None, optional + The device to be used for smashing, e.g., 'cuda' or 'cpu'. Default is None. + If None, the best available device will be used. timing_type : str, default="sync" The type of timing to use. """ diff --git a/src/pruna/evaluation/metrics/metric_energy.py b/src/pruna/evaluation/metrics/metric_energy.py index 9a3c1494..4217f769 100644 --- a/src/pruna/evaluation/metrics/metric_energy.py +++ b/src/pruna/evaluation/metrics/metric_energy.py @@ -22,6 +22,7 @@ from torch.utils.data import DataLoader from pruna.engine.pruna_model import PrunaModel +from pruna.engine.utils import check_device_compatibility from pruna.evaluation.metrics.metric_base import BaseMetric from pruna.evaluation.metrics.registry import MetricRegistry from pruna.evaluation.metrics.result import MetricResult @@ -53,16 +54,17 @@ class EnvironmentalImpactStats(BaseMetric): are not averaged and will therefore increase with this argument. n_warmup_iterations : int, default=10 The number of warmup batches to evaluate the model. - device : str | torch.device, default="cuda" - The device to evaluate the model on. + device : str | torch.device | None, optional + The device to be used for smashing, e.g., 'cuda' or 'cpu'. Default is None. + If None, the best available device will be used. """ def __init__( - self, n_iterations: int = 100, n_warmup_iterations: int = 10, device: str | torch.device = "cuda" + self, n_iterations: int = 100, n_warmup_iterations: int = 10, device: str | torch.device | None = None ) -> None: self.n_iterations = n_iterations self.n_warmup_iterations = n_warmup_iterations - self.device = device + self.device = check_device_compatibility(device) @torch.no_grad() def compute(self, model: PrunaModel, dataloader: DataLoader) -> Dict[str, Any] | MetricResult: @@ -115,6 +117,8 @@ def compute(self, model: PrunaModel, dataloader: DataLoader) -> Dict[str, Any] | # Make sure all the operations are finished before stopping the tracker if self.device == "cuda" or str(self.device).startswith("cuda"): torch.cuda.synchronize() + elif self.device == "mps": + torch.mps.synchronize() tracker.stop() emissions_data = self._collect_emissions_data(tracker) @@ -155,8 +159,9 @@ class EnergyConsumedMetric(EnvironmentalImpactStats): are not averaged and will therefore increase with this argument. n_warmup_iterations : int, default=10 The number of warmup batches to evaluate the model. - device : str | torch.device, default="cuda" - The device to evaluate the model on. + device : str | torch.device | None, optional + The device to be used for smashing, e.g., 'cuda' or 'cpu'. Default is None. + If None, the best available device will be used. """ higher_is_better: bool = False @@ -197,8 +202,9 @@ class CO2EmissionsMetric(EnvironmentalImpactStats): are not averaged and will therefore increase with this argument. n_warmup_iterations : int, default=10 The number of warmup batches to evaluate the model. - device : str | torch.device, default="cuda" - The device to evaluate the model on. + device : str | torch.device | None, optional + The device to be used for smashing, e.g., 'cuda' or 'cpu'. Default is None. + If None, the best available device will be used. """ higher_is_better: bool = False diff --git a/src/pruna/evaluation/metrics/metric_memory.py b/src/pruna/evaluation/metrics/metric_memory.py index b3d5fa87..e0aca482 100644 --- a/src/pruna/evaluation/metrics/metric_memory.py +++ b/src/pruna/evaluation/metrics/metric_memory.py @@ -23,7 +23,7 @@ from torch.utils.data import DataLoader from pruna.engine.pruna_model import PrunaModel -from pruna.engine.utils import safe_memory_cleanup, set_to_train +from pruna.engine.utils import check_device_compatibility, safe_memory_cleanup, set_to_train from pruna.evaluation.metrics.metric_base import BaseMetric from pruna.evaluation.metrics.registry import MetricRegistry from pruna.evaluation.metrics.result import MetricResult @@ -328,7 +328,7 @@ def _load_and_prepare_model(self, model_path: str, model_cls: Type[PrunaModel]) model = model_cls.from_pretrained( model_path, ) - model.move_to_device("cuda") + model.move_to_device(check_device_compatibility(None)) if self.mode in {DISK_MEMORY, INFERENCE_MEMORY}: model.set_to_eval() elif self.mode == TRAINING_MEMORY: diff --git a/src/pruna/evaluation/metrics/metric_model_architecture.py b/src/pruna/evaluation/metrics/metric_model_architecture.py index beedabdb..5786b617 100644 --- a/src/pruna/evaluation/metrics/metric_model_architecture.py +++ b/src/pruna/evaluation/metrics/metric_model_architecture.py @@ -24,6 +24,7 @@ from pruna.engine.call_sequence_tracker import CallSequenceTracker from pruna.engine.pruna_model import PrunaModel +from pruna.engine.utils import check_device_compatibility from pruna.evaluation.metrics.metric_base import BaseMetric from pruna.evaluation.metrics.registry import MetricRegistry from pruna.evaluation.metrics.result import MetricResult @@ -50,12 +51,13 @@ class ModelArchitectureStats(BaseMetric): Parameters ---------- - device : str | torch.device, default="cuda" - The device to evaluate the model on. + device : str | torch.device | None, optional + The device to be used for smashing, e.g., 'cuda' or 'cpu'. Default is None. + If None, the best available device will be used. """ - def __init__(self, device: str | torch.device = "cuda") -> None: - self.device = device + def __init__(self, device: str | torch.device | None = None) -> None: + self.device = check_device_compatibility(device) self.module_macs: Dict[str, Any] = {} self.module_params: Dict[str, Any] = {} self.call_tracker = CallSequenceTracker() @@ -194,8 +196,9 @@ class TotalMACsMetric(ModelArchitectureStats): Parameters ---------- - device : str | torch.device, default="cuda" - The device to evaluate the model on. + device : str | torch.device | None, optional + The device to be used for smashing, e.g., 'cuda' or 'cpu'. Default is None. + If None, the best available device will be used. """ metric_name: str = TOTAL_MACS @@ -231,8 +234,9 @@ class TotalParamsMetric(ModelArchitectureStats): Parameters ---------- - device : str | torch.device, default="cuda" - The device to evaluate the model on. + device : str | torch.device | None, optional + The device to be used for smashing, e.g., 'cuda' or 'cpu'. Default is None. + If None, the best available device will be used. """ metric_name: str = TOTAL_PARAMS diff --git a/src/pruna/evaluation/metrics/metric_torch.py b/src/pruna/evaluation/metrics/metric_torch.py index 8618f3e2..464f0d94 100644 --- a/src/pruna/evaluation/metrics/metric_torch.py +++ b/src/pruna/evaluation/metrics/metric_torch.py @@ -33,6 +33,7 @@ from torchmetrics.text import Perplexity from torchvision import transforms +from pruna.engine.utils import check_device_compatibility from pruna.evaluation.metrics.metric_stateful import StatefulMetric from pruna.evaluation.metrics.registry import MetricRegistry from pruna.evaluation.metrics.result import MetricResult @@ -233,7 +234,7 @@ def __init__(self, metric_name: str, call_type: str = "", **kwargs) -> None: super().__init__() try: if metric_name == "perplexity": - device = kwargs.pop("device", "cuda") + device = check_device_compatibility(device=kwargs.get("device")) self.metric = TorchMetrics[metric_name](**kwargs).to(device) else: self.metric = TorchMetrics[metric_name](**kwargs) diff --git a/src/pruna/evaluation/task.py b/src/pruna/evaluation/task.py index bee014e9..724c8c61 100644 --- a/src/pruna/evaluation/task.py +++ b/src/pruna/evaluation/task.py @@ -20,6 +20,7 @@ import torch from pruna.data.pruna_datamodule import PrunaDataModule +from pruna.engine.utils import check_device_compatibility from pruna.evaluation.metrics.metric_base import BaseMetric from pruna.evaluation.metrics.metric_cmmd import CMMD from pruna.evaluation.metrics.metric_elapsed_time import LATENCY, THROUGHPUT, TOTAL_TIME @@ -56,20 +57,22 @@ class Task: The user request. datamodule : PrunaDataModule The dataloader to use for the evaluation. - device : str | torch.device - The device to use for the evaluation. + device : str | torch.device | None, optional + The device to be used for smashing, e.g., 'cuda' or 'cpu'. Default is None. + If None, the best available device will be used. """ def __init__( self, request: str | List[str | BaseMetric | StatefulMetric], datamodule: PrunaDataModule, - device: str | torch.device = "cuda", + device: str | torch.device | None = None, ) -> None: - self.metrics = get_metrics(request) + device = check_device_compatibility(device) + self.metrics = get_metrics(request, device) self.datamodule = datamodule self.dataloader = datamodule.test_dataloader() - self.device = device + self.device = check_device_compatibility(device) def get_single_stateful_metrics(self) -> List[StatefulMetric]: """ @@ -116,7 +119,9 @@ def is_pairwise_evaluation(self) -> bool: return any(metric.is_pairwise() for metric in self.metrics if isinstance(metric, StatefulMetric)) -def get_metrics(request: str | List[str | BaseMetric | StatefulMetric]) -> List[BaseMetric | StatefulMetric]: +def get_metrics( + request: str | List[str | BaseMetric | StatefulMetric], device: str | torch.device | None = None +) -> List[BaseMetric | StatefulMetric]: """ Convert user requests into a list of metrics. diff --git a/tests/engine/test_device.py b/tests/engine/test_device.py index 82ddce8a..406a83aa 100644 --- a/tests/engine/test_device.py +++ b/tests/engine/test_device.py @@ -30,8 +30,8 @@ def test_device_none() -> None: @pytest.mark.parametrize( "device,expected", [ - ("mps", "cpu") if torch.cuda.is_available() else ("cuda", "cpu"), - (torch.device("mps"), "cpu") if torch.cuda.is_available() else (torch.device("cuda"), "cpu"), + ("mps", "cuda") if torch.cuda.is_available() else ("cuda", "mps"), + (torch.device("mps"), "cuda") if torch.cuda.is_available() else (torch.device("cuda"), "mps"), ], ) def test_device_available(device: str | torch.device, expected: str) -> None: From e9fc2e2612f2f1844495c561e5bd15fe5fc74d33 Mon Sep 17 00:00:00 2001 From: davidberenstein1957 Date: Thu, 22 May 2025 21:09:51 +0200 Subject: [PATCH 02/11] refactor: replace device compatibility checks with new utility function - Updated all instances of `check_device_compatibility` to `set_to_best_available_device` across various modules, including `SmashConfig`, evaluation metrics, and task handling. --- src/pruna/config/smash_config.py | 8 ++++---- src/pruna/engine/utils.py | 8 ++++---- src/pruna/evaluation/evaluation_agent.py | 4 ++-- src/pruna/evaluation/metrics/metric_cmmd.py | 4 ++-- src/pruna/evaluation/metrics/metric_elapsed_time.py | 4 ++-- src/pruna/evaluation/metrics/metric_energy.py | 4 ++-- src/pruna/evaluation/metrics/metric_memory.py | 4 ++-- src/pruna/evaluation/metrics/metric_model_architecture.py | 4 ++-- src/pruna/evaluation/metrics/metric_torch.py | 4 ++-- src/pruna/evaluation/task.py | 6 +++--- 10 files changed, 25 insertions(+), 25 deletions(-) diff --git a/src/pruna/config/smash_config.py b/src/pruna/config/smash_config.py index 864cb7fa..c6b00f95 100644 --- a/src/pruna/config/smash_config.py +++ b/src/pruna/config/smash_config.py @@ -32,7 +32,7 @@ from pruna.config.smash_space import ALGORITHM_GROUPS, SMASH_SPACE from pruna.data.pruna_datamodule import PrunaDataModule, TokenizerMissingError -from pruna.engine.utils import check_device_compatibility +from pruna.engine.utils import set_to_best_available_device from pruna.logging.logger import pruna_logger ADDITIONAL_ARGS = [ @@ -90,7 +90,7 @@ def __init__( self.batch_size = max_batch_size else: self.batch_size = batch_size - self.device = check_device_compatibility(device) + self.device = set_to_best_available_device(device) self.cache_dir_prefix = cache_dir_prefix if not os.path.exists(cache_dir_prefix): @@ -157,7 +157,7 @@ def load_from_json(self, path: str) -> None: # check device compatibility if "device" in config_dict: - config_dict["device"] = check_device_compatibility(config_dict["device"]) + config_dict["device"] = set_to_best_available_device(config_dict["device"]) # support deprecated load_fn if "load_fn" in config_dict: @@ -265,7 +265,7 @@ def load_dict(self, config_dict: dict) -> None: """ # check device compatibility if "device" in config_dict: - config_dict["device"] = check_device_compatibility(config_dict["device"]) + config_dict["device"] = set_to_best_available_device(config_dict["device"]) # since this function is only used for loading algorithm settings, we will ignore additional arguments filtered_config_dict = {k: v for k, v in config_dict.items() if k not in ADDITIONAL_ARGS} diff --git a/src/pruna/engine/utils.py b/src/pruna/engine/utils.py index ee02b43d..b400f203 100644 --- a/src/pruna/engine/utils.py +++ b/src/pruna/engine/utils.py @@ -244,9 +244,9 @@ def determine_dtype(pipeline: Any) -> torch.dtype: return torch.float32 -def check_device_compatibility(device: str | torch.device | None) -> str: +def set_to_best_available_device(device: str | torch.device | None) -> str: """ - Validate if the specified device is available on the current system. + Set the device to the best available device. Supports 'cuda', 'mps', 'cpu' and other PyTorch devices. If device is None, the best available device will be returned. @@ -274,11 +274,11 @@ def check_device_compatibility(device: str | torch.device | None) -> str: if device == "cuda" and not torch.cuda.is_available(): pruna_logger.warning("'cuda' requested but not available.") - return check_device_compatibility(device=None) + return set_to_best_available_device(device=None) if device == "mps" and not torch.backends.mps.is_available(): pruna_logger.warning("'mps' requested but not available.") - return check_device_compatibility(device=None) + return set_to_best_available_device(device=None) return device diff --git a/src/pruna/evaluation/evaluation_agent.py b/src/pruna/evaluation/evaluation_agent.py index 35edf258..d25c991b 100644 --- a/src/pruna/evaluation/evaluation_agent.py +++ b/src/pruna/evaluation/evaluation_agent.py @@ -19,7 +19,7 @@ from pruna.config.utils import is_empty_config from pruna.engine.pruna_model import PrunaModel -from pruna.engine.utils import check_device_compatibility, safe_memory_cleanup +from pruna.engine.utils import safe_memory_cleanup, set_to_best_available_device from pruna.evaluation.metrics.metric_stateful import StatefulMetric from pruna.evaluation.metrics.result import MetricResult from pruna.evaluation.metrics.utils import group_metrics_by_inheritance @@ -41,7 +41,7 @@ def __init__(self, task: Task) -> None: self.task = task self.first_model_results: List[MetricResult] = [] self.subsequent_model_results: List[MetricResult] = [] - self.device = check_device_compatibility(self.task.device) + self.device = set_to_best_available_device(self.task.device) self.cache: List[Tensor] = [] self.evaluation_for_first_model: bool = True diff --git a/src/pruna/evaluation/metrics/metric_cmmd.py b/src/pruna/evaluation/metrics/metric_cmmd.py index 3b8e7c3e..e366ca72 100644 --- a/src/pruna/evaluation/metrics/metric_cmmd.py +++ b/src/pruna/evaluation/metrics/metric_cmmd.py @@ -21,7 +21,7 @@ from huggingface_hub.utils import EntryNotFoundError from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection -from pruna.engine.utils import check_device_compatibility +from pruna.engine.utils import set_to_best_available_device from pruna.evaluation.metrics.metric_stateful import StatefulMetric from pruna.evaluation.metrics.registry import MetricRegistry from pruna.evaluation.metrics.result import MetricResult @@ -68,7 +68,7 @@ def __init__( **kwargs, ) -> None: super().__init__(*args, **kwargs) - self.device = check_device_compatibility(device) + self.device = set_to_best_available_device(device) try: model_info(clip_model_name) except EntryNotFoundError: diff --git a/src/pruna/evaluation/metrics/metric_elapsed_time.py b/src/pruna/evaluation/metrics/metric_elapsed_time.py index 32160c5b..7a394186 100644 --- a/src/pruna/evaluation/metrics/metric_elapsed_time.py +++ b/src/pruna/evaluation/metrics/metric_elapsed_time.py @@ -22,7 +22,7 @@ from torch.utils.data import DataLoader from pruna.engine.pruna_model import PrunaModel -from pruna.engine.utils import check_device_compatibility +from pruna.engine.utils import set_to_best_available_device from pruna.evaluation.metrics.metric_base import BaseMetric from pruna.evaluation.metrics.registry import MetricRegistry from pruna.evaluation.metrics.result import MetricResult @@ -71,7 +71,7 @@ def __init__( ) -> None: self.n_iterations = n_iterations self.n_warmup_iterations = n_warmup_iterations - self.device = check_device_compatibility(device) + self.device = set_to_best_available_device(device) self.timing_type = timing_type def _measure(self, model: PrunaModel, dataloader: DataLoader, iterations: int, measure_fn) -> None: diff --git a/src/pruna/evaluation/metrics/metric_energy.py b/src/pruna/evaluation/metrics/metric_energy.py index 4217f769..780bc24a 100644 --- a/src/pruna/evaluation/metrics/metric_energy.py +++ b/src/pruna/evaluation/metrics/metric_energy.py @@ -22,7 +22,7 @@ from torch.utils.data import DataLoader from pruna.engine.pruna_model import PrunaModel -from pruna.engine.utils import check_device_compatibility +from pruna.engine.utils import set_to_best_available_device from pruna.evaluation.metrics.metric_base import BaseMetric from pruna.evaluation.metrics.registry import MetricRegistry from pruna.evaluation.metrics.result import MetricResult @@ -64,7 +64,7 @@ def __init__( ) -> None: self.n_iterations = n_iterations self.n_warmup_iterations = n_warmup_iterations - self.device = check_device_compatibility(device) + self.device = set_to_best_available_device(device) @torch.no_grad() def compute(self, model: PrunaModel, dataloader: DataLoader) -> Dict[str, Any] | MetricResult: diff --git a/src/pruna/evaluation/metrics/metric_memory.py b/src/pruna/evaluation/metrics/metric_memory.py index e0aca482..882911b1 100644 --- a/src/pruna/evaluation/metrics/metric_memory.py +++ b/src/pruna/evaluation/metrics/metric_memory.py @@ -23,7 +23,7 @@ from torch.utils.data import DataLoader from pruna.engine.pruna_model import PrunaModel -from pruna.engine.utils import check_device_compatibility, safe_memory_cleanup, set_to_train +from pruna.engine.utils import safe_memory_cleanup, set_to_best_available_device, set_to_train from pruna.evaluation.metrics.metric_base import BaseMetric from pruna.evaluation.metrics.registry import MetricRegistry from pruna.evaluation.metrics.result import MetricResult @@ -328,7 +328,7 @@ def _load_and_prepare_model(self, model_path: str, model_cls: Type[PrunaModel]) model = model_cls.from_pretrained( model_path, ) - model.move_to_device(check_device_compatibility(None)) + model.move_to_device(set_to_best_available_device(None)) if self.mode in {DISK_MEMORY, INFERENCE_MEMORY}: model.set_to_eval() elif self.mode == TRAINING_MEMORY: diff --git a/src/pruna/evaluation/metrics/metric_model_architecture.py b/src/pruna/evaluation/metrics/metric_model_architecture.py index 5786b617..eb50fe74 100644 --- a/src/pruna/evaluation/metrics/metric_model_architecture.py +++ b/src/pruna/evaluation/metrics/metric_model_architecture.py @@ -24,7 +24,7 @@ from pruna.engine.call_sequence_tracker import CallSequenceTracker from pruna.engine.pruna_model import PrunaModel -from pruna.engine.utils import check_device_compatibility +from pruna.engine.utils import set_to_best_available_device from pruna.evaluation.metrics.metric_base import BaseMetric from pruna.evaluation.metrics.registry import MetricRegistry from pruna.evaluation.metrics.result import MetricResult @@ -57,7 +57,7 @@ class ModelArchitectureStats(BaseMetric): """ def __init__(self, device: str | torch.device | None = None) -> None: - self.device = check_device_compatibility(device) + self.device = set_to_best_available_device(device) self.module_macs: Dict[str, Any] = {} self.module_params: Dict[str, Any] = {} self.call_tracker = CallSequenceTracker() diff --git a/src/pruna/evaluation/metrics/metric_torch.py b/src/pruna/evaluation/metrics/metric_torch.py index 464f0d94..83a92e7a 100644 --- a/src/pruna/evaluation/metrics/metric_torch.py +++ b/src/pruna/evaluation/metrics/metric_torch.py @@ -33,7 +33,7 @@ from torchmetrics.text import Perplexity from torchvision import transforms -from pruna.engine.utils import check_device_compatibility +from pruna.engine.utils import set_to_best_available_device from pruna.evaluation.metrics.metric_stateful import StatefulMetric from pruna.evaluation.metrics.registry import MetricRegistry from pruna.evaluation.metrics.result import MetricResult @@ -234,7 +234,7 @@ def __init__(self, metric_name: str, call_type: str = "", **kwargs) -> None: super().__init__() try: if metric_name == "perplexity": - device = check_device_compatibility(device=kwargs.get("device")) + device = set_to_best_available_device(device=kwargs.get("device")) self.metric = TorchMetrics[metric_name](**kwargs).to(device) else: self.metric = TorchMetrics[metric_name](**kwargs) diff --git a/src/pruna/evaluation/task.py b/src/pruna/evaluation/task.py index 724c8c61..45971016 100644 --- a/src/pruna/evaluation/task.py +++ b/src/pruna/evaluation/task.py @@ -20,7 +20,7 @@ import torch from pruna.data.pruna_datamodule import PrunaDataModule -from pruna.engine.utils import check_device_compatibility +from pruna.engine.utils import set_to_best_available_device from pruna.evaluation.metrics.metric_base import BaseMetric from pruna.evaluation.metrics.metric_cmmd import CMMD from pruna.evaluation.metrics.metric_elapsed_time import LATENCY, THROUGHPUT, TOTAL_TIME @@ -68,11 +68,11 @@ def __init__( datamodule: PrunaDataModule, device: str | torch.device | None = None, ) -> None: - device = check_device_compatibility(device) + device = set_to_best_available_device(device) self.metrics = get_metrics(request, device) self.datamodule = datamodule self.dataloader = datamodule.test_dataloader() - self.device = check_device_compatibility(device) + self.device = set_to_best_available_device(device) def get_single_stateful_metrics(self) -> List[StatefulMetric]: """ From 05430d525189ea5456f0c1940ef937e788b0aeaf Mon Sep 17 00:00:00 2001 From: davidberenstein1957 Date: Thu, 22 May 2025 21:25:55 +0200 Subject: [PATCH 03/11] refactor: update inference method to use best available device - Modified the `run_inference` method in `PrunaModel` to accept a device parameter that defaults to None, utilizing the new utility function `set_to_best_available_device` for improved device management. - Cleaned up the import statements in `pruna_model.py` to include the new device utility. - Removed unnecessary blank line in the test file `test_cmmd.py` for better code cleanliness. --- src/pruna/engine/pruna_model.py | 11 +++++++---- tests/evaluation/test_cmmd.py | 1 - 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/src/pruna/engine/pruna_model.py b/src/pruna/engine/pruna_model.py index 78dcaecb..4d89a154 100644 --- a/src/pruna/engine/pruna_model.py +++ b/src/pruna/engine/pruna_model.py @@ -25,7 +25,7 @@ from pruna.engine.handler.handler_utils import register_inference_handler from pruna.engine.load import load_pruna_model, load_pruna_model_from_hub from pruna.engine.save import save_pruna_model, save_pruna_model_to_hub -from pruna.engine.utils import get_nn_modules, move_to_device, set_to_eval +from pruna.engine.utils import get_nn_modules, move_to_device, set_to_best_available_device, set_to_eval from pruna.logging.filter import apply_warning_filter from pruna.telemetry import increment_counter, track_usage @@ -74,7 +74,9 @@ def __call__(self, *args: Any, **kwargs: Any) -> Any: with torch.no_grad(): return self.model.__call__(*args, **kwargs) - def run_inference(self, batch: Tuple[List[str] | torch.Tensor, ...], device: torch.device | str) -> Any: + def run_inference( + self, batch: Tuple[List[str] | torch.Tensor, ...], device: torch.device | str | None = None + ) -> Any: """ Run inference on the model. @@ -82,14 +84,15 @@ def run_inference(self, batch: Tuple[List[str] | torch.Tensor, ...], device: tor ---------- batch : Tuple[List[str] | torch.Tensor, ...] The batch to run inference on. - device : torch.device | str - The device to run inference on. + device : torch.device | str | None + The device to run inference on. If None, the best available device will be used. Returns ------- Any The processed output. """ + device = set_to_best_available_device(device) batch = self.inference_handler.move_inputs_to_device(batch, device) # type: ignore prepared_inputs = self.inference_handler.prepare_inputs(batch) if prepared_inputs is not None: diff --git a/tests/evaluation/test_cmmd.py b/tests/evaluation/test_cmmd.py index 744984c5..2c74f8b0 100644 --- a/tests/evaluation/test_cmmd.py +++ b/tests/evaluation/test_cmmd.py @@ -9,7 +9,6 @@ from pruna.evaluation.metrics.metric_cmmd import CMMD from pruna.evaluation.task import Task - @pytest.mark.parametrize( "model_fixture, device, clip_model", [ From 236204834b1dfdd78587ae6b1a8bdb6eee959c89 Mon Sep 17 00:00:00 2001 From: davidberenstein1957 Date: Thu, 22 May 2025 21:41:15 +0200 Subject: [PATCH 04/11] refactor: enhance device handling in evaluation metrics - Updated the `set_to_best_available_device` function to raise a ValueError for unsupported devices, improving error handling. - Modified the `Task` class to directly use the provided device parameter instead of relying on the utility function for device assignment. - Enhanced the `get_metrics` and `_process_metric_names` functions to accept and utilize the device parameter, ensuring consistent device management across metric processing. - Improved docstrings to clarify the usage of the device parameter in various functions. --- src/pruna/engine/utils.py | 5 ++++- src/pruna/evaluation/metrics/registry.py | 3 ++- src/pruna/evaluation/task.py | 19 +++++++++++-------- 3 files changed, 17 insertions(+), 10 deletions(-) diff --git a/src/pruna/engine/utils.py b/src/pruna/engine/utils.py index b400f203..95a29616 100644 --- a/src/pruna/engine/utils.py +++ b/src/pruna/engine/utils.py @@ -272,6 +272,9 @@ def set_to_best_available_device(device: str | torch.device | None) -> str: if device == "cpu": return "cpu" + if device == "accelerate" and torch.cuda.is_available(): + return "accelerate" + if device == "cuda" and not torch.cuda.is_available(): pruna_logger.warning("'cuda' requested but not available.") return set_to_best_available_device(device=None) @@ -280,7 +283,7 @@ def set_to_best_available_device(device: str | torch.device | None) -> str: pruna_logger.warning("'mps' requested but not available.") return set_to_best_available_device(device=None) - return device + raise ValueError(f"Device not supported: '{device}'") class ModelContext: diff --git a/src/pruna/evaluation/metrics/registry.py b/src/pruna/evaluation/metrics/registry.py index 5ad88f13..df48bba2 100644 --- a/src/pruna/evaluation/metrics/registry.py +++ b/src/pruna/evaluation/metrics/registry.py @@ -17,6 +17,7 @@ from functools import partial from typing import Any, Callable, Dict, Iterable, List +from pruna.engine.load import filter_load_kwargs from pruna.evaluation.metrics.metric_base import BaseMetric from pruna.evaluation.metrics.metric_stateful import StatefulMetric from pruna.logging.logger import pruna_logger @@ -123,4 +124,4 @@ def get_metrics(cls, names: List[str], **kwargs) -> List[BaseMetric | StatefulMe ------- A list of metric instances. """ - return [cls.get_metric(name, **kwargs) for name in names] + return [cls.get_metric(name, **filter_load_kwargs(cls.get_metric, **kwargs)) for name in names] diff --git a/src/pruna/evaluation/task.py b/src/pruna/evaluation/task.py index 45971016..ad40c78f 100644 --- a/src/pruna/evaluation/task.py +++ b/src/pruna/evaluation/task.py @@ -72,7 +72,7 @@ def __init__( self.metrics = get_metrics(request, device) self.datamodule = datamodule self.dataloader = datamodule.test_dataloader() - self.device = set_to_best_available_device(device) + self.device = device def get_single_stateful_metrics(self) -> List[StatefulMetric]: """ @@ -129,6 +129,9 @@ def get_metrics( ---------- request : str | List[str] The user request. Right now, it only supports image generation quality. + device : str | torch.device | None, optional + The device to be used for smashing, e.g., 'cuda' or 'cpu'. Default is None. + If None, the best available device will be used. Returns ------- @@ -144,14 +147,14 @@ def get_metrics( """ if isinstance(request, List): if all(isinstance(item, BaseMetric | StatefulMetric) for item in request): - return _process_metric_instances(cast(List[BaseMetric | StatefulMetric], request)) + return _process_metric_instances(request=cast(List[BaseMetric | StatefulMetric], request)) elif all(isinstance(item, str) for item in request): - return _process_metric_names(cast(List[str], request)) + return _process_metric_names(request=cast(List[str], request), device=device) else: pruna_logger.error("List must contain either all strings or all [BaseMetric | StatefulMetric] instances.") raise ValueError("List must contain either all strings or all [BaseMetric | StatefulMetric] instances.") else: - return _process_single_request(request) + return _process_single_request(request, device) def _process_metric_instances(request: List[BaseMetric | StatefulMetric]) -> List[BaseMetric | StatefulMetric]: @@ -168,7 +171,7 @@ def _process_metric_instances(request: List[BaseMetric | StatefulMetric]) -> Lis return new_request_metrics -def _process_metric_names(request: List[str]) -> List[BaseMetric | StatefulMetric]: +def _process_metric_names(request: List[str], device: str | torch.device | None) -> List[BaseMetric | StatefulMetric]: pruna_logger.info(f"Creating metrics from names: {request}") new_requests: List[str] = [] for metric_name in request: @@ -184,16 +187,16 @@ def _process_metric_names(request: List[str]) -> List[BaseMetric | StatefulMetri new_requests.append(cast(str, new_metric)) else: new_requests.append(cast(str, metric_name)) - return MetricRegistry.get_metrics(new_requests) + return MetricRegistry.get_metrics(names=new_requests, device=device) -def _process_single_request(request: str) -> List[BaseMetric | StatefulMetric]: +def _process_single_request(request: str, device: str | torch.device | None) -> List[BaseMetric | StatefulMetric]: if request == "image_generation_quality": pruna_logger.info("An evaluation task for image generation quality is being created.") return [ TorchMetricWrapper("clip_score"), TorchMetricWrapper("clip_score", call_type="pairwise"), - CMMD(), + CMMD(device=device), ] else: pruna_logger.error(f"Metric {request} not found. Available requests: {AVAILABLE_REQUESTS}.") From 17c84ede3c6944386bcdbd6536d949f025ed4bd8 Mon Sep 17 00:00:00 2001 From: davidberenstein1957 Date: Thu, 22 May 2025 21:53:40 +0200 Subject: [PATCH 05/11] refactor: streamline event handling in InferenceTimeStats - Updated the event handling in the `InferenceTimeStats` class to improve device attribute access, enhancing code clarity and maintainability. - Replaced direct calls to `getattr(torch, self.device)` with a local variable for better performance and readability. --- src/pruna/evaluation/metrics/metric_elapsed_time.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/pruna/evaluation/metrics/metric_elapsed_time.py b/src/pruna/evaluation/metrics/metric_elapsed_time.py index 7a394186..d59cf3ce 100644 --- a/src/pruna/evaluation/metrics/metric_elapsed_time.py +++ b/src/pruna/evaluation/metrics/metric_elapsed_time.py @@ -120,12 +120,13 @@ def _time_inference(self, model: PrunaModel, x: Any) -> float: endevent_time = time.time() return (endevent_time - startevent_time) * 1000 # in ms elif self.timing_type == "sync": - startevent = getattr(torch, self.device).Event(enable_timing=True) - endevent = getattr(torch, self.device).Event(enable_timing=True) + torch_device_attr = getattr(torch, self.device) + startevent = torch_device_attr.Event(enable_timing=True) + endevent = torch_device_attr.Event(enable_timing=True) startevent.record() _ = model(x, **model.inference_handler.model_args) endevent.record() - getattr(torch, self.device).synchronize() + torch_device_attr.synchronize() return startevent.elapsed_time(endevent) # in ms else: raise ValueError(f"Timing type {self.timing_type} not supported.") From 698c1a34da7772cf83b5735857548762ccee8f5c Mon Sep 17 00:00:00 2001 From: davidberenstein1957 Date: Fri, 23 May 2025 08:25:42 +0200 Subject: [PATCH 06/11] refactor: update device parameter documentation across metrics - Revised docstrings in multiple classes and functions to clarify the usage of the device parameter, removing redundant phrasing related to "smashing." - Ensured consistency in the description of device handling across various evaluation metrics and configurations. --- src/pruna/config/smash_config.py | 2 +- src/pruna/evaluation/metrics/metric_cmmd.py | 2 +- src/pruna/evaluation/metrics/metric_elapsed_time.py | 8 ++++---- src/pruna/evaluation/metrics/metric_energy.py | 6 +++--- src/pruna/evaluation/metrics/metric_model_architecture.py | 6 +++--- src/pruna/evaluation/task.py | 4 ++-- 6 files changed, 14 insertions(+), 14 deletions(-) diff --git a/src/pruna/config/smash_config.py b/src/pruna/config/smash_config.py index c6b00f95..0e94e029 100644 --- a/src/pruna/config/smash_config.py +++ b/src/pruna/config/smash_config.py @@ -60,7 +60,7 @@ class SmashConfig: batch_size : int, optional The number of batches to process at once. Default is 1. device : str | torch.device | None, optional - The device to be used for smashing, e.g., 'cuda' or 'cpu'. Default is None. + The device to be used, e.g., 'cuda' or 'cpu'. Default is None. If None, the best available device will be used. cache_dir_prefix : str, optional The prefix for the cache directory. If None, a default cache directory will be created. diff --git a/src/pruna/evaluation/metrics/metric_cmmd.py b/src/pruna/evaluation/metrics/metric_cmmd.py index e366ca72..0c3a7cd2 100644 --- a/src/pruna/evaluation/metrics/metric_cmmd.py +++ b/src/pruna/evaluation/metrics/metric_cmmd.py @@ -43,7 +43,7 @@ class CMMD(StatefulMetric): *args : Any Additional arguments to pass to the StatefulMetric constructor. device : str | torch.device | None, optional - The device to be used for smashing, e.g., 'cuda' or 'cpu'. Default is None. + The device to be used, e.g., 'cuda' or 'cpu'. Default is None. If None, the best available device will be used. clip_model_name : str The name of the CLIP model to use. diff --git a/src/pruna/evaluation/metrics/metric_elapsed_time.py b/src/pruna/evaluation/metrics/metric_elapsed_time.py index d59cf3ce..d4f19f99 100644 --- a/src/pruna/evaluation/metrics/metric_elapsed_time.py +++ b/src/pruna/evaluation/metrics/metric_elapsed_time.py @@ -56,7 +56,7 @@ class InferenceTimeStats(BaseMetric): n_warmup_iterations : int, default=10 The number of warmup batches to evaluate the model. device : str | torch.device | None, optional - The device to be used for smashing, e.g., 'cuda' or 'cpu'. Default is None. + The device to be used, e.g., 'cuda' or 'cpu'. Default is None. If None, the best available device will be used. timing_type : str, default="sync" The type of timing to use. @@ -187,7 +187,7 @@ class LatencyMetric(InferenceTimeStats): n_warmup_iterations : int, default=10 The number of warmup batches to evaluate the model. device : str | torch.device | None, optional - The device to be used for smashing, e.g., 'cuda' or 'cpu'. Default is None. + The device to be used, e.g., 'cuda' or 'cpu'. Default is None. If None, the best available device will be used. timing_type : str, default="sync" The type of timing to use. @@ -232,7 +232,7 @@ class ThroughputMetric(InferenceTimeStats): n_warmup_iterations : int, default=10 The number of warmup batches to evaluate the model. device : str | torch.device | None, optional - The device to be used for smashing, e.g., 'cuda' or 'cpu'. Default is None. + The device to be used, e.g., 'cuda' or 'cpu'. Default is None. If None, the best available device will be used. timing_type : str, default="sync" The type of timing to use. @@ -277,7 +277,7 @@ class TotalTimeMetric(InferenceTimeStats): n_warmup_iterations : int, default=10 The number of warmup batches to evaluate the model. device : str | torch.device | None, optional - The device to be used for smashing, e.g., 'cuda' or 'cpu'. Default is None. + The device to be used, e.g., 'cuda' or 'cpu'. Default is None. If None, the best available device will be used. timing_type : str, default="sync" The type of timing to use. diff --git a/src/pruna/evaluation/metrics/metric_energy.py b/src/pruna/evaluation/metrics/metric_energy.py index 780bc24a..b9a19107 100644 --- a/src/pruna/evaluation/metrics/metric_energy.py +++ b/src/pruna/evaluation/metrics/metric_energy.py @@ -55,7 +55,7 @@ class EnvironmentalImpactStats(BaseMetric): n_warmup_iterations : int, default=10 The number of warmup batches to evaluate the model. device : str | torch.device | None, optional - The device to be used for smashing, e.g., 'cuda' or 'cpu'. Default is None. + The device to be used, e.g., 'cuda' or 'cpu'. Default is None. If None, the best available device will be used. """ @@ -160,7 +160,7 @@ class EnergyConsumedMetric(EnvironmentalImpactStats): n_warmup_iterations : int, default=10 The number of warmup batches to evaluate the model. device : str | torch.device | None, optional - The device to be used for smashing, e.g., 'cuda' or 'cpu'. Default is None. + The device to be used, e.g., 'cuda' or 'cpu'. Default is None. If None, the best available device will be used. """ @@ -203,7 +203,7 @@ class CO2EmissionsMetric(EnvironmentalImpactStats): n_warmup_iterations : int, default=10 The number of warmup batches to evaluate the model. device : str | torch.device | None, optional - The device to be used for smashing, e.g., 'cuda' or 'cpu'. Default is None. + The device to be used, e.g., 'cuda' or 'cpu'. Default is None. If None, the best available device will be used. """ diff --git a/src/pruna/evaluation/metrics/metric_model_architecture.py b/src/pruna/evaluation/metrics/metric_model_architecture.py index eb50fe74..8c3f7cef 100644 --- a/src/pruna/evaluation/metrics/metric_model_architecture.py +++ b/src/pruna/evaluation/metrics/metric_model_architecture.py @@ -52,7 +52,7 @@ class ModelArchitectureStats(BaseMetric): Parameters ---------- device : str | torch.device | None, optional - The device to be used for smashing, e.g., 'cuda' or 'cpu'. Default is None. + The device to be used, e.g., 'cuda' or 'cpu'. Default is None. If None, the best available device will be used. """ @@ -197,7 +197,7 @@ class TotalMACsMetric(ModelArchitectureStats): Parameters ---------- device : str | torch.device | None, optional - The device to be used for smashing, e.g., 'cuda' or 'cpu'. Default is None. + The device to be used, e.g., 'cuda' or 'cpu'. Default is None. If None, the best available device will be used. """ @@ -235,7 +235,7 @@ class TotalParamsMetric(ModelArchitectureStats): Parameters ---------- device : str | torch.device | None, optional - The device to be used for smashing, e.g., 'cuda' or 'cpu'. Default is None. + The device to be used, e.g., 'cuda' or 'cpu'. Default is None. If None, the best available device will be used. """ diff --git a/src/pruna/evaluation/task.py b/src/pruna/evaluation/task.py index ad40c78f..d4e8091e 100644 --- a/src/pruna/evaluation/task.py +++ b/src/pruna/evaluation/task.py @@ -58,7 +58,7 @@ class Task: datamodule : PrunaDataModule The dataloader to use for the evaluation. device : str | torch.device | None, optional - The device to be used for smashing, e.g., 'cuda' or 'cpu'. Default is None. + The device to be used, e.g., 'cuda' or 'cpu'. Default is None. If None, the best available device will be used. """ @@ -130,7 +130,7 @@ def get_metrics( request : str | List[str] The user request. Right now, it only supports image generation quality. device : str | torch.device | None, optional - The device to be used for smashing, e.g., 'cuda' or 'cpu'. Default is None. + The device to be used, e.g., 'cuda' or 'cpu'. Default is None. If None, the best available device will be used. Returns From cd4634503a187a4f8f001debbb70f5d7ff503e01 Mon Sep 17 00:00:00 2001 From: davidberenstein1957 Date: Fri, 23 May 2025 08:39:36 +0200 Subject: [PATCH 07/11] fix: improve error handling for unsupported devices in InferenceTimeStats - Added a try-except block around the device attribute access to raise a ValueError when an unsupported device is specified for sync timing, ensuring clearer error reporting and fallback behavior. --- src/pruna/evaluation/metrics/metric_elapsed_time.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/pruna/evaluation/metrics/metric_elapsed_time.py b/src/pruna/evaluation/metrics/metric_elapsed_time.py index d4f19f99..e6e63370 100644 --- a/src/pruna/evaluation/metrics/metric_elapsed_time.py +++ b/src/pruna/evaluation/metrics/metric_elapsed_time.py @@ -120,7 +120,10 @@ def _time_inference(self, model: PrunaModel, x: Any) -> float: endevent_time = time.time() return (endevent_time - startevent_time) * 1000 # in ms elif self.timing_type == "sync": - torch_device_attr = getattr(torch, self.device) + try: + torch_device_attr = getattr(torch, self.device) + except AttributeError: + raise ValueError(f"Device {self.device} not supported for sync timing. Using async timing instead.") startevent = torch_device_attr.Event(enable_timing=True) endevent = torch_device_attr.Event(enable_timing=True) startevent.record() From 630917d090605347a2e08cd51054c6774e2b2905 Mon Sep 17 00:00:00 2001 From: davidberenstein1957 Date: Wed, 28 May 2025 08:40:37 +0200 Subject: [PATCH 08/11] refactor: enhance device selection logic in set_to_best_available_device - Improved error handling for the 'accelerate' device to raise a ValueError when neither CUDA nor MPS is available. - Streamlined checks for 'cuda' and 'mps' devices, ensuring warnings are logged when the requested device is unavailable, and fallback behavior is maintained. --- src/pruna/engine/utils.py | 22 ++++++++++++++-------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/src/pruna/engine/utils.py b/src/pruna/engine/utils.py index 95a29616..906eea67 100644 --- a/src/pruna/engine/utils.py +++ b/src/pruna/engine/utils.py @@ -272,16 +272,22 @@ def set_to_best_available_device(device: str | torch.device | None) -> str: if device == "cpu": return "cpu" - if device == "accelerate" and torch.cuda.is_available(): + if device == "accelerate": + if not torch.cuda.is_available() and not torch.backends.mps.is_available(): + raise ValueError("'accelerate' requested but neither CUDA nor MPS is available.") return "accelerate" - if device == "cuda" and not torch.cuda.is_available(): - pruna_logger.warning("'cuda' requested but not available.") - return set_to_best_available_device(device=None) - - if device == "mps" and not torch.backends.mps.is_available(): - pruna_logger.warning("'mps' requested but not available.") - return set_to_best_available_device(device=None) + if device == "cuda": + if not torch.cuda.is_available(): + pruna_logger.warning("'cuda' requested but not available.") + return set_to_best_available_device(device=None) + return "cuda" + + if device == "mps": + if not torch.backends.mps.is_available(): + pruna_logger.warning("'mps' requested but not available.") + return set_to_best_available_device(device=None) + return "mps" raise ValueError(f"Device not supported: '{device}'") From 125c7206224fbde547130686ba9dacaa513edaa7 Mon Sep 17 00:00:00 2001 From: davidberenstein1957 Date: Tue, 3 Jun 2025 14:33:37 +0200 Subject: [PATCH 09/11] refactor: enhance GPU memory statistics handling for MPS devices - Updated the GPUMemoryStats class to include support for MPS devices, using -1 as a placeholder. - Improved device index retrieval logic for both CUDA and MPS, ensuring consistent handling of device types. - Streamlined inference method calls to utilize the best available device, enhancing overall device management. --- src/pruna/evaluation/metrics/metric_memory.py | 32 ++++++++++++------- src/pruna/evaluation/metrics/registry.py | 3 +- 2 files changed, 22 insertions(+), 13 deletions(-) diff --git a/src/pruna/evaluation/metrics/metric_memory.py b/src/pruna/evaluation/metrics/metric_memory.py index 882911b1..cbefdbad 100644 --- a/src/pruna/evaluation/metrics/metric_memory.py +++ b/src/pruna/evaluation/metrics/metric_memory.py @@ -248,12 +248,15 @@ def _check_device_map(self, model: PrunaModel, attr_name: str) -> Optional[List[ indices = set() for device in device_map.values(): - if isinstance(device, str) and "cuda" in device: - try: - idx = int(device.split(":")[1]) - indices.add(idx) - except (IndexError, ValueError): - indices.add(0) + if isinstance(device, str): + if "cuda" in device: + try: + idx = int(device.split(":")[1]) + indices.add(idx) + except (IndexError, ValueError): + indices.add(0) + elif "mps" in device: + indices.add(-1) # Use -1 as a placeholder for MPS return sorted(indices) if indices else None @@ -302,10 +305,17 @@ def _check_model_device(self, model: PrunaModel) -> Optional[List[int]]: return None device = model.device - if hasattr(device, "index") and device.type == "cuda": - return [device.index] - elif isinstance(device, str) and "cuda" in device: - return [0] + if hasattr(device, "type"): + if device.type == "cuda": + return [device.index if device.index is not None else 0] + elif device.type == "mps": + return [-1] # Placeholder for MPS + + elif isinstance(device, str): + if "cuda" in device: + return [0] + elif "mps" in device: + return [-1] return None @@ -348,7 +358,7 @@ def _perform_forward_pass(self, model: PrunaModel, dataloader: DataLoader) -> No """ with torch.no_grad() if self.mode == INFERENCE_MEMORY else torch.enable_grad(): batch = next(iter(dataloader)) - model.run_inference(batch, "cuda") + model.run_inference(batch=batch, device=set_to_best_available_device(None)) @MetricRegistry.register(DISK_MEMORY) diff --git a/src/pruna/evaluation/metrics/registry.py b/src/pruna/evaluation/metrics/registry.py index df48bba2..5ad88f13 100644 --- a/src/pruna/evaluation/metrics/registry.py +++ b/src/pruna/evaluation/metrics/registry.py @@ -17,7 +17,6 @@ from functools import partial from typing import Any, Callable, Dict, Iterable, List -from pruna.engine.load import filter_load_kwargs from pruna.evaluation.metrics.metric_base import BaseMetric from pruna.evaluation.metrics.metric_stateful import StatefulMetric from pruna.logging.logger import pruna_logger @@ -124,4 +123,4 @@ def get_metrics(cls, names: List[str], **kwargs) -> List[BaseMetric | StatefulMe ------- A list of metric instances. """ - return [cls.get_metric(name, **filter_load_kwargs(cls.get_metric, **kwargs)) for name in names] + return [cls.get_metric(name, **kwargs) for name in names] From 59ec8f6a6602bce236f3eac0c071254c2e264e62 Mon Sep 17 00:00:00 2001 From: davidberenstein1957 Date: Tue, 3 Jun 2025 14:39:27 +0200 Subject: [PATCH 10/11] docs: update GPUMemoryStats parameter documentation for clarity - Revised the documentation for the `mode` parameter in the `GPUMemoryStats` class to specify the correct options as 'disk_memory', 'inference_memory', and 'training_memory', enhancing clarity for users. --- src/pruna/evaluation/metrics/metric_memory.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/pruna/evaluation/metrics/metric_memory.py b/src/pruna/evaluation/metrics/metric_memory.py index cbefdbad..eca82d34 100644 --- a/src/pruna/evaluation/metrics/metric_memory.py +++ b/src/pruna/evaluation/metrics/metric_memory.py @@ -102,7 +102,7 @@ class GPUMemoryStats(BaseMetric): Parameters ---------- mode : str - The mode for memory evaluation. Must be one of 'disk', 'inference', or 'training'. + The mode for memory evaluation. Must be one of 'disk_memory', 'inference_memory', or 'training_memory'. gpu_indices : Optional[List[int]] List of GPU indices to monitor. If None, all GPUs are assumed. """ From f6a92b3e950a0af438c622b25f5bde0161225e0e Mon Sep 17 00:00:00 2001 From: davidberenstein1957 Date: Tue, 3 Jun 2025 15:10:30 +0200 Subject: [PATCH 11/11] refactor: update input preparation method signatures for consistency - Modified the `prepare_inputs` method signatures across multiple handler classes to accept a more flexible input type, allowing for `List[str]`, `torch.Tensor`, or a tuple of these types. - Removed redundant type annotations to enhance clarity and maintainability in the codebase. - Adjusted the `CTranslateCompiler` class to remove an unnecessary line related to model configuration. --- src/pruna/algorithms/compilation/c_translate.py | 1 - src/pruna/engine/handler/handler_diffuser.py | 6 +++--- src/pruna/engine/handler/handler_inference.py | 2 +- src/pruna/engine/handler/handler_standard.py | 4 ++-- src/pruna/engine/handler/handler_transformer.py | 4 ++-- tests/algorithms/test_combinations.py | 1 + 6 files changed, 9 insertions(+), 9 deletions(-) diff --git a/src/pruna/algorithms/compilation/c_translate.py b/src/pruna/algorithms/compilation/c_translate.py index 506c8e94..5d527362 100644 --- a/src/pruna/algorithms/compilation/c_translate.py +++ b/src/pruna/algorithms/compilation/c_translate.py @@ -199,7 +199,6 @@ def load_model(self: Any, model_class: Any, model_name_or_path: str, **kwargs: A elif self.task_name == "whisper": optimized_model = imported_modules["Whisper"](temp_dir, device=smash_config["device"]) optimized_model = WhisperWrapper(optimized_model, temp_dir, smash_config.processor) - optimized_model.config = model.config else: raise ValueError("Task not supported") diff --git a/src/pruna/engine/handler/handler_diffuser.py b/src/pruna/engine/handler/handler_diffuser.py index 2a10aace..c1c15359 100644 --- a/src/pruna/engine/handler/handler_diffuser.py +++ b/src/pruna/engine/handler/handler_diffuser.py @@ -15,7 +15,7 @@ from __future__ import annotations import inspect -from typing import Any, Dict, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple import torch from torchvision import transforms @@ -47,13 +47,13 @@ def __init__(self, call_signature: inspect.Signature, model_args: Optional[Dict[ default_args.update(model_args) self.model_args = default_args - def prepare_inputs(self, batch: Tuple[Any, ...]) -> Any: + def prepare_inputs(self, batch: List[str] | torch.Tensor | Tuple[List[str] | torch.Tensor, ...]) -> Any: """ Prepare the inputs for the model. Parameters ---------- - batch : Tuple[Any, ...] + batch : List[str] | torch.Tensor | Tuple[List[str] | torch.Tensor, ...] The batch to prepare the inputs for. Returns diff --git a/src/pruna/engine/handler/handler_inference.py b/src/pruna/engine/handler/handler_inference.py index ee0ab576..442f7636 100644 --- a/src/pruna/engine/handler/handler_inference.py +++ b/src/pruna/engine/handler/handler_inference.py @@ -33,7 +33,7 @@ def __init__(self) -> None: self.model_args: Dict[str, Any] = {} @abstractmethod - def prepare_inputs(self, batch: Tuple[Any, ...]) -> Any: + def prepare_inputs(self, batch: List[str] | torch.Tensor | Tuple[List[str] | torch.Tensor, ...]) -> Any: """ Prepare the inputs for the model. diff --git a/src/pruna/engine/handler/handler_standard.py b/src/pruna/engine/handler/handler_standard.py index d10da2ae..4f2b42c7 100644 --- a/src/pruna/engine/handler/handler_standard.py +++ b/src/pruna/engine/handler/handler_standard.py @@ -40,13 +40,13 @@ class StandardHandler(InferenceHandler): def __init__(self, model_args: Optional[Dict[str, Any]] = None) -> None: self.model_args = model_args if model_args else {} - def prepare_inputs(self, batch: Tuple[List[str] | torch.Tensor, ...]) -> Any: + def prepare_inputs(self, batch: List[str] | torch.Tensor | Tuple[List[str] | torch.Tensor, ...]) -> Any: """ Prepare the inputs for the model. Parameters ---------- - batch : Tuple[Any, ...] + batch : List[str] | torch.Tensor | Tuple[List[str] | torch.Tensor, ...] The batch to prepare the inputs for. Returns diff --git a/src/pruna/engine/handler/handler_transformer.py b/src/pruna/engine/handler/handler_transformer.py index a818c183..06bbd93b 100644 --- a/src/pruna/engine/handler/handler_transformer.py +++ b/src/pruna/engine/handler/handler_transformer.py @@ -38,13 +38,13 @@ class TransformerHandler(InferenceHandler): def __init__(self, model_args: Optional[Dict[str, Any]] = None) -> None: self.model_args = model_args if model_args else {} - def prepare_inputs(self, batch: Tuple[List[str] | torch.Tensor, ...]) -> Any: + def prepare_inputs(self, batch: List[str] | torch.Tensor | Tuple[List[str] | torch.Tensor, ...]) -> Any: """ Prepare the inputs for the model. Parameters ---------- - batch : Tuple[List[str] | torch.Tensor, ...] + batch : List[str] | torch.Tensor | Tuple[List[str] | torch.Tensor, ...] The batch to prepare the inputs for. Returns diff --git a/tests/algorithms/test_combinations.py b/tests/algorithms/test_combinations.py index 23b7c885..afb0895b 100644 --- a/tests/algorithms/test_combinations.py +++ b/tests/algorithms/test_combinations.py @@ -21,6 +21,7 @@ def allow_pickle_files(self) -> bool: """Allow pickle files.""" return self._allow_pickle_files + @property def compatible_devices(self) -> list[str]: """Return the compatible devices for the test.""" return ["cuda"]