kaira.losses.CompositeLoss

Inheritance diagram of CompositeLoss

Inheritance diagram for CompositeLoss

class kaira.losses.CompositeLoss(losses: Dict[str, BaseLoss], weights: Dict[str, float] | None = None)[source]

Bases: BaseLoss

A loss that combines multiple loss functions with optional weighting.

This class allows for the creation of custom loss functions by combining multiple individual losses with specified weights. It’s useful when training requires optimizing multiple objectives simultaneously, such as combining pixel-wise reconstruction loss with perceptual or adversarial losses.

The composite approach can balance the trade-offs between different loss terms. For example, L1 loss promotes pixel accuracy, while perceptual loss promotes visual quality. By combining them, you can achieve outputs that satisfy multiple criteria.

Example

>>> from kaira.losses import L1Loss, SSIMLoss, PerceptualLoss
>>> from kaira.losses.composite import CompositeLoss
>>>
>>> # Create individual losses
>>> l1_loss = L1Loss()
>>> ssim_loss = SSIMLoss()
>>> perceptual_loss = PerceptualLoss()
>>>
>>> # Create a composite loss with custom weights
>>> losses = {"l1": l1_loss, "ssim": ssim_loss, "perceptual": perceptual_loss}
>>> weights = {"l1": 1.0, "ssim": 0.5, "perceptual": 0.1}
>>> composite_loss = CompositeLoss(losses=losses, weights=weights)
>>>
>>> # Train a model with the composite loss
>>> output = model(input_data)
>>> loss = composite_loss(output, target)
>>> loss.backward()
>>> optimizer.step()

Methods

__init__

Initialize composite loss with component losses and their weights.

add_loss

Add a new loss to the composite loss.

compute_individual

Compute all individual losses separately without combining them.

forward

Compute the weighted combination of all component losses.

get_individual_losses

Compute all individual losses separately without combining them.

__init__(losses: Dict[str, BaseLoss], weights: Dict[str, float] | None = None)[source]

Initialize composite loss with component losses and their weights.

Parameters:
  • losses (Dict[str, BaseLoss]) – Dictionary mapping loss names to loss objects. Each loss should be a subclass of BaseLoss.

  • weights (Optional[Dict[str, float]]) – Dictionary mapping loss names to their relative importance. If None, equal weights are assigned to all losses. Weights are automatically normalized to sum to 1.0.

Raises:

ValueError – If weights dictionary contains keys not present in losses dictionary.

forward(x: Tensor, target: Tensor) Tensor[source]

Compute the weighted combination of all component losses.

Evaluates each loss on the input tensors and combines them according to the normalized weights specified during initialization.

Parameters:
  • x (torch.Tensor) – First input tensor, typically the prediction or generated output

  • target (torch.Tensor) – Second input tensor, typically the target or ground truth

Returns:

Weighted sum of all loss values as a single scalar tensor.

Return type:

torch.Tensor

get_individual_losses(x: Tensor, target: Tensor) Dict[str, Tensor][source]

Compute all individual losses separately without combining them.

This method is an alias for compute_individual for backward compatibility.

Parameters:
  • x (torch.Tensor) – First input tensor, typically the prediction or generated output

  • target (torch.Tensor) – Second input tensor, typically the target or ground truth

Returns:

Dictionary mapping loss names to their computed values.

Return type:

Dict[str, torch.Tensor]

compute_individual(x: Tensor, target: Tensor) Dict[str, Tensor][source]

Compute all individual losses separately without combining them.

This method is useful for debugging and monitoring individual loss components during training.

Parameters:
  • x (torch.Tensor) – First input tensor, typically the prediction or generated output

  • target (torch.Tensor) – Second input tensor, typically the target or ground truth

Returns:

Dictionary mapping loss names to their computed values.

Return type:

Dict[str, torch.Tensor]

add_loss(name: str, loss, weight: float = 1.0)[source]

Add a new loss to the composite loss.

Parameters:
  • name (str) – Name for the loss

  • loss (BaseLoss) – Loss module to add

  • weight (float) – Weight for the new loss (will be preserved exactly as provided)

Returns:

Updates the loss and weight dictionaries in-place

Return type:

None

Raises:

ValueError – If a loss with the given name already exists