kaira.losses.image.CombinedLoss

Inheritance diagram of CombinedLoss

Inheritance diagram for CombinedLoss

class kaira.losses.image.CombinedLoss(losses: Sequence[BaseLoss], weights: list[float])[source]

Bases: BaseLoss

Combined Loss Module.

This module combines multiple loss functions into a single loss function. Combining multiple losses is a common approach to improve image quality by addressing different aspects of visual perception [Zhao et al., 2016].

Methods

__init__

Initialize the CombinedLoss module.

forward

Forward pass through the CombinedLoss module.

__init__(losses: Sequence[BaseLoss], weights: list[float])[source]

Initialize the CombinedLoss module.

Parameters:
  • losses (Sequence[BaseLoss]) – A list of loss functions to combine.

  • weights (list[float]) – A list of weights for each loss function.

forward(x: Tensor, target: Tensor) Tensor[source]

Forward pass through the CombinedLoss module.

Parameters:
Returns:

The combined loss between the input and the target.

Return type:

torch.Tensor