diff --git a/src/pruna/engine/load.py b/src/pruna/engine/load.py index 8b33b548..04f3c071 100644 --- a/src/pruna/engine/load.py +++ b/src/pruna/engine/load.py @@ -329,10 +329,10 @@ def load_pickled(path: str | Path, smash_config: SmashConfig, **kwargs) -> Any: model = torch.load( Path(path) / PICKLED_FILE_NAME, weights_only=False, - map_location=target_device, + map_location="cpu", **filter_load_kwargs(torch.load, kwargs), ) - # for cpu/cuda this will just continue as its on the correct device, for accelerate it will now distribute / cast + # move to target device, for accelerate it will now distribute / cast move_to_device(model, target_device, device_map=smash_config.device_map) return model