kaira.losses.adversarial.LSGANLoss

Inheritance diagram of LSGANLoss

Inheritance diagram for LSGANLoss

class kaira.losses.adversarial.LSGANLoss(reduction='mean')[source]

Bases: BaseLoss

Least Squares GAN Loss Module.

This module implements the LSGAN loss from Mao et al. 2017.

Methods

__init__

Initialize the LSGANLoss module.

forward

Forward pass through the LSGANLoss module.

forward_discriminator

Forward pass for discriminator.

forward_generator

Forward pass for generator.

__init__(reduction='mean')[source]

Initialize the LSGANLoss module.

Parameters:

reduction (str) – Reduction method (‘mean’, ‘sum’, or ‘none’). Default is ‘mean’.

forward_discriminator(real_pred: Tensor, fake_pred: Tensor) Tensor[source]

Forward pass for discriminator.

Parameters:
  • real_pred (torch.Tensor) – Discriminator outputs for real data.

  • fake_pred (torch.Tensor) – Discriminator outputs for fake data.

Returns:

Discriminator loss.

Return type:

torch.Tensor

forward_generator(fake_pred: Tensor) Tensor[source]

Forward pass for generator.

Parameters:

fake_pred (torch.Tensor) – Discriminator outputs for fake data.

Returns:

Generator loss.

Return type:

torch.Tensor

forward(pred: Tensor, is_real: bool, for_discriminator: bool = True) Tensor[source]

Forward pass through the LSGANLoss module.

Parameters:
  • pred (torch.Tensor) – Discriminator outputs.

  • is_real (bool) – Whether predictions are for real data.

  • for_discriminator (bool) – Whether calculating loss for discriminator. Default is True.

Returns:

The LSGAN loss.

Return type:

torch.Tensor