kaira.metrics.image.MultiScaleSSIM

Inheritance diagram for MultiScaleSSIM
- class kaira.metrics.image.MultiScaleSSIM(kernel_size: int = 11, data_range: float = 1.0, reduction: str | None = None, weights: Tensor | None = None, *args: Any, **kwargs: Any)[source]
Bases:
BaseMetricMulti-Scale Structural Similarity Index Measure (MS-SSIM) Module.
This module calculates the MS-SSIM between two images. MS-SSIM is an extension of the SSIM metric that considers multiple scales to better capture perceptual similarity [Wang et al., 2003]. It has been shown to correlate better with human perception than single-scale methods [Wang et al., 2004].
Methods
Initialize the MultiScaleSSIM module.
Compute accumulated MS-SSIM statistics.
Compute MS-SSIM with mean and standard deviation.
Calculate MS-SSIM between predicted and target images.
Reset accumulated statistics.
Update internal state with batch of samples.
Attributes
Get the data range used by the underlying torchmetrics implementation.
- __init__(kernel_size: int = 11, data_range: float = 1.0, reduction: str | None = None, weights: Tensor | None = None, *args: Any, **kwargs: Any) None[source]
Initialize the MultiScaleSSIM module.
- Parameters:
kernel_size (int) – The size of the Gaussian kernel
data_range (float) – The range of the input data (typically 1.0 or 255)
reduction (Optional[str]) – Reduction method (‘mean’, ‘sum’, or None)
weights (Optional[torch.Tensor]) – Weights for different scales. Default is equal weighting.
*args – Variable length argument list passed to the base class.
**kwargs – Arbitrary keyword arguments passed to the base class.
- forward(x: Tensor, y: Tensor, *args: Any, **kwargs: Any) Tensor[source]
Calculate MS-SSIM between predicted and target images.
- Parameters:
x (torch.Tensor) – Predicted images
y (torch.Tensor) – Target images
*args – Variable length argument list (currently unused).
**kwargs – Arbitrary keyword arguments (currently unused).
- Returns:
MS-SSIM values for each sample, or reduced according to reduction parameter
- Return type:
- update(preds: Tensor, targets: Tensor, *args: Any, **kwargs: Any) None[source]
Update internal state with batch of samples.
- Parameters:
preds (torch.Tensor) – Predicted images
targets (torch.Tensor) – Target images
*args – Variable length argument list passed to forward.
**kwargs – Arbitrary keyword arguments passed to forward.
- compute() Tuple[Tensor, Tensor][source]
Compute accumulated MS-SSIM statistics.
- Returns:
Mean and standard deviation
- Return type:
Tuple[torch.Tensor, torch.Tensor]
- compute_with_stats(x: Tensor, y: Tensor, *args: Any, **kwargs: Any) Tuple[Tensor, Tensor][source]
Compute MS-SSIM with mean and standard deviation.
- Parameters:
x (torch.Tensor) – Predicted images
y (torch.Tensor) – Target images
*args – Variable length argument list (currently unused).
**kwargs – Arbitrary keyword arguments (currently unused).
- Returns:
Mean and standard deviation of MS-SSIM values
- Return type:
Tuple[torch.Tensor, torch.Tensor]