kaira.losses.adversarial.VanillaGANLoss

Inheritance diagram of VanillaGANLoss

Inheritance diagram for VanillaGANLoss

class kaira.losses.adversarial.VanillaGANLoss(reduction='mean')[source]

Bases: BaseLoss

Vanilla GAN Loss Module.

This module implements the original GAN loss from Goodfellow et al. 2014.

Methods

__init__

Initialize the VanillaGANLoss module.

forward

Forward pass through the VanillaGANLoss module.

forward_discriminator

Forward pass for discriminator.

forward_generator

Forward pass for generator.

__init__(reduction='mean')[source]

Initialize the VanillaGANLoss module.

Parameters:

reduction (str) – Reduction method (‘mean’, ‘sum’, or ‘none’). Default is ‘mean’.

forward_discriminator(real_logits: Tensor, fake_logits: Tensor) Tensor[source]

Forward pass for discriminator.

Parameters:
  • real_logits (torch.Tensor) – Discriminator outputs for real data.

  • fake_logits (torch.Tensor) – Discriminator outputs for fake data.

Returns:

Discriminator loss.

Return type:

torch.Tensor

forward_generator(fake_logits: Tensor) Tensor[source]

Forward pass for generator.

Parameters:

fake_logits (torch.Tensor) – Discriminator outputs for fake data.

Returns:

Generator loss.

Return type:

torch.Tensor

forward(discriminator_pred: Tensor, is_real: bool) Tensor[source]

Forward pass through the VanillaGANLoss module.

Parameters:
  • discriminator_pred (torch.Tensor) – Discriminator outputs.

  • is_real (bool) – Whether predictions are for real data.

Returns:

The GAN loss.

Return type:

torch.Tensor