kaira.losses.adversarial.R1GradientPenalty

Inheritance diagram of R1GradientPenalty

Inheritance diagram for R1GradientPenalty

class kaira.losses.adversarial.R1GradientPenalty(gamma=10.0)[source]

Bases: BaseLoss

R1 Gradient Penalty Module for GANs.

This module implements the R1 gradient penalty for GAN training.

Methods

__init__

Initialize the R1GradientPenalty module.

forward

Forward pass through the R1GradientPenalty module.

__init__(gamma=10.0)[source]

Initialize the R1GradientPenalty module.

Parameters:

gamma (float) – Weight for the gradient penalty. Default is 10.0.

forward(real_data: Tensor, real_outputs: Tensor) Tensor[source]

Forward pass through the R1GradientPenalty module.

Parameters:
Returns:

The R1 gradient penalty.

Return type:

torch.Tensor