Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion src/pruna/algorithms/compilation/c_translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,6 @@ def load_model(self: Any, model_class: Any, model_name_or_path: str, **kwargs: A
elif self.task_name == "whisper":
optimized_model = imported_modules["Whisper"](temp_dir, device=smash_config["device"])
optimized_model = WhisperWrapper(optimized_model, temp_dir, smash_config.processor)
optimized_model.config = model.config
else:
raise ValueError("Task not supported")

Expand Down
10 changes: 5 additions & 5 deletions src/pruna/config/smash_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@

from pruna.config.smash_space import ALGORITHM_GROUPS, SMASH_SPACE
from pruna.data.pruna_datamodule import PrunaDataModule, TokenizerMissingError
from pruna.engine.utils import check_device_compatibility
from pruna.engine.utils import set_to_best_available_device
from pruna.logging.logger import pruna_logger

ADDITIONAL_ARGS = [
Expand Down Expand Up @@ -60,7 +60,7 @@ class SmashConfig:
batch_size : int, optional
The number of batches to process at once. Default is 1.
device : str | torch.device | None, optional
The device to be used for smashing, e.g., 'cuda' or 'cpu'. Default is None.
The device to be used, e.g., 'cuda' or 'cpu'. Default is None.
If None, the best available device will be used.
cache_dir_prefix : str, optional
The prefix for the cache directory. If None, a default cache directory will be created.
Expand Down Expand Up @@ -90,7 +90,7 @@ def __init__(
self.batch_size = max_batch_size
else:
self.batch_size = batch_size
self.device = check_device_compatibility(device)
self.device = set_to_best_available_device(device)

self.cache_dir_prefix = cache_dir_prefix
if not os.path.exists(cache_dir_prefix):
Expand Down Expand Up @@ -157,7 +157,7 @@ def load_from_json(self, path: str) -> None:

# check device compatibility
if "device" in config_dict:
config_dict["device"] = check_device_compatibility(config_dict["device"])
config_dict["device"] = set_to_best_available_device(config_dict["device"])

# support deprecated load_fn
if "load_fn" in config_dict:
Expand Down Expand Up @@ -265,7 +265,7 @@ def load_dict(self, config_dict: dict) -> None:
"""
# check device compatibility
if "device" in config_dict:
config_dict["device"] = check_device_compatibility(config_dict["device"])
config_dict["device"] = set_to_best_available_device(config_dict["device"])

# since this function is only used for loading algorithm settings, we will ignore additional arguments
filtered_config_dict = {k: v for k, v in config_dict.items() if k not in ADDITIONAL_ARGS}
Expand Down
6 changes: 3 additions & 3 deletions src/pruna/engine/handler/handler_diffuser.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from __future__ import annotations

import inspect
from typing import Any, Dict, Optional, Tuple
from typing import Any, Dict, List, Optional, Tuple

import torch
from torchvision import transforms
Expand Down Expand Up @@ -47,13 +47,13 @@ def __init__(self, call_signature: inspect.Signature, model_args: Optional[Dict[
default_args.update(model_args)
self.model_args = default_args

def prepare_inputs(self, batch: Tuple[Any, ...]) -> Any:
def prepare_inputs(self, batch: List[str] | torch.Tensor | Tuple[List[str] | torch.Tensor, ...]) -> Any:
"""
Prepare the inputs for the model.

Parameters
----------
batch : Tuple[Any, ...]
batch : List[str] | torch.Tensor | Tuple[List[str] | torch.Tensor, ...]
The batch to prepare the inputs for.

Returns
Expand Down
2 changes: 1 addition & 1 deletion src/pruna/engine/handler/handler_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def __init__(self) -> None:
self.model_args: Dict[str, Any] = {}

@abstractmethod
def prepare_inputs(self, batch: Tuple[Any, ...]) -> Any:
def prepare_inputs(self, batch: List[str] | torch.Tensor | Tuple[List[str] | torch.Tensor, ...]) -> Any:
"""
Prepare the inputs for the model.

Expand Down
4 changes: 2 additions & 2 deletions src/pruna/engine/handler/handler_standard.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,13 @@ class StandardHandler(InferenceHandler):
def __init__(self, model_args: Optional[Dict[str, Any]] = None) -> None:
self.model_args = model_args if model_args else {}

def prepare_inputs(self, batch: Tuple[List[str] | torch.Tensor, ...]) -> Any:
def prepare_inputs(self, batch: List[str] | torch.Tensor | Tuple[List[str] | torch.Tensor, ...]) -> Any:
"""
Prepare the inputs for the model.

Parameters
----------
batch : Tuple[Any, ...]
batch : List[str] | torch.Tensor | Tuple[List[str] | torch.Tensor, ...]
The batch to prepare the inputs for.

Returns
Expand Down
4 changes: 2 additions & 2 deletions src/pruna/engine/handler/handler_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,13 @@ class TransformerHandler(InferenceHandler):
def __init__(self, model_args: Optional[Dict[str, Any]] = None) -> None:
self.model_args = model_args if model_args else {}

def prepare_inputs(self, batch: Tuple[List[str] | torch.Tensor, ...]) -> Any:
def prepare_inputs(self, batch: List[str] | torch.Tensor | Tuple[List[str] | torch.Tensor, ...]) -> Any:
"""
Prepare the inputs for the model.

Parameters
----------
batch : Tuple[List[str] | torch.Tensor, ...]
batch : List[str] | torch.Tensor | Tuple[List[str] | torch.Tensor, ...]
The batch to prepare the inputs for.

Returns
Expand Down
11 changes: 7 additions & 4 deletions src/pruna/engine/pruna_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from pruna.engine.handler.handler_utils import register_inference_handler
from pruna.engine.load import load_pruna_model, load_pruna_model_from_hub
from pruna.engine.save import save_pruna_model, save_pruna_model_to_hub
from pruna.engine.utils import get_nn_modules, move_to_device, set_to_eval
from pruna.engine.utils import get_nn_modules, move_to_device, set_to_best_available_device, set_to_eval
from pruna.logging.filter import apply_warning_filter
from pruna.telemetry import increment_counter, track_usage

Expand Down Expand Up @@ -74,22 +74,25 @@ def __call__(self, *args: Any, **kwargs: Any) -> Any:
with torch.no_grad():
return self.model.__call__(*args, **kwargs)

def run_inference(self, batch: Tuple[List[str] | torch.Tensor, ...], device: torch.device | str) -> Any:
def run_inference(
self, batch: Tuple[List[str] | torch.Tensor, ...], device: torch.device | str | None = None
) -> Any:
"""
Run inference on the model.

Parameters
----------
batch : Tuple[List[str] | torch.Tensor, ...]
The batch to run inference on.
device : torch.device | str
The device to run inference on.
device : torch.device | str | None
The device to run inference on. If None, the best available device will be used.

Returns
-------
Any
The processed output.
"""
device = set_to_best_available_device(device)
batch = self.inference_handler.move_inputs_to_device(batch, device) # type: ignore
prepared_inputs = self.inference_handler.prepare_inputs(batch)
if prepared_inputs is not None:
Expand Down
29 changes: 19 additions & 10 deletions src/pruna/engine/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,9 +244,9 @@ def determine_dtype(pipeline: Any) -> torch.dtype:
return torch.float32


def check_device_compatibility(device: str | torch.device | None) -> str:
def set_to_best_available_device(device: str | torch.device | None) -> str:
"""
Validate if the specified device is available on the current system.
Set the device to the best available device.

Supports 'cuda', 'mps', 'cpu' and other PyTorch devices.
If device is None, the best available device will be returned.
Expand All @@ -266,21 +266,30 @@ def check_device_compatibility(device: str | torch.device | None) -> str:

if device is None:
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
pruna_logger.info(f"No device specified. Using best available device: '{device}'")
pruna_logger.info(f"Using best available device: '{device}'")
return device

if device == "cpu":
return "cpu"

if device == "cuda" and not torch.cuda.is_available():
pruna_logger.warning("'cuda' requested but not available. Falling back to 'cpu'")
return "cpu"
if device == "accelerate":
Comment thread
davidberenstein1957 marked this conversation as resolved.
Outdated
if not torch.cuda.is_available() and not torch.backends.mps.is_available():
raise ValueError("'accelerate' requested but neither CUDA nor MPS is available.")
return "accelerate"

if device == "mps" and not torch.backends.mps.is_available():
pruna_logger.warning("'mps' requested but not available. Falling back to 'cpu'")
return "cpu"
if device == "cuda":
if not torch.cuda.is_available():
pruna_logger.warning("'cuda' requested but not available.")
return set_to_best_available_device(device=None)
return "cuda"

if device == "mps":
if not torch.backends.mps.is_available():
pruna_logger.warning("'mps' requested but not available.")
return set_to_best_available_device(device=None)
return "mps"

return device
raise ValueError(f"Device not supported: '{device}'")


class ModelContext:
Expand Down
4 changes: 2 additions & 2 deletions src/pruna/evaluation/evaluation_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

from pruna.config.utils import is_empty_config
from pruna.engine.pruna_model import PrunaModel
from pruna.engine.utils import safe_memory_cleanup
from pruna.engine.utils import safe_memory_cleanup, set_to_best_available_device
from pruna.evaluation.metrics.metric_stateful import StatefulMetric
from pruna.evaluation.metrics.result import MetricResult
from pruna.evaluation.metrics.utils import group_metrics_by_inheritance
Expand All @@ -41,7 +41,7 @@ def __init__(self, task: Task) -> None:
self.task = task
self.first_model_results: List[MetricResult] = []
self.subsequent_model_results: List[MetricResult] = []
self.device = self.task.device
self.device = set_to_best_available_device(self.task.device)
Comment thread
davidberenstein1957 marked this conversation as resolved.
Outdated
self.cache: List[Tensor] = []
self.evaluation_for_first_model: bool = True

Expand Down
10 changes: 6 additions & 4 deletions src/pruna/evaluation/metrics/metric_cmmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from huggingface_hub.utils import EntryNotFoundError
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection

from pruna.engine.utils import set_to_best_available_device
from pruna.evaluation.metrics.metric_stateful import StatefulMetric
from pruna.evaluation.metrics.registry import MetricRegistry
from pruna.evaluation.metrics.result import MetricResult
Expand All @@ -41,8 +42,9 @@ class CMMD(StatefulMetric):
----------
*args : Any
Additional arguments to pass to the StatefulMetric constructor.
device : str | torch.device
The device to run the CLIP model on to calculate the embeddings for the metric.
device : str | torch.device | None, optional
The device to be used, e.g., 'cuda' or 'cpu'. Default is None.
If None, the best available device will be used.
clip_model_name : str
The name of the CLIP model to use.
call_type : str
Expand All @@ -60,13 +62,13 @@ class CMMD(StatefulMetric):
def __init__(
self,
*args,
device: str | torch.device = "cuda",
device: str | torch.device | None = None,
clip_model_name: str = "openai/clip-vit-large-patch14-336",
call_type: str = SINGLE,
**kwargs,
) -> None:
super().__init__(*args, **kwargs)
self.device = device
self.device = set_to_best_available_device(device)
try:
model_info(clip_model_name)
except EntryNotFoundError:
Expand Down
35 changes: 22 additions & 13 deletions src/pruna/evaluation/metrics/metric_elapsed_time.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from torch.utils.data import DataLoader

from pruna.engine.pruna_model import PrunaModel
from pruna.engine.utils import set_to_best_available_device
from pruna.evaluation.metrics.metric_base import BaseMetric
from pruna.evaluation.metrics.registry import MetricRegistry
from pruna.evaluation.metrics.result import MetricResult
Expand Down Expand Up @@ -54,8 +55,9 @@ class InferenceTimeStats(BaseMetric):
The number of batches to evaluate the model.
n_warmup_iterations : int, default=10
The number of warmup batches to evaluate the model.
device : str | torch.device, default="cuda"
The device to evaluate the model on.
device : str | torch.device | None, optional
The device to be used, e.g., 'cuda' or 'cpu'. Default is None.
If None, the best available device will be used.
timing_type : str, default="sync"
The type of timing to use.
"""
Expand All @@ -64,12 +66,12 @@ def __init__(
self,
n_iterations: int = 100,
n_warmup_iterations: int = 10,
device: str | torch.device = "cuda",
device: str | torch.device | None = None,
timing_type: str = "sync",
) -> None:
self.n_iterations = n_iterations
self.n_warmup_iterations = n_warmup_iterations
self.device = device
self.device = set_to_best_available_device(device)
self.timing_type = timing_type

def _measure(self, model: PrunaModel, dataloader: DataLoader, iterations: int, measure_fn) -> None:
Expand Down Expand Up @@ -118,12 +120,16 @@ def _time_inference(self, model: PrunaModel, x: Any) -> float:
endevent_time = time.time()
return (endevent_time - startevent_time) * 1000 # in ms
elif self.timing_type == "sync":
startevent = torch.cuda.Event(enable_timing=True)
endevent = torch.cuda.Event(enable_timing=True)
try:
torch_device_attr = getattr(torch, self.device)
except AttributeError:
raise ValueError(f"Device {self.device} not supported for sync timing. Using async timing instead.")
startevent = torch_device_attr.Event(enable_timing=True)
endevent = torch_device_attr.Event(enable_timing=True)
startevent.record()
_ = model(x, **model.inference_handler.model_args)
endevent.record()
torch.cuda.synchronize()
torch_device_attr.synchronize()
return startevent.elapsed_time(endevent) # in ms
else:
raise ValueError(f"Timing type {self.timing_type} not supported.")
Expand Down Expand Up @@ -183,8 +189,9 @@ class LatencyMetric(InferenceTimeStats):
The number of batches to evaluate the model.
n_warmup_iterations : int, default=10
The number of warmup batches to evaluate the model.
device : str | torch.device, default="cuda"
The device to evaluate the model on.
device : str | torch.device | None, optional
The device to be used, e.g., 'cuda' or 'cpu'. Default is None.
If None, the best available device will be used.
timing_type : str, default="sync"
The type of timing to use.
"""
Expand Down Expand Up @@ -227,8 +234,9 @@ class ThroughputMetric(InferenceTimeStats):
The number of batches to evaluate the model.
n_warmup_iterations : int, default=10
The number of warmup batches to evaluate the model.
device : str | torch.device, default="cuda"
The device to evaluate the model on.
device : str | torch.device | None, optional
The device to be used, e.g., 'cuda' or 'cpu'. Default is None.
If None, the best available device will be used.
timing_type : str, default="sync"
The type of timing to use.
"""
Expand Down Expand Up @@ -271,8 +279,9 @@ class TotalTimeMetric(InferenceTimeStats):
The number of batches to evaluate the model.
n_warmup_iterations : int, default=10
The number of warmup batches to evaluate the model.
device : str | torch.device, default="cuda"
The device to evaluate the model on.
device : str | torch.device | None, optional
The device to be used, e.g., 'cuda' or 'cpu'. Default is None.
If None, the best available device will be used.
timing_type : str, default="sync"
The type of timing to use.
"""
Expand Down
Loading