kaira.metrics.image.LearnedPerceptualImagePatchSimilarity

Inheritance diagram of LearnedPerceptualImagePatchSimilarity

Inheritance diagram for LearnedPerceptualImagePatchSimilarity

class kaira.metrics.image.LearnedPerceptualImagePatchSimilarity(net_type: Literal['vgg', 'alex', 'squeeze'] = 'alex', normalize: bool = False, *args: Any, **kwargs: Any)[source]

Bases: BaseMetric

Learned Perceptual Image Patch Similarity (LPIPS) Module.

LPIPS measures the perceptual similarity between images using deep features. Lower values indicate greater perceptual similarity. Unlike traditional metrics like PSNR and SSIM, LPIPS uses human perceptual judgments to calibrate a deep feature-based metric [Zhang et al., 2018].

Methods

__init__

Initialize the LPIPS module.

compute

Compute the accumulated LPIPS statistics.

compute_with_stats

Compute metric with mean and standard deviation.

forward

Calculate LPIPS between two images.

reset

Reset accumulated statistics.

update

Update the internal state with a batch of samples.

__init__(net_type: Literal['vgg', 'alex', 'squeeze'] = 'alex', normalize: bool = False, *args: Any, **kwargs: Any) None[source]

Initialize the LPIPS module.

Parameters:
  • net_type (str) – The backbone network to use (‘vgg’, ‘alex’, or ‘squeeze’)

  • normalize (bool) – Whether to normalize the input images to [-1,1] range. If True, the input images should be in the range [0,1]. If False, the input images should be in the range [-1,1].

  • *args – Variable length argument list passed to the base class and torchmetrics.

  • **kwargs – Arbitrary keyword arguments passed to the base class and torchmetrics.

forward(x: Tensor, y: Tensor, *args: Any, **kwargs: Any) Tensor[source]

Calculate LPIPS between two images.

Parameters:
  • x (Tensor) – First batch of images

  • y (Tensor) – Second batch of images

  • *args – Variable length argument list (currently unused).

  • **kwargs – Arbitrary keyword arguments (currently unused).

Returns:

LPIPS values for each sample

Return type:

Tensor

update(x: Tensor, y: Tensor, *args: Any, **kwargs: Any) None[source]

Update the internal state with a batch of samples.

Parameters:
  • x (Tensor) – First batch of images

  • y (Tensor) – Second batch of images

  • *args – Variable length argument list (currently unused).

  • **kwargs – Arbitrary keyword arguments (currently unused).

compute() Tuple[Tensor, Tensor][source]

Compute the accumulated LPIPS statistics.

Returns:

Mean and standard deviation of LPIPS values

Return type:

Tuple[Tensor, Tensor]

reset() None[source]

Reset accumulated statistics.

compute_with_stats(x: Tensor, y: Tensor, *args: Any, **kwargs: Any) Tuple[Tensor, Tensor]

Compute metric with mean and standard deviation.

Parameters:
  • x (torch.Tensor) – The first input tensor (typically predictions)

  • y (torch.Tensor) – The second input tensor (typically targets)

  • *args – Variable length argument list.

  • **kwargs – Arbitrary keyword arguments.

Returns:

Mean and standard deviation of the metric

Return type:

Tuple[torch.Tensor, torch.Tensor]