kaira.metrics.image.MultiScaleSSIM

Inheritance diagram of MultiScaleSSIM

Inheritance diagram for MultiScaleSSIM

class kaira.metrics.image.MultiScaleSSIM(kernel_size: int = 11, data_range: float = 1.0, reduction: str | None = None, weights: Tensor | None = None, *args: Any, **kwargs: Any)[source]

Bases: BaseMetric

Multi-Scale Structural Similarity Index Measure (MS-SSIM) Module.

This module calculates the MS-SSIM between two images. MS-SSIM is an extension of the SSIM metric that considers multiple scales to better capture perceptual similarity [Wang et al., 2003]. It has been shown to correlate better with human perception than single-scale methods [Wang et al., 2004].

Methods

__init__

Initialize the MultiScaleSSIM module.

compute

Compute accumulated MS-SSIM statistics.

compute_with_stats

Compute MS-SSIM with mean and standard deviation.

forward

Calculate MS-SSIM between predicted and target images.

reset

Reset accumulated statistics.

update

Update internal state with batch of samples.

Attributes

data_range

Get the data range used by the underlying torchmetrics implementation.

__init__(kernel_size: int = 11, data_range: float = 1.0, reduction: str | None = None, weights: Tensor | None = None, *args: Any, **kwargs: Any) None[source]

Initialize the MultiScaleSSIM module.

Parameters:
  • kernel_size (int) – The size of the Gaussian kernel

  • data_range (float) – The range of the input data (typically 1.0 or 255)

  • reduction (Optional[str]) – Reduction method (‘mean’, ‘sum’, or None)

  • weights (Optional[torch.Tensor]) – Weights for different scales. Default is equal weighting.

  • *args – Variable length argument list passed to the base class.

  • **kwargs – Arbitrary keyword arguments passed to the base class.

forward(x: Tensor, y: Tensor, *args: Any, **kwargs: Any) Tensor[source]

Calculate MS-SSIM between predicted and target images.

Parameters:
  • x (torch.Tensor) – Predicted images

  • y (torch.Tensor) – Target images

  • *args – Variable length argument list (currently unused).

  • **kwargs – Arbitrary keyword arguments (currently unused).

Returns:

MS-SSIM values for each sample, or reduced according to reduction parameter

Return type:

torch.Tensor

update(preds: Tensor, targets: Tensor, *args: Any, **kwargs: Any) None[source]

Update internal state with batch of samples.

Parameters:
  • preds (torch.Tensor) – Predicted images

  • targets (torch.Tensor) – Target images

  • *args – Variable length argument list passed to forward.

  • **kwargs – Arbitrary keyword arguments passed to forward.

compute() Tuple[Tensor, Tensor][source]

Compute accumulated MS-SSIM statistics.

Returns:

Mean and standard deviation

Return type:

Tuple[torch.Tensor, torch.Tensor]

reset() None[source]

Reset accumulated statistics.

compute_with_stats(x: Tensor, y: Tensor, *args: Any, **kwargs: Any) Tuple[Tensor, Tensor][source]

Compute MS-SSIM with mean and standard deviation.

Parameters:
  • x (torch.Tensor) – Predicted images

  • y (torch.Tensor) – Target images

  • *args – Variable length argument list (currently unused).

  • **kwargs – Arbitrary keyword arguments (currently unused).

Returns:

Mean and standard deviation of MS-SSIM values

Return type:

Tuple[torch.Tensor, torch.Tensor]

property data_range: float

Get the data range used by the underlying torchmetrics implementation.