kaira.losses.adversarial.HingeLoss

Inheritance diagram for HingeLoss
- class kaira.losses.adversarial.HingeLoss[source]
Bases:
BaseLossHinge Loss Module for GANs.
This module implements the hinge loss commonly used in spectral normalization GAN.
Methods
Initialize the HingeLoss module.
Forward pass through the HingeLoss module.
Forward pass for discriminator.
Forward pass for generator.
- forward_discriminator(real_pred: Tensor, fake_pred: Tensor) Tensor[source]
Forward pass for discriminator.
- Parameters:
real_pred (torch.Tensor) – Discriminator outputs for real data.
fake_pred (torch.Tensor) – Discriminator outputs for fake data.
- Returns:
Discriminator loss.
- Return type:
- forward_generator(fake_pred: Tensor) Tensor[source]
Forward pass for generator.
- Parameters:
fake_pred (torch.Tensor) – Discriminator outputs for fake data.
- Returns:
Generator loss.
- Return type:
- forward(pred: Tensor, is_real: bool, for_discriminator: bool = True) Tensor[source]
Forward pass through the HingeLoss module.
- Parameters:
pred (torch.Tensor) – Discriminator outputs.
is_real (bool) – Whether predictions are for real data.
for_discriminator (bool) – Whether calculating loss for discriminator. Default is True.
- Returns:
The hinge loss.
- Return type: