"""Structural Similarity Index Measure (SSIM) metrics.
SSIM is a perceptual metric that quantifies image quality degradation caused by
processing such as data compression or by losses in data transmission :cite:`wang2004image`.
MS-SSIM extends this concept to multiple scales :cite:`wang2003multiscale`.
"""
# Need to import inspect
import inspect
from typing import Any, Optional, Tuple
import torch
import torchmetrics
from torch import Tensor
from ..base import BaseMetric
from ..registry import MetricRegistry
[docs]
@MetricRegistry.register_metric("ssim")
class StructuralSimilarityIndexMeasure(BaseMetric):
"""Structural Similarity Index Measure (SSIM) Module.
SSIM measures the perceptual difference between two similar images. Values range from 0 to 1,
where 1 means perfect similarity. The metric considers luminance, contrast, and structure
to better match human visual perception :cite:`wang2004image` :cite:`brunet2011mathematical`.
"""
[docs]
def __init__(self, data_range: float = 1.0, kernel_size: int = 11, sigma: float = 1.5, reduction: Optional[str] = None, *args: Any, **kwargs: Any) -> None:
"""Initialize the SSIM module.
Args:
data_range (float): Range of the input data (typically 1.0 or 255)
kernel_size (int): Size of the Gaussian kernel
sigma (float): Standard deviation of the Gaussian kernel
reduction (Optional[str]): Reduction method. The underlying torchmetrics implementation
requires reduction=None, so this parameter controls post-processing reduction.
*args: Variable length argument list passed to the base class and torchmetrics.
**kwargs: Arbitrary keyword arguments passed to the base class and torchmetrics.
"""
# Remove name="SSIM" as BaseMetric handles it
super().__init__(*args, **kwargs) # Pass args and kwargs
self.reduction = reduction
# Always use reduction=None in the underlying implementation
# Pass only relevant kwargs to torchmetrics
torchmetrics_kwargs = {k: v for k, v in kwargs.items() if k in inspect.signature(torchmetrics.image.StructuralSimilarityIndexMeasure.__init__).parameters}
self.ssim = torchmetrics.image.StructuralSimilarityIndexMeasure(data_range=data_range, kernel_size=kernel_size, sigma=sigma, reduction=None, **torchmetrics_kwargs)
# Rename preds to x and targets to y to match BaseMetric
[docs]
def forward(self, x: Tensor, y: Tensor, *args: Any, **kwargs: Any) -> Tensor:
"""Calculate SSIM between predicted and target images.
Args:
x (Tensor): Predicted images
y (Tensor): Target images
*args: Variable length argument list (currently unused).
**kwargs: Arbitrary keyword arguments (currently unused).
Returns:
Tensor: SSIM values for each sample or reduced according to reduction parameter
"""
# Note: *args and **kwargs are not directly used by self.ssim call here
# but are included for interface consistency.
# Handle empty tensors gracefully
if x.numel() == 0 or y.numel() == 0:
# Return empty tensor with appropriate shape
batch_size = x.shape[0] if x.numel() >= 0 else 0
return torch.tensor([], device=x.device, dtype=x.dtype).view(batch_size)
values = self.ssim(x, y)
# Apply reduction if specified
if self.reduction == "mean":
return values.mean()
elif self.reduction == "sum":
return values.sum()
else:
# Ensure the tensor has at least one dimension when not reduced
if values.dim() == 0:
return values.unsqueeze(0)
return values
# Rename preds to x and targets to y to match BaseMetric
[docs]
def compute_with_stats(self, x: Tensor, y: Tensor, *args: Any, **kwargs: Any) -> Tuple[Tensor, Tensor]:
"""Compute SSIM with mean and standard deviation.
Args:
x (Tensor): Predicted images
y (Tensor): Target images
*args: Variable length argument list (currently unused).
**kwargs: Arbitrary keyword arguments (currently unused).
Returns:
Tuple[Tensor, Tensor]: Mean and standard deviation of SSIM values
"""
# Note: *args and **kwargs are not directly used here
# but are included for interface consistency.
values = self.forward(x, y) # Use self.forward to handle reduction
# Handle single value case to avoid NaN in std calculation
if values.numel() <= 1:
return values.mean(), torch.tensor(0.0)
return values.mean(), values.std()
[docs]
@MetricRegistry.register_metric("ms_ssim")
class MultiScaleSSIM(BaseMetric):
"""Multi-Scale Structural Similarity Index Measure (MS-SSIM) Module.
This module calculates the MS-SSIM between two images. MS-SSIM is an extension of the SSIM
metric that considers multiple scales to better capture perceptual similarity
:cite:`wang2003multiscale`. It has been shown to correlate better with human perception
than single-scale methods :cite:`wang2004image`.
"""
[docs]
def __init__(self, kernel_size: int = 11, data_range: float = 1.0, reduction: Optional[str] = None, weights: Optional[torch.Tensor] = None, *args: Any, **kwargs: Any) -> None:
"""Initialize the MultiScaleSSIM module.
Args:
kernel_size (int): The size of the Gaussian kernel
data_range (float): The range of the input data (typically 1.0 or 255)
reduction (Optional[str]): Reduction method ('mean', 'sum', or None)
weights (Optional[torch.Tensor]): Weights for different scales. Default is equal weighting.
*args: Variable length argument list passed to the base class.
**kwargs: Arbitrary keyword arguments passed to the base class.
"""
# Remove name="MS-SSIM" as BaseMetric handles it
super().__init__(*args, **kwargs) # Pass args and kwargs
self.reduction = reduction
# Convert weights to betas format for torchmetrics if provided
if weights is not None:
betas = tuple(weights.tolist())
else:
# Use default betas from torchmetrics
betas = (0.0448, 0.2856, 0.3001, 0.2363, 0.1333)
# Pass only relevant kwargs to torchmetrics
torchmetrics_kwargs = {k: v for k, v in kwargs.items() if k in inspect.signature(torchmetrics.image.MultiScaleStructuralSimilarityIndexMeasure.__init__).parameters}
# Use torchmetrics MultiScaleStructuralSimilarityIndexMeasure
self.ms_ssim = torchmetrics.image.MultiScaleStructuralSimilarityIndexMeasure(data_range=data_range, kernel_size=kernel_size, reduction=None, betas=betas, **torchmetrics_kwargs) # Always use None for reduction in underlying implementation
# Register buffers for backwards compatibility with existing tests
self.register_buffer("sum_values", torch.tensor(0.0))
self.register_buffer("sum_sq", torch.tensor(0.0))
self.register_buffer("count", torch.tensor(0))
# Rename preds to x and targets to y to match BaseMetric
[docs]
def forward(self, x: torch.Tensor, y: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:
"""Calculate MS-SSIM between predicted and target images.
Args:
x (torch.Tensor): Predicted images
y (torch.Tensor): Target images
*args: Variable length argument list (currently unused).
**kwargs: Arbitrary keyword arguments (currently unused).
Returns:
torch.Tensor: MS-SSIM values for each sample, or reduced according to reduction parameter
"""
# Note: *args and **kwargs are not directly used here
# but are included for interface consistency.
# Handle empty tensors gracefully
if x.numel() == 0 or y.numel() == 0:
# Return empty tensor with appropriate shape
batch_size = x.shape[0] if x.numel() >= 0 else 0
return torch.tensor([], device=x.device, dtype=x.dtype).view(batch_size)
# Use torchmetrics MS-SSIM implementation
values = self.ms_ssim(x, y)
# Apply reduction if specified
if self.reduction == "mean":
return values.mean()
elif self.reduction == "sum":
return values.sum()
else:
# Ensure the tensor has at least one dimension when not reduced
if values.dim() == 0:
return values.unsqueeze(0)
return values
[docs]
def update(self, preds: torch.Tensor, targets: torch.Tensor, *args: Any, **kwargs: Any) -> None:
"""Update internal state with batch of samples.
Args:
preds (torch.Tensor): Predicted images
targets (torch.Tensor): Target images
*args: Variable length argument list passed to forward.
**kwargs: Arbitrary keyword arguments passed to forward.
"""
# Handle empty tensors gracefully
if preds.numel() == 0 or targets.numel() == 0:
return # Skip update for empty tensors
values = self.forward(preds, targets, *args, **kwargs) # Pass args/kwargs
if values.numel() == 0:
return # Avoid updating with empty values
self.sum_values += values.sum()
self.sum_sq += (values**2).sum()
self.count += values.numel()
[docs]
def compute(self) -> Tuple[torch.Tensor, torch.Tensor]:
"""Compute accumulated MS-SSIM statistics.
Returns:
Tuple[torch.Tensor, torch.Tensor]: Mean and standard deviation
"""
# For backward compatibility, we return mean and std
if self.count == 0:
return torch.tensor(0.0), torch.tensor(0.0)
mean = self.sum_values / self.count
std = torch.sqrt((self.sum_sq / self.count) - mean**2)
return mean, std
[docs]
def reset(self) -> None:
"""Reset accumulated statistics."""
self.ms_ssim.reset()
self.sum_values.zero_()
self.sum_sq.zero_()
self.count.zero_()
# Rename preds to x and targets to y to match BaseMetric
[docs]
def compute_with_stats(self, x: torch.Tensor, y: torch.Tensor, *args: Any, **kwargs: Any) -> Tuple[torch.Tensor, torch.Tensor]:
"""Compute MS-SSIM with mean and standard deviation.
Args:
x (torch.Tensor): Predicted images
y (torch.Tensor): Target images
*args: Variable length argument list (currently unused).
**kwargs: Arbitrary keyword arguments (currently unused).
Returns:
Tuple[torch.Tensor, torch.Tensor]: Mean and standard deviation of MS-SSIM values
"""
# Note: *args and **kwargs are not directly used here
# but are included for interface consistency.
values = self.forward(x, y) # Use self.forward to handle reduction
# Handle single value case to avoid NaN in std calculation
if values.numel() <= 1:
return values.mean(), torch.tensor(0.0)
return values.mean(), values.std()
@property
def data_range(self) -> float:
"""Get the data range used by the underlying torchmetrics implementation."""
return self.ms_ssim.data_range
# Alias for backward compatibility
SSIM = StructuralSimilarityIndexMeasure