kaira.losses.adversarial.LSGANLoss

Inheritance diagram for LSGANLoss
- class kaira.losses.adversarial.LSGANLoss(reduction='mean')[source]
Bases:
BaseLossLeast Squares GAN Loss Module.
This module implements the LSGAN loss from Mao et al. 2017.
Methods
Initialize the LSGANLoss module.
Forward pass through the LSGANLoss module.
Forward pass for discriminator.
Forward pass for generator.
- __init__(reduction='mean')[source]
Initialize the LSGANLoss module.
- Parameters:
reduction (str) – Reduction method (‘mean’, ‘sum’, or ‘none’). Default is ‘mean’.
- forward_discriminator(real_pred: Tensor, fake_pred: Tensor) Tensor[source]
Forward pass for discriminator.
- Parameters:
real_pred (torch.Tensor) – Discriminator outputs for real data.
fake_pred (torch.Tensor) – Discriminator outputs for fake data.
- Returns:
Discriminator loss.
- Return type:
- forward_generator(fake_pred: Tensor) Tensor[source]
Forward pass for generator.
- Parameters:
fake_pred (torch.Tensor) – Discriminator outputs for fake data.
- Returns:
Generator loss.
- Return type:
- forward(pred: Tensor, is_real: bool, for_discriminator: bool = True) Tensor[source]
Forward pass through the LSGANLoss module.
- Parameters:
pred (torch.Tensor) – Discriminator outputs.
is_real (bool) – Whether predictions are for real data.
for_discriminator (bool) – Whether calculating loss for discriminator. Default is True.
- Returns:
The LSGAN loss.
- Return type: