kaira.losses.adversarial.VanillaGANLoss

Inheritance diagram for VanillaGANLoss
- class kaira.losses.adversarial.VanillaGANLoss(reduction='mean')[source]
Bases:
BaseLossVanilla GAN Loss Module.
This module implements the original GAN loss from Goodfellow et al. 2014.
Methods
Initialize the VanillaGANLoss module.
Forward pass through the VanillaGANLoss module.
Forward pass for discriminator.
Forward pass for generator.
- __init__(reduction='mean')[source]
Initialize the VanillaGANLoss module.
- Parameters:
reduction (str) – Reduction method (‘mean’, ‘sum’, or ‘none’). Default is ‘mean’.
- forward_discriminator(real_logits: Tensor, fake_logits: Tensor) Tensor[source]
Forward pass for discriminator.
- Parameters:
real_logits (torch.Tensor) – Discriminator outputs for real data.
fake_logits (torch.Tensor) – Discriminator outputs for fake data.
- Returns:
Discriminator loss.
- Return type:
- forward_generator(fake_logits: Tensor) Tensor[source]
Forward pass for generator.
- Parameters:
fake_logits (torch.Tensor) – Discriminator outputs for fake data.
- Returns:
Generator loss.
- Return type: