"""Losses module for Kaira.
This module contains various loss functions for training communication systems, including MSE loss,
LPIPS loss, and SSIM loss. These loss functions are widely used in image processing and
computer vision tasks :cite:`wang2009mean` :cite:`zhang2018unreasonable`.
"""
from typing import Sequence
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models # type: ignore
from kaira.metrics.image import (
LearnedPerceptualImagePatchSimilarity,
MultiScaleSSIM,
StructuralSimilarityIndexMeasure,
)
from .base import BaseLoss
from .registry import LossRegistry
[docs]
@LossRegistry.register_loss()
class MSELoss(BaseLoss):
"""Mean Squared Error (MSE) Loss Module.
This module calculates the MSE loss between the input and the target.
MSE is the most widely used loss function for regression tasks and image restoration
:cite:`wang2009mean`.
"""
[docs]
def __init__(self):
"""Initialize the MSELoss module."""
super().__init__()
self.mse = nn.MSELoss()
[docs]
def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""Forward pass through the MSELoss module.
Args:
x (torch.Tensor): The input tensor.
target (torch.Tensor): The target tensor.
Returns:
torch.Tensor: The MSE loss between the input and the target.
"""
return self.mse(x, target)
[docs]
@LossRegistry.register_loss()
class CombinedLoss(BaseLoss):
"""Combined Loss Module.
This module combines multiple loss functions into a single loss function.
Combining multiple losses is a common approach to improve image quality by
addressing different aspects of visual perception :cite:`zhao2016loss`.
"""
[docs]
def __init__(self, losses: Sequence[BaseLoss], weights: list[float]):
"""Initialize the CombinedLoss module.
Args:
losses (Sequence[BaseLoss]): A list of loss functions to combine.
weights (list[float]): A list of weights for each loss function.
"""
super().__init__()
self.losses = losses
self.weights = weights
[docs]
def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""Forward pass through the CombinedLoss module.
Args:
x (torch.Tensor): The input tensor.
target (torch.Tensor): The target tensor.
Returns:
torch.Tensor: The combined loss between the input and the target.
"""
# Start with a scalar tensor on the correct device
loss = torch.tensor(0.0, device=x.device)
for i, cur_loss in enumerate(self.losses):
# Compute current loss
current_loss_value = cur_loss(x, target)
# Apply weight to the loss value
weighted_loss = self.weights[i] * current_loss_value
# Add to total loss, preserving shape if the loss returns a non-scalar tensor
if isinstance(weighted_loss, torch.Tensor):
# Handle different tensor dimensions - if loss is a tensor with dimensions
# we need to make sure it's properly aggregated
if weighted_loss.ndim > 0:
loss = loss + weighted_loss.mean()
else:
loss = loss + weighted_loss
else:
# Handle case where loss might be a Python scalar
loss = loss + torch.tensor(weighted_loss, device=x.device)
return loss
[docs]
@LossRegistry.register_loss()
class MSELPIPSLoss(BaseLoss):
"""MSELPIPSLoss Module.
This module combines MSE and LPIPS losses with configurable weights.
This combination balances pixel-wise accuracy (MSE) with perceptual quality (LPIPS)
:cite:`zhang2018unreasonable`.
"""
[docs]
def __init__(self, mse_weight=1.0, lpips_weight=1.0):
"""Initialize the MSELPIPSLoss module.
Args:
mse_weight (float): Weight for the MSE loss component.
lpips_weight (float): Weight for the LPIPS loss component.
"""
super().__init__()
self.mse_loss = MSELoss()
self.lpips_loss = LPIPSLoss()
self.mse_weight = mse_weight
self.lpips_weight = lpips_weight
[docs]
def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""Forward pass through the MSELPIPSLoss module.
Args:
x (torch.Tensor): The input tensor.
target (torch.Tensor): The target tensor.
Returns:
torch.Tensor: The combined MSE and LPIPS loss between the input and target.
"""
mse = self.mse_loss(x, target)
lpips = self.lpips_loss(x, target)
return self.mse_weight * mse + self.lpips_weight * lpips
[docs]
@LossRegistry.register_loss()
class LPIPSLoss(BaseLoss):
"""Learned Perceptual Image Patch Similarity (LPIPS) Loss Module.
This module calculates the LPIPS loss between the input and the target.
LPIPS uses deep features to measure perceptual similarity between images,
which correlates better with human judgment than pixel-based metrics
:cite:`zhang2018unreasonable`.
"""
[docs]
def __init__(self):
"""Initialize the LPIPSLoss module."""
super().__init__()
self.lpips = LearnedPerceptualImagePatchSimilarity()
[docs]
def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""Forward pass through the LPIPSLoss module.
Args:
x (torch.Tensor): The input tensor.
target (torch.Tensor): The target tensor.
Returns:
torch.Tensor: The LPIPS loss between the input and the target.
"""
return self.lpips(x, target)
[docs]
@LossRegistry.register_loss()
class SSIMLoss(BaseLoss):
"""Structural Similarity Index Measure (SSIM) Loss Module.
This module calculates the SSIM loss between the input and the target.
SSIM evaluates image similarity based on luminance, contrast, and structure,
better matching human visual perception :cite:`wang2004image`.
"""
[docs]
def __init__(self, kernel_size: int = 11, data_range: float = 1.0):
"""Initialize the SSIMLoss module.
Args:
kernel_size (int): Size of the Gaussian kernel used in SSIM calculation.
data_range (float): Range of the input data (typically 1.0 or 255).
"""
super().__init__()
self.ssim = StructuralSimilarityIndexMeasure(data_range=data_range, kernel_size=kernel_size)
[docs]
def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""Forward pass through the SSIMLoss module.
Args:
x (torch.Tensor): The input tensor.
target (torch.Tensor): The target tensor.
Returns:
torch.Tensor: The SSIM loss between the input and the target.
"""
# Normalize input data to range [-1, 1] if necessary
x_norm = torch.clamp(x, -1.0, 1.0)
target_norm = torch.clamp(target, -1.0, 1.0)
# 1 - SSIM because higher SSIM means better similarity (we want to minimize loss)
return 1 - self.ssim(x_norm, target_norm)
[docs]
@LossRegistry.register_loss()
class MSSSIMLoss(BaseLoss):
"""Multi-Scale Structural Similarity Index Measure (MS-SSIM) Loss Module.
This module calculates the MS-SSIM loss between the input and the target.
MS-SSIM extends SSIM by evaluating similarity at multiple scales, making it
more robust to viewing distance variations :cite:`wang2003multiscale`.
"""
[docs]
def __init__(self, kernel_size: int = 11, data_range: float = 1.0):
"""Initialize the MSSSIMLoss module.
Args:
kernel_size (int): Size of the Gaussian kernel used in SSIM calculation.
data_range (float): Range of the input data (typically 1.0 or 255).
"""
super().__init__()
self.ms_ssim = MultiScaleSSIM(kernel_size=kernel_size, data_range=data_range)
[docs]
def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""Forward pass through the MSSSIMLoss module.
Args:
x (torch.Tensor): The input tensor.
target (torch.Tensor): The target tensor.
Returns:
torch.Tensor: The MS-SSIM loss between the input and the target.
"""
# Normalize input data to range [-1, 1] if necessary
x_norm = torch.clamp(x, -1.0, 1.0)
target_norm = torch.clamp(target, -1.0, 1.0)
# 1 - MS-SSIM because higher MS-SSIM means better similarity (we want to minimize loss)
return 1 - self.ms_ssim(x_norm, target_norm)
[docs]
@LossRegistry.register_loss()
class L1Loss(BaseLoss):
"""L1 (Mean Absolute Error) Loss Module.
This module calculates the L1 loss between the input and the target.
L1 loss is often preferred over MSE for image restoration tasks as it
preserves edges better and is more robust to outliers :cite:`zhao2016loss`.
"""
[docs]
def __init__(self):
"""Initialize the L1Loss module."""
super().__init__()
self.l1 = nn.L1Loss()
[docs]
def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""Forward pass through the L1Loss module.
Args:
x (torch.Tensor): The input tensor.
target (torch.Tensor): The target tensor.
Returns:
torch.Tensor: The L1 loss between the input and the target.
"""
return self.l1(x, target)
[docs]
@LossRegistry.register_loss()
class VGGLoss(BaseLoss):
"""VGG Perceptual Loss Module.
This module calculates the perceptual loss using features extracted by the VGG network.
VGG loss measures similarity in feature space rather than pixel space, capturing
perceptual differences better :cite:`johnson2016perceptual` :cite:`dosovitskiy2016generating`.
"""
[docs]
def __init__(self, layer_weights=None):
"""Initialize the VGGLoss module.
Args:
layer_weights (dict, optional): Weights for different VGG layers.
Default is {'conv1_2': 0.1, 'conv2_2': 0.2, 'conv3_3': 0.4, 'conv4_3': 0.3}
"""
super().__init__()
if layer_weights is None:
self.layer_weights = {"conv1_2": 0.1, "conv2_2": 0.2, "conv3_3": 0.4, "conv4_3": 0.3}
else:
self.layer_weights = layer_weights
# Updated to use weights parameter instead of deprecated pretrained
self.vgg = models.vgg16(weights=models.VGG16_Weights.DEFAULT).features.eval()
# Freeze VGG parameters - standard way
for param in self.vgg.parameters():
param.requires_grad = False
# For test compatibility - handle direct access to _params
if hasattr(self.vgg, "_params"):
for param in self.vgg._params:
param.requires_grad = False
self.layer_name_mapping = {
"3": "conv1_2",
"8": "conv2_2",
"15": "conv3_3",
"22": "conv4_3",
}
[docs]
def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""Forward pass through the VGGLoss module.
Args:
x (torch.Tensor): The input tensor.
target (torch.Tensor): The target tensor.
Returns:
torch.Tensor: The VGG perceptual loss between the input and the target.
"""
# Normalize to match VGG input requirements
mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(x.device)
std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(x.device)
x = (x - mean) / std
target = (target - mean) / std
loss = 0.0
x_features = {}
target_features = {}
for name, module in self.vgg._modules.items():
x = module(x)
target = module(target)
if name in self.layer_name_mapping:
layer_name = self.layer_name_mapping[name]
x_features[layer_name] = x
target_features[layer_name] = target
if layer_name in self.layer_weights:
loss += self.layer_weights[layer_name] * F.mse_loss(x, target)
return loss
[docs]
@LossRegistry.register_loss()
class TotalVariationLoss(BaseLoss):
"""Total Variation Loss Module.
This module calculates the total variation loss to encourage spatial smoothness.
Total variation regularization reduces noise while preserving edges in images
:cite:`rudin1992nonlinear` :cite:`mahendran2015understanding`.
"""
[docs]
def __init__(self):
"""Initialize the TotalVariationLoss module."""
super().__init__()
[docs]
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward pass through the TotalVariationLoss module.
Args:
x (torch.Tensor): The input tensor of shape (B, C, H, W).
Returns:
torch.Tensor: The total variation loss of the input.
"""
batch_size = x.size()[0]
h_tv = torch.pow((x[:, :, 1:, :] - x[:, :, :-1, :]), 2).sum()
w_tv = torch.pow((x[:, :, :, 1:] - x[:, :, :, :-1]), 2).sum()
return (h_tv + w_tv) / batch_size
[docs]
@LossRegistry.register_loss()
class GradientLoss(BaseLoss):
"""Gradient Loss Module.
This module calculates the gradient loss to preserve edge information.
Gradient loss explicitly penalizes differences in image gradients, helping to
preserve structural information and edges :cite:`mathieu2015deep`.
"""
[docs]
def __init__(self):
"""Initialize the GradientLoss module."""
super().__init__()
self.sobel_x = torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], dtype=torch.float32).unsqueeze(0).unsqueeze(0)
self.sobel_y = torch.tensor([[-1, -2, -1], [0, 0, 0], [1, 2, 1]], dtype=torch.float32).unsqueeze(0).unsqueeze(0)
[docs]
def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""Forward pass through the GradientLoss module.
Args:
x (torch.Tensor): The input tensor.
target (torch.Tensor): The target tensor.
Returns:
torch.Tensor: The gradient loss between the input and the target.
"""
device = x.device
self.sobel_x = self.sobel_x.to(device)
self.sobel_y = self.sobel_y.to(device)
b, c, h, w = x.size()
loss = 0.0
for ch in range(c):
# Extract gradients using Sobel operators
x_grad_x = F.conv2d(x[:, ch : ch + 1, :, :], self.sobel_x, padding=1)
x_grad_y = F.conv2d(x[:, ch : ch + 1, :, :], self.sobel_y, padding=1)
target_grad_x = F.conv2d(target[:, ch : ch + 1, :, :], self.sobel_x, padding=1)
target_grad_y = F.conv2d(target[:, ch : ch + 1, :, :], self.sobel_y, padding=1)
# Calculate differences in gradients
loss += F.l1_loss(x_grad_x, target_grad_x) + F.l1_loss(x_grad_y, target_grad_y)
return loss / c
[docs]
@LossRegistry.register_loss()
class PSNRLoss(BaseLoss):
"""Peak Signal-to-Noise Ratio (PSNR) Loss Module.
This module calculates the negative PSNR (to be minimized) between the input and target.
PSNR is a standard metric for image quality assessment :cite:`hore2010image`,
though it doesn't always correlate well with human perception :cite:`huynh2008scope`.
"""
[docs]
def __init__(self, max_val=1.0):
"""Initialize the PSNRLoss module.
Args:
max_val (float): Maximum value of the input tensor. Default is 1.0.
"""
super().__init__()
self.max_val = max_val
[docs]
def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""Forward pass through the PSNRLoss module.
Args:
x (torch.Tensor): The input tensor.
target (torch.Tensor): The target tensor.
Returns:
torch.Tensor: The negative PSNR loss between the input and the target.
"""
mse = F.mse_loss(x, target)
psnr = 20 * torch.log10(self.max_val / torch.sqrt(mse))
# Return negative PSNR since we want to minimize the loss
return -psnr
[docs]
@LossRegistry.register_loss()
class StyleLoss(BaseLoss):
"""Style Loss Module based on Gram matrices.
This module calculates the style loss used in neural style transfer.
Style loss computes the difference between Gram matrices of feature maps,
capturing texture information independent of spatial arrangement :cite:`gatys2016image`.
"""
[docs]
def __init__(self, apply_gram=True, normalize=False, layer_weights=None):
"""Initialize the StyleLoss module.
Args:
apply_gram (bool): Whether to apply Gram matrix computation (True) or use
precomputed Gram matrices as input (False). Default is True.
normalize (bool): Whether to normalize the Gram matrices. Default is False.
layer_weights (dict, optional): Weights for different layers. Default is None
(equal weights for all layers).
"""
super().__init__()
# Initialize common parameters
self.apply_gram = apply_gram
self.normalize = normalize
# Try to initialize VGG-based feature extractor
try:
vgg = models.vgg16(weights=models.VGG16_Weights.DEFAULT).features.eval()
self.feature_extractor = nn.Sequential()
self.style_layers = [0, 5, 10, 17, 24]
self.layer_weights = layer_weights or {f"layer_{i}": 1.0 for i in range(len(self.style_layers))}
i = 0
for layer in vgg.children():
# Classify layer type and assign appropriate name
if isinstance(layer, nn.Conv2d):
i += 1
name = f"conv_{i}"
elif isinstance(layer, nn.ReLU):
name = f"relu_{i}"
layer = nn.ReLU(inplace=False)
elif isinstance(layer, nn.MaxPool2d):
name = f"pool_{i}"
elif isinstance(layer, nn.BatchNorm2d):
name = f"bn_{i}"
else:
# Generic name for unrecognized layers
name = f"unknown_{i}"
self.feature_extractor.add_module(name, layer)
# Freeze parameters
for param in self.feature_extractor.parameters():
param.requires_grad = False
except Exception:
# Fall back to minimal configuration for graceful degradation
self.feature_extractor = nn.Sequential()
self.style_layers = [0]
self.layer_weights = layer_weights or {"layer_0": 1.0}
[docs]
def gram_matrix(self, x):
"""Calculate Gram matrix from features.
Args:
x (torch.Tensor): Feature tensor.
Returns:
torch.Tensor: Gram matrix.
"""
batch_size, channels, height, width = x.size()
# Make tensor contiguous before reshaping
x_cont = x.contiguous()
features = x_cont.view(batch_size, channels, height * width)
features_t = features.transpose(1, 2)
gram = features.bmm(features_t)
# Normalize if requested
if self.normalize:
gram = gram / (channels * height * width)
return gram
[docs]
def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""Forward pass through the StyleLoss module.
Args:
x (torch.Tensor): The input tensor.
target (torch.Tensor): The target tensor.
Returns:
torch.Tensor: The style loss between the input and the target.
"""
# Handle precomputed Gram matrices case
if not self.apply_gram:
return F.mse_loss(x, target)
# Input shape validation for image inputs
if x.dim() != 4 or target.dim() != 4:
raise ValueError("Input tensors must be 4D (batch, channels, height, width)")
if x.size(1) != 3 or target.size(1) != 3:
raise ValueError("Input tensors must have 3 channels (RGB)")
# Normalize to match VGG input requirements
mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(x.device)
std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(x.device)
x = (x - mean) / std
target = (target - mean) / std
loss = 0.0
# Extract features and calculate gram matrices
for i, layer in enumerate(self.feature_extractor):
x = layer(x)
target = layer(target)
if i in self.style_layers:
layer_idx = self.style_layers.index(i)
layer_name = f"layer_{layer_idx}"
weight = self.layer_weights.get(layer_name, 1.0)
x_gram = self.gram_matrix(x)
target_gram = self.gram_matrix(target)
loss += weight * F.mse_loss(x_gram, target_gram)
return loss
[docs]
@LossRegistry.register_loss()
class FocalLoss(BaseLoss):
"""Focal Loss Module for dealing with class imbalance.
This implementation works for both binary and multi-class problems.
Focal loss addresses class imbalance by down-weighting well-classified examples,
focusing training on difficult examples :cite:`lin2017focal`.
"""
[docs]
def __init__(self, alpha=None, gamma=2.0, reduction="mean"):
"""Initialize the FocalLoss module.
Args:
alpha (float or tensor): Weighting factor for the rare class. Default is None.
gamma (float): Focusing parameter. Default is 2.0.
reduction (str): Specifies the reduction to apply to the output. Default is 'mean'.
"""
super().__init__()
self.alpha = alpha
self.gamma = gamma
self.reduction = reduction
[docs]
def forward(self, inputs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
"""Forward pass through the FocalLoss module.
Args:
inputs (torch.Tensor): The input logits tensor.
targets (torch.Tensor): The target tensor.
Returns:
torch.Tensor: The focal loss between the input and the target.
"""
# For binary classification
if inputs.shape[1] == 1 or len(inputs.shape) == 1:
inputs = inputs.view(-1)
targets = targets.view(-1)
bce_loss = F.binary_cross_entropy_with_logits(inputs, targets.float(), reduction="none")
pt = torch.exp(-bce_loss)
focal_loss = (1 - pt) ** self.gamma * bce_loss
if self.alpha is not None:
focal_loss = self.alpha * targets + (1 - self.alpha) * (1 - targets) * focal_loss
# For multi-class classification
else:
log_softmax = F.log_softmax(inputs, dim=1)
ce_loss = F.nll_loss(log_softmax, targets, reduction="none")
pt = torch.exp(-ce_loss)
focal_loss = (1 - pt) ** self.gamma * ce_loss
if self.alpha is not None:
alpha_tensor = self.alpha.to(inputs.device)
alpha_t = alpha_tensor.gather(0, targets)
focal_loss = alpha_t * focal_loss
if self.reduction == "mean":
return focal_loss.mean()
elif self.reduction == "sum":
return focal_loss.sum()
else: # 'none'
return focal_loss
@LossRegistry.register_loss()
class ElasticLoss(BaseLoss):
"""Elastic Loss combines L1 and L2 losses.
This loss function smoothly transitions between L1 and L2 behavior.
Elastic net regularization combines the benefits of both L1 and L2 penalties,
offering robustness to outliers while maintaining smoothness :cite:`zou2005regularization`.
"""
def __init__(self, beta=1.0, alpha=0.5, reduction="mean"):
"""Initialize the ElasticLoss module.
Args:
beta (float): Balance parameter between L1 and L2. Default is 1.0.
alpha (float): Weight parameter controlling L1 vs L2 contribution (0.5 means equal mix). Default is 0.5.
reduction (str): Reduction method ('mean', 'sum', 'none'). Default is 'mean'.
"""
super().__init__()
self.beta = max(beta, 1e-8) # Prevent division by zero
self.alpha = alpha
self.reduction = reduction
def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""Forward pass through the ElasticLoss module.
Args:
x (torch.Tensor): The input tensor.
target (torch.Tensor): The target tensor.
Returns:
torch.Tensor: The elastic loss between the input and the target.
"""
diff = x - target
abs_diff = torch.abs(diff)
squared_diff = diff**2
# Handle edge cases based on alpha and beta values
if self.alpha >= 0.99: # Close to 1.0, act like pure L1
point_losses = abs_diff
elif self.alpha <= 0.01: # Close to 0.0, act like pure L2
point_losses = squared_diff # Removed 0.5 factor to match standard MSE
else:
# Compute weighted combination of L1 and L2 loss
l1_component = abs_diff
l2_component = 0.5 * squared_diff / self.beta
# Apply smooth transition between L1 and L2 based on difference magnitude
point_losses = torch.where(abs_diff < self.beta, self.alpha * l2_component, (1.0 - self.alpha) * l1_component + self.alpha * self.beta / 2.0)
# Apply reduction
if self.reduction == "mean":
return point_losses.mean()
elif self.reduction == "sum":
return point_losses.sum()
else: # 'none'
return point_losses
__all__ = ["MSELoss", "CombinedLoss", "MSELPIPSLoss", "LPIPSLoss", "SSIMLoss", "MSSSIMLoss", "L1Loss", "VGGLoss", "TotalVariationLoss", "GradientLoss", "PSNRLoss", "StyleLoss", "FocalLoss"]