Source code for kaira.losses.composite

"""Composite loss module for combining multiple loss functions.

This module provides functionality to create composite losses that combine
multiple individual loss functions with customizable weights. This is particularly
useful for cases where training requires optimizing multiple objectives
simultaneously, such as reconstruction loss combined with adversarial loss
or perceptual loss.

The composite approach addresses several common challenges in training:
- Different losses capture different aspects of the desired output
- Some applications require balancing multiple objectives
- Custom training schemes may need to emphasize certain properties over others
"""

from typing import Dict, Optional

import torch
from torch import nn

from .base import BaseLoss


[docs] class CompositeLoss(BaseLoss): """A loss that combines multiple loss functions with optional weighting. This class allows for the creation of custom loss functions by combining multiple individual losses with specified weights. It's useful when training requires optimizing multiple objectives simultaneously, such as combining pixel-wise reconstruction loss with perceptual or adversarial losses. The composite approach can balance the trade-offs between different loss terms. For example, L1 loss promotes pixel accuracy, while perceptual loss promotes visual quality. By combining them, you can achieve outputs that satisfy multiple criteria. Example: >>> from kaira.losses import L1Loss, SSIMLoss, PerceptualLoss >>> from kaira.losses.composite import CompositeLoss >>> >>> # Create individual losses >>> l1_loss = L1Loss() >>> ssim_loss = SSIMLoss() >>> perceptual_loss = PerceptualLoss() >>> >>> # Create a composite loss with custom weights >>> losses = {"l1": l1_loss, "ssim": ssim_loss, "perceptual": perceptual_loss} >>> weights = {"l1": 1.0, "ssim": 0.5, "perceptual": 0.1} >>> composite_loss = CompositeLoss(losses=losses, weights=weights) >>> >>> # Train a model with the composite loss >>> output = model(input_data) >>> loss = composite_loss(output, target) >>> loss.backward() >>> optimizer.step() """
[docs] def __init__(self, losses: Dict[str, BaseLoss], weights: Optional[Dict[str, float]] = None): """Initialize composite loss with component losses and their weights. Args: losses (Dict[str, BaseLoss]): Dictionary mapping loss names to loss objects. Each loss should be a subclass of BaseLoss. weights (Optional[Dict[str, float]]): Dictionary mapping loss names to their relative importance. If None, equal weights are assigned to all losses. Weights are automatically normalized to sum to 1.0. Raises: ValueError: If weights dictionary contains keys not present in losses dictionary. """ super().__init__() self.losses = nn.ModuleDict(losses) # Validate weights if weights is not None: for name in weights: if name not in losses: raise ValueError(f"Weight key '{name}' not found in losses dictionary") self.weights = weights or {name: 1.0 for name in losses} # Normalize weights total = sum(self.weights.values()) self.weights = {k: v / total for k, v in self.weights.items()}
[docs] def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """Compute the weighted combination of all component losses. Evaluates each loss on the input tensors and combines them according to the normalized weights specified during initialization. Args: x (torch.Tensor): First input tensor, typically the prediction or generated output target (torch.Tensor): Second input tensor, typically the target or ground truth Returns: torch.Tensor: Weighted sum of all loss values as a single scalar tensor. """ result = torch.tensor(0.0, device=x.device) for name, loss in self.losses.items(): if name in self.weights: loss_value = loss(x, target) if isinstance(loss_value, tuple): loss_value = loss_value[0] # Take first value if tuple result = result + self.weights[name] * loss_value return result
[docs] def get_individual_losses(self, x: torch.Tensor, target: torch.Tensor) -> Dict[str, torch.Tensor]: """Compute all individual losses separately without combining them. This method is an alias for compute_individual for backward compatibility. Args: x (torch.Tensor): First input tensor, typically the prediction or generated output target (torch.Tensor): Second input tensor, typically the target or ground truth Returns: Dict[str, torch.Tensor]: Dictionary mapping loss names to their computed values. """ return self.compute_individual(x, target)
[docs] def compute_individual(self, x: torch.Tensor, target: torch.Tensor) -> Dict[str, torch.Tensor]: """Compute all individual losses separately without combining them. This method is useful for debugging and monitoring individual loss components during training. Args: x (torch.Tensor): First input tensor, typically the prediction or generated output target (torch.Tensor): Second input tensor, typically the target or ground truth Returns: Dict[str, torch.Tensor]: Dictionary mapping loss names to their computed values. """ results = {} for name, loss in self.losses.items(): results[name] = loss(x, target) return results
[docs] def add_loss(self, name: str, loss, weight: float = 1.0): """Add a new loss to the composite loss. Args: name (str): Name for the loss loss (BaseLoss): Loss module to add weight (float): Weight for the new loss (will be preserved exactly as provided) Returns: None: Updates the loss and weight dictionaries in-place Raises: ValueError: If a loss with the given name already exists """ # Check if loss name already exists if name in self.losses: raise ValueError(f"Loss '{name}' already exists in the composite loss") # Add loss to ModuleDict self.losses[name] = loss # Preserve the exact weight for the new loss, adjust existing weights proportionally remaining_weight = 1.0 - weight current_total = sum(self.weights.values()) # Adjust existing weights to maintain the sum of 1.0 if current_total > 0: # Avoid division by zero for k in self.weights: self.weights[k] = self.weights[k] * remaining_weight / current_total # Set the weight for the new loss self.weights[name] = weight
def __str__(self) -> str: """Get a string representation of the composite loss with weights. Returns: str: String representation of the composite loss including components and weights """ base_str = f"{self.__class__.__name__}(\n" base_str += " (losses): ModuleDict(\n" for name, loss in self.losses.items(): weight = self.weights.get(name, 0.0) base_str += f" ({name}): {loss.__class__.__name__} # weight={weight:.3f}\n" base_str += " )\n)" return base_str