diff --git a/src/pruna/evaluation/metrics/metric_torch.py b/src/pruna/evaluation/metrics/metric_torch.py index 8cd8753e..3d844d30 100644 --- a/src/pruna/evaluation/metrics/metric_torch.py +++ b/src/pruna/evaluation/metrics/metric_torch.py @@ -26,6 +26,7 @@ from torchmetrics.image import ( FrechetInceptionDistance, LearnedPerceptualImagePatchSimilarity, + MultiScaleStructuralSimilarityIndexMeasure, PeakSignalNoiseRatio, StructuralSimilarityIndexMeasure, ) @@ -122,14 +123,18 @@ def arniqa_update(metric: ARNIQA, preds: Any) -> None: metric.update(preds) -def ssim_update(metric: StructuralSimilarityIndexMeasure, preds: Any, target: Any) -> None: +def ssim_update( + metric: StructuralSimilarityIndexMeasure | MultiScaleStructuralSimilarityIndexMeasure, + preds: Any, + target: Any +) -> None: """ - Update handler for SSIM metric. + Update handler for SSIM or MS-SSIM metric. Parameters ---------- - metric : StructuralSimilarityIndexMeasure instance - The SSIM metric instance. + metric : StructuralSimilarityIndexMeasure | MultiScaleStructuralSimilarityIndexMeasure instance + The SSIM or MS-SSIM metric instance. preds : Any The generated images tensor. target : Any @@ -173,6 +178,7 @@ class TorchMetrics(Enum): recall = (partial(Recall), None, "y_gt") psnr = (partial(PeakSignalNoiseRatio), None, "pairwise_y_gt") ssim = (partial(StructuralSimilarityIndexMeasure), ssim_update, "pairwise_y_gt") + msssim = (partial(MultiScaleStructuralSimilarityIndexMeasure), ssim_update, "pairwise_y_gt") lpips = (partial(LearnedPerceptualImagePatchSimilarity), lpips_update, "pairwise_y_gt") arniqa = (partial(ARNIQA), arniqa_update, "y") clipiqa = (partial(CLIPImageQualityAssessment), None, "y") diff --git a/tests/evaluation/test_torch_metrics.py b/tests/evaluation/test_torch_metrics.py index 16fd4674..70fa4eaf 100644 --- a/tests/evaluation/test_torch_metrics.py +++ b/tests/evaluation/test_torch_metrics.py @@ -2,6 +2,10 @@ import pytest import torch +from torchmetrics.image import ( + StructuralSimilarityIndexMeasure, + MultiScaleStructuralSimilarityIndexMeasure +) from pruna.evaluation.metrics.metric_torch import TorchMetricWrapper, TorchMetrics @@ -100,3 +104,29 @@ def test_check_call_type(metric: str, call_type: str): assert metric.call_type == "y" else: assert not metric.call_type.startswith("pairwise") + +@pytest.mark.cpu +@pytest.mark.parametrize( + 'metric_name,metric_type', + ( + ('ssim', StructuralSimilarityIndexMeasure), + ('msssim', MultiScaleStructuralSimilarityIndexMeasure) + ) +) +def test_ssim_generalization_metric_type(metric_name, metric_type): + wrapper = TorchMetricWrapper(metric_name=metric_name) + assert isinstance(wrapper.metric, metric_type) + +@pytest.mark.cpu +@pytest.mark.parametrize( + 'metric_name,invalid_param_args', + ( + pytest.param('ssim', {'betas': [0.0448, 0.2856, 0.3001, 0.2363, 0.1333]}), + pytest.param('ssim', {'normalize': 'relu'}), + pytest.param('msssim', {'return_full_image': True}), + pytest.param('msssim', {'return_contrast_sensitivity': True}), + ) +) +def test_ssim_generalization_invalid_param_type(metric_name, invalid_param_args): + with pytest.raises(ValueError): + TorchMetricWrapper(metric_name, **invalid_param_args)