diff --git a/tests/evaluation/test_memory_metrics.py b/tests/evaluation/test_memory_metrics.py index 7e561cc1..7383b463 100644 --- a/tests/evaluation/test_memory_metrics.py +++ b/tests/evaluation/test_memory_metrics.py @@ -18,6 +18,7 @@ def test_disk_memory_metric(model_fixture: tuple[Any, SmashConfig], device: str) model, smash_config = model_fixture disk_memory_metric = DiskMemoryMetric() pruna_model = PrunaModel(model, smash_config=smash_config) + pruna_model.move_to_device("cuda") disk_memory_results = disk_memory_metric.compute(pruna_model, smash_config.test_dataloader()) assert disk_memory_results.result > 0 @@ -34,6 +35,7 @@ def test_inference_memory_metric(model_fixture: tuple[Any, SmashConfig], device: model, smash_config = model_fixture inference_memory_metric = InferenceMemoryMetric() pruna_model = PrunaModel(model, smash_config=smash_config) + pruna_model.move_to_device("cuda") inference_memory_results = inference_memory_metric.compute(pruna_model, smash_config.test_dataloader()) assert inference_memory_results.result > 0 @@ -50,5 +52,6 @@ def test_training_memory_metric(model_fixture: tuple[Any, SmashConfig], device: model, smash_config = model_fixture training_memory_metric = TrainingMemoryMetric() pruna_model = PrunaModel(model, smash_config=smash_config) + pruna_model.move_to_device("cuda") training_memory_results = training_memory_metric.compute(pruna_model, smash_config.test_dataloader()) assert training_memory_results.result > 0 \ No newline at end of file