kaira.losses.image.TotalVariationLoss

Inheritance diagram of TotalVariationLoss

Inheritance diagram for TotalVariationLoss

class kaira.losses.image.TotalVariationLoss[source]

Bases: BaseLoss

Total Variation Loss Module.

This module calculates the total variation loss to encourage spatial smoothness. Total variation regularization reduces noise while preserving edges in images [Rudin et al., 1992] [Mahendran and Vedaldi, 2015].

Methods

__init__

Initialize the TotalVariationLoss module.

forward

Forward pass through the TotalVariationLoss module.

__init__()[source]

Initialize the TotalVariationLoss module.

forward(x: Tensor) Tensor[source]

Forward pass through the TotalVariationLoss module.

Parameters:

x (torch.Tensor) – The input tensor of shape (B, C, H, W).

Returns:

The total variation loss of the input.

Return type:

torch.Tensor