Source code for kaira.metrics.image.lpips

"""Learned Perceptual Image Patch Similarity (LPIPS) metric.

LPIPS is a learned perceptual metric that leverages deep features and better correlates
with human perception than traditional metrics :cite:`zhang2018unreasonable`.
"""

# Need to import inspect
import inspect
from typing import Any, Literal, Tuple

import torch
import torchmetrics
from torch import Tensor
from torchmetrics.functional.image.lpips import _lpips_compute, _lpips_update

from ..base import BaseMetric
from ..registry import MetricRegistry


[docs] @MetricRegistry.register_metric("lpips") class LearnedPerceptualImagePatchSimilarity(BaseMetric): """Learned Perceptual Image Patch Similarity (LPIPS) Module. LPIPS measures the perceptual similarity between images using deep features. Lower values indicate greater perceptual similarity. Unlike traditional metrics like PSNR and SSIM, LPIPS uses human perceptual judgments to calibrate a deep feature-based metric :cite:`zhang2018unreasonable`. """
[docs] def __init__(self, net_type: Literal["vgg", "alex", "squeeze"] = "alex", normalize: bool = False, *args: Any, **kwargs: Any) -> None: """Initialize the LPIPS module. Args: net_type (str): The backbone network to use ('vgg', 'alex', or 'squeeze') normalize (bool): Whether to normalize the input images to [-1,1] range. If True, the input images should be in the range [0,1]. If False, the input images should be in the range [-1,1]. *args: Variable length argument list passed to the base class and torchmetrics. **kwargs: Arbitrary keyword arguments passed to the base class and torchmetrics. """ # Remove name="LPIPS" as BaseMetric handles it super().__init__(*args, **kwargs) # Pass args and kwargs self.net_type = net_type self.normalize = normalize # Pass only relevant kwargs to torchmetrics torchmetrics_kwargs = {k: v for k, v in kwargs.items() if k in inspect.signature(torchmetrics.image.lpip.LearnedPerceptualImagePatchSimilarity.__init__).parameters} self.lpips = torchmetrics.image.lpip.LearnedPerceptualImagePatchSimilarity(net_type, normalize=normalize, **torchmetrics_kwargs) self.register_buffer("sum_scores", torch.tensor(0.0)) self.register_buffer("sum_sq", torch.tensor(0.0)) self.register_buffer("total", torch.tensor(0))
# Rename img1 to x and img2 to y to match BaseMetric
[docs] def forward(self, x: Tensor, y: Tensor, *args: Any, **kwargs: Any) -> Tensor: """Calculate LPIPS between two images. Args: x (Tensor): First batch of images y (Tensor): Second batch of images *args: Variable length argument list (currently unused). **kwargs: Arbitrary keyword arguments (currently unused). Returns: Tensor: LPIPS values for each sample """ # Note: *args and **kwargs are not directly used by self.lpips call here # but are included for interface consistency. result = self.lpips(x, y) return result.unsqueeze(0) if result.dim() == 0 else result
# Rename img1 to x and img2 to y to match BaseMetric
[docs] def update(self, x: Tensor, y: Tensor, *args: Any, **kwargs: Any) -> None: """Update the internal state with a batch of samples. Args: x (Tensor): First batch of images y (Tensor): Second batch of images *args: Variable length argument list (currently unused). **kwargs: Arbitrary keyword arguments (currently unused). """ # Note: *args and **kwargs are not directly used by _lpips_update call here # but are included for interface consistency. loss, total = _lpips_update(x, y, net=self.lpips.net, normalize=self.normalize) self.sum_scores += loss.sum() self.total += total self.sum_sq += (loss**2).sum()
[docs] def compute(self) -> Tuple[Tensor, Tensor]: """Compute the accumulated LPIPS statistics. Returns: Tuple[Tensor, Tensor]: Mean and standard deviation of LPIPS values """ mean = _lpips_compute(self.sum_scores, self.total, "mean") std = torch.sqrt((self.sum_sq / self.total) - mean**2) return mean, std
[docs] def reset(self) -> None: """Reset accumulated statistics.""" self.sum_scores.zero_() self.sum_sq.zero_() self.total.zero_()
# Alias for backward compatibility LPIPS = LearnedPerceptualImagePatchSimilarity