kaira.losses.adversarial.R1GradientPenalty

Inheritance diagram for R1GradientPenalty
- class kaira.losses.adversarial.R1GradientPenalty(gamma=10.0)[source]
Bases:
BaseLossR1 Gradient Penalty Module for GANs.
This module implements the R1 gradient penalty for GAN training.
Methods
Initialize the R1GradientPenalty module.
Forward pass through the R1GradientPenalty module.
- __init__(gamma=10.0)[source]
Initialize the R1GradientPenalty module.
- Parameters:
gamma (float) – Weight for the gradient penalty. Default is 10.0.
- forward(real_data: Tensor, real_outputs: Tensor) Tensor[source]
Forward pass through the R1GradientPenalty module.
- Parameters:
real_data (torch.Tensor) – Real input data.
real_outputs (torch.Tensor) – Discriminator outputs for real data.
- Returns:
The R1 gradient penalty.
- Return type: