Source code for kaira.losses.image

"""Losses module for Kaira.

This module contains various loss functions for training communication systems, including MSE loss,
LPIPS loss, and SSIM loss. These loss functions are widely used in image processing and
computer vision tasks :cite:`wang2009mean` :cite:`zhang2018unreasonable`.
"""

from typing import Sequence

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models  # type: ignore

from kaira.metrics.image import (
    LearnedPerceptualImagePatchSimilarity,
    MultiScaleSSIM,
    StructuralSimilarityIndexMeasure,
)

from .base import BaseLoss
from .registry import LossRegistry


[docs] @LossRegistry.register_loss() class MSELoss(BaseLoss): """Mean Squared Error (MSE) Loss Module. This module calculates the MSE loss between the input and the target. MSE is the most widely used loss function for regression tasks and image restoration :cite:`wang2009mean`. """
[docs] def __init__(self): """Initialize the MSELoss module.""" super().__init__() self.mse = nn.MSELoss()
[docs] def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """Forward pass through the MSELoss module. Args: x (torch.Tensor): The input tensor. target (torch.Tensor): The target tensor. Returns: torch.Tensor: The MSE loss between the input and the target. """ return self.mse(x, target)
[docs] @LossRegistry.register_loss() class CombinedLoss(BaseLoss): """Combined Loss Module. This module combines multiple loss functions into a single loss function. Combining multiple losses is a common approach to improve image quality by addressing different aspects of visual perception :cite:`zhao2016loss`. """
[docs] def __init__(self, losses: Sequence[BaseLoss], weights: list[float]): """Initialize the CombinedLoss module. Args: losses (Sequence[BaseLoss]): A list of loss functions to combine. weights (list[float]): A list of weights for each loss function. """ super().__init__() self.losses = losses self.weights = weights
[docs] def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """Forward pass through the CombinedLoss module. Args: x (torch.Tensor): The input tensor. target (torch.Tensor): The target tensor. Returns: torch.Tensor: The combined loss between the input and the target. """ # Start with a scalar tensor on the correct device loss = torch.tensor(0.0, device=x.device) for i, cur_loss in enumerate(self.losses): # Compute current loss current_loss_value = cur_loss(x, target) # Apply weight to the loss value weighted_loss = self.weights[i] * current_loss_value # Add to total loss, preserving shape if the loss returns a non-scalar tensor if isinstance(weighted_loss, torch.Tensor): # Handle different tensor dimensions - if loss is a tensor with dimensions # we need to make sure it's properly aggregated if weighted_loss.ndim > 0: loss = loss + weighted_loss.mean() else: loss = loss + weighted_loss else: # Handle case where loss might be a Python scalar loss = loss + torch.tensor(weighted_loss, device=x.device) return loss
[docs] @LossRegistry.register_loss() class MSELPIPSLoss(BaseLoss): """MSELPIPSLoss Module. This module combines MSE and LPIPS losses with configurable weights. This combination balances pixel-wise accuracy (MSE) with perceptual quality (LPIPS) :cite:`zhang2018unreasonable`. """
[docs] def __init__(self, mse_weight=1.0, lpips_weight=1.0): """Initialize the MSELPIPSLoss module. Args: mse_weight (float): Weight for the MSE loss component. lpips_weight (float): Weight for the LPIPS loss component. """ super().__init__() self.mse_loss = MSELoss() self.lpips_loss = LPIPSLoss() self.mse_weight = mse_weight self.lpips_weight = lpips_weight
[docs] def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """Forward pass through the MSELPIPSLoss module. Args: x (torch.Tensor): The input tensor. target (torch.Tensor): The target tensor. Returns: torch.Tensor: The combined MSE and LPIPS loss between the input and target. """ mse = self.mse_loss(x, target) lpips = self.lpips_loss(x, target) return self.mse_weight * mse + self.lpips_weight * lpips
[docs] @LossRegistry.register_loss() class LPIPSLoss(BaseLoss): """Learned Perceptual Image Patch Similarity (LPIPS) Loss Module. This module calculates the LPIPS loss between the input and the target. LPIPS uses deep features to measure perceptual similarity between images, which correlates better with human judgment than pixel-based metrics :cite:`zhang2018unreasonable`. """
[docs] def __init__(self): """Initialize the LPIPSLoss module.""" super().__init__() self.lpips = LearnedPerceptualImagePatchSimilarity()
[docs] def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """Forward pass through the LPIPSLoss module. Args: x (torch.Tensor): The input tensor. target (torch.Tensor): The target tensor. Returns: torch.Tensor: The LPIPS loss between the input and the target. """ return self.lpips(x, target)
[docs] @LossRegistry.register_loss() class SSIMLoss(BaseLoss): """Structural Similarity Index Measure (SSIM) Loss Module. This module calculates the SSIM loss between the input and the target. SSIM evaluates image similarity based on luminance, contrast, and structure, better matching human visual perception :cite:`wang2004image`. """
[docs] def __init__(self, kernel_size: int = 11, data_range: float = 1.0): """Initialize the SSIMLoss module. Args: kernel_size (int): Size of the Gaussian kernel used in SSIM calculation. data_range (float): Range of the input data (typically 1.0 or 255). """ super().__init__() self.ssim = StructuralSimilarityIndexMeasure(data_range=data_range, kernel_size=kernel_size)
[docs] def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """Forward pass through the SSIMLoss module. Args: x (torch.Tensor): The input tensor. target (torch.Tensor): The target tensor. Returns: torch.Tensor: The SSIM loss between the input and the target. """ # Normalize input data to range [-1, 1] if necessary x_norm = torch.clamp(x, -1.0, 1.0) target_norm = torch.clamp(target, -1.0, 1.0) # 1 - SSIM because higher SSIM means better similarity (we want to minimize loss) return 1 - self.ssim(x_norm, target_norm)
[docs] @LossRegistry.register_loss() class MSSSIMLoss(BaseLoss): """Multi-Scale Structural Similarity Index Measure (MS-SSIM) Loss Module. This module calculates the MS-SSIM loss between the input and the target. MS-SSIM extends SSIM by evaluating similarity at multiple scales, making it more robust to viewing distance variations :cite:`wang2003multiscale`. """
[docs] def __init__(self, kernel_size: int = 11, data_range: float = 1.0): """Initialize the MSSSIMLoss module. Args: kernel_size (int): Size of the Gaussian kernel used in SSIM calculation. data_range (float): Range of the input data (typically 1.0 or 255). """ super().__init__() self.ms_ssim = MultiScaleSSIM(kernel_size=kernel_size, data_range=data_range)
[docs] def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """Forward pass through the MSSSIMLoss module. Args: x (torch.Tensor): The input tensor. target (torch.Tensor): The target tensor. Returns: torch.Tensor: The MS-SSIM loss between the input and the target. """ # Normalize input data to range [-1, 1] if necessary x_norm = torch.clamp(x, -1.0, 1.0) target_norm = torch.clamp(target, -1.0, 1.0) # 1 - MS-SSIM because higher MS-SSIM means better similarity (we want to minimize loss) return 1 - self.ms_ssim(x_norm, target_norm)
[docs] @LossRegistry.register_loss() class L1Loss(BaseLoss): """L1 (Mean Absolute Error) Loss Module. This module calculates the L1 loss between the input and the target. L1 loss is often preferred over MSE for image restoration tasks as it preserves edges better and is more robust to outliers :cite:`zhao2016loss`. """
[docs] def __init__(self): """Initialize the L1Loss module.""" super().__init__() self.l1 = nn.L1Loss()
[docs] def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """Forward pass through the L1Loss module. Args: x (torch.Tensor): The input tensor. target (torch.Tensor): The target tensor. Returns: torch.Tensor: The L1 loss between the input and the target. """ return self.l1(x, target)
[docs] @LossRegistry.register_loss() class VGGLoss(BaseLoss): """VGG Perceptual Loss Module. This module calculates the perceptual loss using features extracted by the VGG network. VGG loss measures similarity in feature space rather than pixel space, capturing perceptual differences better :cite:`johnson2016perceptual` :cite:`dosovitskiy2016generating`. """
[docs] def __init__(self, layer_weights=None): """Initialize the VGGLoss module. Args: layer_weights (dict, optional): Weights for different VGG layers. Default is {'conv1_2': 0.1, 'conv2_2': 0.2, 'conv3_3': 0.4, 'conv4_3': 0.3} """ super().__init__() if layer_weights is None: self.layer_weights = {"conv1_2": 0.1, "conv2_2": 0.2, "conv3_3": 0.4, "conv4_3": 0.3} else: self.layer_weights = layer_weights # Updated to use weights parameter instead of deprecated pretrained self.vgg = models.vgg16(weights=models.VGG16_Weights.DEFAULT).features.eval() # Freeze VGG parameters - standard way for param in self.vgg.parameters(): param.requires_grad = False # For test compatibility - handle direct access to _params if hasattr(self.vgg, "_params"): for param in self.vgg._params: param.requires_grad = False self.layer_name_mapping = { "3": "conv1_2", "8": "conv2_2", "15": "conv3_3", "22": "conv4_3", }
[docs] def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """Forward pass through the VGGLoss module. Args: x (torch.Tensor): The input tensor. target (torch.Tensor): The target tensor. Returns: torch.Tensor: The VGG perceptual loss between the input and the target. """ # Normalize to match VGG input requirements mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(x.device) std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(x.device) x = (x - mean) / std target = (target - mean) / std loss = 0.0 x_features = {} target_features = {} for name, module in self.vgg._modules.items(): x = module(x) target = module(target) if name in self.layer_name_mapping: layer_name = self.layer_name_mapping[name] x_features[layer_name] = x target_features[layer_name] = target if layer_name in self.layer_weights: loss += self.layer_weights[layer_name] * F.mse_loss(x, target) return loss
[docs] @LossRegistry.register_loss() class TotalVariationLoss(BaseLoss): """Total Variation Loss Module. This module calculates the total variation loss to encourage spatial smoothness. Total variation regularization reduces noise while preserving edges in images :cite:`rudin1992nonlinear` :cite:`mahendran2015understanding`. """
[docs] def __init__(self): """Initialize the TotalVariationLoss module.""" super().__init__()
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: """Forward pass through the TotalVariationLoss module. Args: x (torch.Tensor): The input tensor of shape (B, C, H, W). Returns: torch.Tensor: The total variation loss of the input. """ batch_size = x.size()[0] h_tv = torch.pow((x[:, :, 1:, :] - x[:, :, :-1, :]), 2).sum() w_tv = torch.pow((x[:, :, :, 1:] - x[:, :, :, :-1]), 2).sum() return (h_tv + w_tv) / batch_size
[docs] @LossRegistry.register_loss() class GradientLoss(BaseLoss): """Gradient Loss Module. This module calculates the gradient loss to preserve edge information. Gradient loss explicitly penalizes differences in image gradients, helping to preserve structural information and edges :cite:`mathieu2015deep`. """
[docs] def __init__(self): """Initialize the GradientLoss module.""" super().__init__() self.sobel_x = torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], dtype=torch.float32).unsqueeze(0).unsqueeze(0) self.sobel_y = torch.tensor([[-1, -2, -1], [0, 0, 0], [1, 2, 1]], dtype=torch.float32).unsqueeze(0).unsqueeze(0)
[docs] def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """Forward pass through the GradientLoss module. Args: x (torch.Tensor): The input tensor. target (torch.Tensor): The target tensor. Returns: torch.Tensor: The gradient loss between the input and the target. """ device = x.device self.sobel_x = self.sobel_x.to(device) self.sobel_y = self.sobel_y.to(device) b, c, h, w = x.size() loss = 0.0 for ch in range(c): # Extract gradients using Sobel operators x_grad_x = F.conv2d(x[:, ch : ch + 1, :, :], self.sobel_x, padding=1) x_grad_y = F.conv2d(x[:, ch : ch + 1, :, :], self.sobel_y, padding=1) target_grad_x = F.conv2d(target[:, ch : ch + 1, :, :], self.sobel_x, padding=1) target_grad_y = F.conv2d(target[:, ch : ch + 1, :, :], self.sobel_y, padding=1) # Calculate differences in gradients loss += F.l1_loss(x_grad_x, target_grad_x) + F.l1_loss(x_grad_y, target_grad_y) return loss / c
[docs] @LossRegistry.register_loss() class PSNRLoss(BaseLoss): """Peak Signal-to-Noise Ratio (PSNR) Loss Module. This module calculates the negative PSNR (to be minimized) between the input and target. PSNR is a standard metric for image quality assessment :cite:`hore2010image`, though it doesn't always correlate well with human perception :cite:`huynh2008scope`. """
[docs] def __init__(self, max_val=1.0): """Initialize the PSNRLoss module. Args: max_val (float): Maximum value of the input tensor. Default is 1.0. """ super().__init__() self.max_val = max_val
[docs] def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """Forward pass through the PSNRLoss module. Args: x (torch.Tensor): The input tensor. target (torch.Tensor): The target tensor. Returns: torch.Tensor: The negative PSNR loss between the input and the target. """ mse = F.mse_loss(x, target) psnr = 20 * torch.log10(self.max_val / torch.sqrt(mse)) # Return negative PSNR since we want to minimize the loss return -psnr
[docs] @LossRegistry.register_loss() class StyleLoss(BaseLoss): """Style Loss Module based on Gram matrices. This module calculates the style loss used in neural style transfer. Style loss computes the difference between Gram matrices of feature maps, capturing texture information independent of spatial arrangement :cite:`gatys2016image`. """
[docs] def __init__(self, apply_gram=True, normalize=False, layer_weights=None): """Initialize the StyleLoss module. Args: apply_gram (bool): Whether to apply Gram matrix computation (True) or use precomputed Gram matrices as input (False). Default is True. normalize (bool): Whether to normalize the Gram matrices. Default is False. layer_weights (dict, optional): Weights for different layers. Default is None (equal weights for all layers). """ super().__init__() # Initialize common parameters self.apply_gram = apply_gram self.normalize = normalize # Try to initialize VGG-based feature extractor try: vgg = models.vgg16(weights=models.VGG16_Weights.DEFAULT).features.eval() self.feature_extractor = nn.Sequential() self.style_layers = [0, 5, 10, 17, 24] self.layer_weights = layer_weights or {f"layer_{i}": 1.0 for i in range(len(self.style_layers))} i = 0 for layer in vgg.children(): # Classify layer type and assign appropriate name if isinstance(layer, nn.Conv2d): i += 1 name = f"conv_{i}" elif isinstance(layer, nn.ReLU): name = f"relu_{i}" layer = nn.ReLU(inplace=False) elif isinstance(layer, nn.MaxPool2d): name = f"pool_{i}" elif isinstance(layer, nn.BatchNorm2d): name = f"bn_{i}" else: # Generic name for unrecognized layers name = f"unknown_{i}" self.feature_extractor.add_module(name, layer) # Freeze parameters for param in self.feature_extractor.parameters(): param.requires_grad = False except Exception: # Fall back to minimal configuration for graceful degradation self.feature_extractor = nn.Sequential() self.style_layers = [0] self.layer_weights = layer_weights or {"layer_0": 1.0}
[docs] def gram_matrix(self, x): """Calculate Gram matrix from features. Args: x (torch.Tensor): Feature tensor. Returns: torch.Tensor: Gram matrix. """ batch_size, channels, height, width = x.size() # Make tensor contiguous before reshaping x_cont = x.contiguous() features = x_cont.view(batch_size, channels, height * width) features_t = features.transpose(1, 2) gram = features.bmm(features_t) # Normalize if requested if self.normalize: gram = gram / (channels * height * width) return gram
[docs] def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """Forward pass through the StyleLoss module. Args: x (torch.Tensor): The input tensor. target (torch.Tensor): The target tensor. Returns: torch.Tensor: The style loss between the input and the target. """ # Handle precomputed Gram matrices case if not self.apply_gram: return F.mse_loss(x, target) # Input shape validation for image inputs if x.dim() != 4 or target.dim() != 4: raise ValueError("Input tensors must be 4D (batch, channels, height, width)") if x.size(1) != 3 or target.size(1) != 3: raise ValueError("Input tensors must have 3 channels (RGB)") # Normalize to match VGG input requirements mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(x.device) std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(x.device) x = (x - mean) / std target = (target - mean) / std loss = 0.0 # Extract features and calculate gram matrices for i, layer in enumerate(self.feature_extractor): x = layer(x) target = layer(target) if i in self.style_layers: layer_idx = self.style_layers.index(i) layer_name = f"layer_{layer_idx}" weight = self.layer_weights.get(layer_name, 1.0) x_gram = self.gram_matrix(x) target_gram = self.gram_matrix(target) loss += weight * F.mse_loss(x_gram, target_gram) return loss
[docs] @LossRegistry.register_loss() class FocalLoss(BaseLoss): """Focal Loss Module for dealing with class imbalance. This implementation works for both binary and multi-class problems. Focal loss addresses class imbalance by down-weighting well-classified examples, focusing training on difficult examples :cite:`lin2017focal`. """
[docs] def __init__(self, alpha=None, gamma=2.0, reduction="mean"): """Initialize the FocalLoss module. Args: alpha (float or tensor): Weighting factor for the rare class. Default is None. gamma (float): Focusing parameter. Default is 2.0. reduction (str): Specifies the reduction to apply to the output. Default is 'mean'. """ super().__init__() self.alpha = alpha self.gamma = gamma self.reduction = reduction
[docs] def forward(self, inputs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: """Forward pass through the FocalLoss module. Args: inputs (torch.Tensor): The input logits tensor. targets (torch.Tensor): The target tensor. Returns: torch.Tensor: The focal loss between the input and the target. """ # For binary classification if inputs.shape[1] == 1 or len(inputs.shape) == 1: inputs = inputs.view(-1) targets = targets.view(-1) bce_loss = F.binary_cross_entropy_with_logits(inputs, targets.float(), reduction="none") pt = torch.exp(-bce_loss) focal_loss = (1 - pt) ** self.gamma * bce_loss if self.alpha is not None: focal_loss = self.alpha * targets + (1 - self.alpha) * (1 - targets) * focal_loss # For multi-class classification else: log_softmax = F.log_softmax(inputs, dim=1) ce_loss = F.nll_loss(log_softmax, targets, reduction="none") pt = torch.exp(-ce_loss) focal_loss = (1 - pt) ** self.gamma * ce_loss if self.alpha is not None: alpha_tensor = self.alpha.to(inputs.device) alpha_t = alpha_tensor.gather(0, targets) focal_loss = alpha_t * focal_loss if self.reduction == "mean": return focal_loss.mean() elif self.reduction == "sum": return focal_loss.sum() else: # 'none' return focal_loss
@LossRegistry.register_loss() class ElasticLoss(BaseLoss): """Elastic Loss combines L1 and L2 losses. This loss function smoothly transitions between L1 and L2 behavior. Elastic net regularization combines the benefits of both L1 and L2 penalties, offering robustness to outliers while maintaining smoothness :cite:`zou2005regularization`. """ def __init__(self, beta=1.0, alpha=0.5, reduction="mean"): """Initialize the ElasticLoss module. Args: beta (float): Balance parameter between L1 and L2. Default is 1.0. alpha (float): Weight parameter controlling L1 vs L2 contribution (0.5 means equal mix). Default is 0.5. reduction (str): Reduction method ('mean', 'sum', 'none'). Default is 'mean'. """ super().__init__() self.beta = max(beta, 1e-8) # Prevent division by zero self.alpha = alpha self.reduction = reduction def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """Forward pass through the ElasticLoss module. Args: x (torch.Tensor): The input tensor. target (torch.Tensor): The target tensor. Returns: torch.Tensor: The elastic loss between the input and the target. """ diff = x - target abs_diff = torch.abs(diff) squared_diff = diff**2 # Handle edge cases based on alpha and beta values if self.alpha >= 0.99: # Close to 1.0, act like pure L1 point_losses = abs_diff elif self.alpha <= 0.01: # Close to 0.0, act like pure L2 point_losses = squared_diff # Removed 0.5 factor to match standard MSE else: # Compute weighted combination of L1 and L2 loss l1_component = abs_diff l2_component = 0.5 * squared_diff / self.beta # Apply smooth transition between L1 and L2 based on difference magnitude point_losses = torch.where(abs_diff < self.beta, self.alpha * l2_component, (1.0 - self.alpha) * l1_component + self.alpha * self.beta / 2.0) # Apply reduction if self.reduction == "mean": return point_losses.mean() elif self.reduction == "sum": return point_losses.sum() else: # 'none' return point_losses __all__ = ["MSELoss", "CombinedLoss", "MSELPIPSLoss", "LPIPSLoss", "SSIMLoss", "MSSSIMLoss", "L1Loss", "VGGLoss", "TotalVariationLoss", "GradientLoss", "PSNRLoss", "StyleLoss", "FocalLoss"]