kaira.losses.adversarial.WassersteinGANLoss

Inheritance diagram of WassersteinGANLoss

Inheritance diagram for WassersteinGANLoss

class kaira.losses.adversarial.WassersteinGANLoss[source]

Bases: BaseLoss

Wasserstein GAN Loss Module.

This module implements the WGAN loss from Arjovsky et al. 2017.

Methods

__init__

Initialize the WassersteinGANLoss module.

forward

Forward pass through the WassersteinGANLoss module.

forward_discriminator

Forward pass for discriminator.

forward_generator

Forward pass for generator.

__init__()[source]

Initialize the WassersteinGANLoss module.

forward_discriminator(real_pred: Tensor, fake_pred: Tensor) Tensor[source]

Forward pass for discriminator.

Parameters:
  • real_pred (torch.Tensor) – Discriminator outputs for real data.

  • fake_pred (torch.Tensor) – Discriminator outputs for fake data.

Returns:

Discriminator loss.

Return type:

torch.Tensor

forward_generator(fake_pred: Tensor) Tensor[source]

Forward pass for generator.

Parameters:

fake_pred (torch.Tensor) – Discriminator outputs for fake data.

Returns:

Generator loss.

Return type:

torch.Tensor

forward(pred: Tensor, is_real: bool, for_discriminator: bool = True) Tensor[source]

Forward pass through the WassersteinGANLoss module.

Parameters:
  • pred (torch.Tensor) – Discriminator outputs.

  • is_real (bool) – Whether predictions are for real data.

  • for_discriminator (bool) – Whether calculating loss for discriminator. Default is True.

Returns:

The Wasserstein loss.

Return type:

torch.Tensor