kaira.losses.CompositeLoss

Inheritance diagram for CompositeLoss
- class kaira.losses.CompositeLoss(losses: Dict[str, BaseLoss], weights: Dict[str, float] | None = None)[source]
Bases:
BaseLossA loss that combines multiple loss functions with optional weighting.
This class allows for the creation of custom loss functions by combining multiple individual losses with specified weights. It’s useful when training requires optimizing multiple objectives simultaneously, such as combining pixel-wise reconstruction loss with perceptual or adversarial losses.
The composite approach can balance the trade-offs between different loss terms. For example, L1 loss promotes pixel accuracy, while perceptual loss promotes visual quality. By combining them, you can achieve outputs that satisfy multiple criteria.
Example
>>> from kaira.losses import L1Loss, SSIMLoss, PerceptualLoss >>> from kaira.losses.composite import CompositeLoss >>> >>> # Create individual losses >>> l1_loss = L1Loss() >>> ssim_loss = SSIMLoss() >>> perceptual_loss = PerceptualLoss() >>> >>> # Create a composite loss with custom weights >>> losses = {"l1": l1_loss, "ssim": ssim_loss, "perceptual": perceptual_loss} >>> weights = {"l1": 1.0, "ssim": 0.5, "perceptual": 0.1} >>> composite_loss = CompositeLoss(losses=losses, weights=weights) >>> >>> # Train a model with the composite loss >>> output = model(input_data) >>> loss = composite_loss(output, target) >>> loss.backward() >>> optimizer.step()
Methods
Initialize composite loss with component losses and their weights.
Add a new loss to the composite loss.
Compute all individual losses separately without combining them.
Compute the weighted combination of all component losses.
Compute all individual losses separately without combining them.
- __init__(losses: Dict[str, BaseLoss], weights: Dict[str, float] | None = None)[source]
Initialize composite loss with component losses and their weights.
- Parameters:
losses (Dict[str, BaseLoss]) – Dictionary mapping loss names to loss objects. Each loss should be a subclass of BaseLoss.
weights (Optional[Dict[str, float]]) – Dictionary mapping loss names to their relative importance. If None, equal weights are assigned to all losses. Weights are automatically normalized to sum to 1.0.
- Raises:
ValueError – If weights dictionary contains keys not present in losses dictionary.
- forward(x: Tensor, target: Tensor) Tensor[source]
Compute the weighted combination of all component losses.
Evaluates each loss on the input tensors and combines them according to the normalized weights specified during initialization.
- Parameters:
x (torch.Tensor) – First input tensor, typically the prediction or generated output
target (torch.Tensor) – Second input tensor, typically the target or ground truth
- Returns:
Weighted sum of all loss values as a single scalar tensor.
- Return type:
- get_individual_losses(x: Tensor, target: Tensor) Dict[str, Tensor][source]
Compute all individual losses separately without combining them.
This method is an alias for compute_individual for backward compatibility.
- Parameters:
x (torch.Tensor) – First input tensor, typically the prediction or generated output
target (torch.Tensor) – Second input tensor, typically the target or ground truth
- Returns:
Dictionary mapping loss names to their computed values.
- Return type:
Dict[str, torch.Tensor]
- compute_individual(x: Tensor, target: Tensor) Dict[str, Tensor][source]
Compute all individual losses separately without combining them.
This method is useful for debugging and monitoring individual loss components during training.
- Parameters:
x (torch.Tensor) – First input tensor, typically the prediction or generated output
target (torch.Tensor) – Second input tensor, typically the target or ground truth
- Returns:
Dictionary mapping loss names to their computed values.
- Return type:
Dict[str, torch.Tensor]
- add_loss(name: str, loss, weight: float = 1.0)[source]
Add a new loss to the composite loss.
- Parameters:
- Returns:
Updates the loss and weight dictionaries in-place
- Return type:
None
- Raises:
ValueError – If a loss with the given name already exists