kaira.losses.adversarial.HingeLoss

Inheritance diagram of HingeLoss

Inheritance diagram for HingeLoss

class kaira.losses.adversarial.HingeLoss[source]

Bases: BaseLoss

Hinge Loss Module for GANs.

This module implements the hinge loss commonly used in spectral normalization GAN.

Methods

__init__

Initialize the HingeLoss module.

forward

Forward pass through the HingeLoss module.

forward_discriminator

Forward pass for discriminator.

forward_generator

Forward pass for generator.

__init__()[source]

Initialize the HingeLoss module.

forward_discriminator(real_pred: Tensor, fake_pred: Tensor) Tensor[source]

Forward pass for discriminator.

Parameters:
  • real_pred (torch.Tensor) – Discriminator outputs for real data.

  • fake_pred (torch.Tensor) – Discriminator outputs for fake data.

Returns:

Discriminator loss.

Return type:

torch.Tensor

forward_generator(fake_pred: Tensor) Tensor[source]

Forward pass for generator.

Parameters:

fake_pred (torch.Tensor) – Discriminator outputs for fake data.

Returns:

Generator loss.

Return type:

torch.Tensor

forward(pred: Tensor, is_real: bool, for_discriminator: bool = True) Tensor[source]

Forward pass through the HingeLoss module.

Parameters:
  • pred (torch.Tensor) – Discriminator outputs.

  • is_real (bool) – Whether predictions are for real data.

  • for_discriminator (bool) – Whether calculating loss for discriminator. Default is True.

Returns:

The hinge loss.

Return type:

torch.Tensor