kaira.losses.adversarial.WassersteinGANLoss

Inheritance diagram for WassersteinGANLoss
- class kaira.losses.adversarial.WassersteinGANLoss[source]
Bases:
BaseLossWasserstein GAN Loss Module.
This module implements the WGAN loss from Arjovsky et al. 2017.
Methods
Initialize the WassersteinGANLoss module.
Forward pass through the WassersteinGANLoss module.
Forward pass for discriminator.
Forward pass for generator.
- forward_discriminator(real_pred: Tensor, fake_pred: Tensor) Tensor[source]
Forward pass for discriminator.
- Parameters:
real_pred (torch.Tensor) – Discriminator outputs for real data.
fake_pred (torch.Tensor) – Discriminator outputs for fake data.
- Returns:
Discriminator loss.
- Return type:
- forward_generator(fake_pred: Tensor) Tensor[source]
Forward pass for generator.
- Parameters:
fake_pred (torch.Tensor) – Discriminator outputs for fake data.
- Returns:
Generator loss.
- Return type:
- forward(pred: Tensor, is_real: bool, for_discriminator: bool = True) Tensor[source]
Forward pass through the WassersteinGANLoss module.
- Parameters:
pred (torch.Tensor) – Discriminator outputs.
is_real (bool) – Whether predictions are for real data.
for_discriminator (bool) – Whether calculating loss for discriminator. Default is True.
- Returns:
The Wasserstein loss.
- Return type: