"""Adversarial Losses module for Kaira.
This module contains various adversarial loss functions for GAN-based training.
"""
import torch
import torch.nn.functional as F
from .base import BaseLoss
from .registry import LossRegistry
[docs]
@LossRegistry.register_loss()
class VanillaGANLoss(BaseLoss):
"""Vanilla GAN Loss Module.
This module implements the original GAN loss from Goodfellow et al. 2014.
"""
[docs]
def __init__(self, reduction="mean"):
"""Initialize the VanillaGANLoss module.
Args:
reduction (str): Reduction method ('mean', 'sum', or 'none'). Default is 'mean'.
"""
super().__init__()
self.reduction = reduction
[docs]
def forward_discriminator(self, real_logits: torch.Tensor, fake_logits: torch.Tensor) -> torch.Tensor:
"""Forward pass for discriminator.
Args:
real_logits (torch.Tensor): Discriminator outputs for real data.
fake_logits (torch.Tensor): Discriminator outputs for fake data.
Returns:
torch.Tensor: Discriminator loss.
"""
real_loss = F.binary_cross_entropy_with_logits(real_logits, torch.ones_like(real_logits), reduction=self.reduction)
fake_loss = F.binary_cross_entropy_with_logits(fake_logits, torch.zeros_like(fake_logits), reduction=self.reduction)
return real_loss + fake_loss
[docs]
def forward_generator(self, fake_logits: torch.Tensor) -> torch.Tensor:
"""Forward pass for generator.
Args:
fake_logits (torch.Tensor): Discriminator outputs for fake data.
Returns:
torch.Tensor: Generator loss.
"""
return F.binary_cross_entropy_with_logits(fake_logits, torch.ones_like(fake_logits), reduction=self.reduction)
[docs]
def forward(self, discriminator_pred: torch.Tensor, is_real: bool) -> torch.Tensor:
"""Forward pass through the VanillaGANLoss module.
Args:
discriminator_pred (torch.Tensor): Discriminator outputs.
is_real (bool): Whether predictions are for real data.
Returns:
torch.Tensor: The GAN loss.
"""
target = torch.ones_like(discriminator_pred) if is_real else torch.zeros_like(discriminator_pred)
return F.binary_cross_entropy_with_logits(discriminator_pred, target, reduction=self.reduction)
[docs]
@LossRegistry.register_loss()
class LSGANLoss(BaseLoss):
"""Least Squares GAN Loss Module.
This module implements the LSGAN loss from Mao et al. 2017.
"""
[docs]
def __init__(self, reduction="mean"):
"""Initialize the LSGANLoss module.
Args:
reduction (str): Reduction method ('mean', 'sum', or 'none'). Default is 'mean'.
"""
super().__init__()
self.reduction = reduction
[docs]
def forward_discriminator(self, real_pred: torch.Tensor, fake_pred: torch.Tensor) -> torch.Tensor:
"""Forward pass for discriminator.
Args:
real_pred (torch.Tensor): Discriminator outputs for real data.
fake_pred (torch.Tensor): Discriminator outputs for fake data.
Returns:
torch.Tensor: Discriminator loss.
"""
real_loss = torch.mean((real_pred - 1) ** 2)
fake_loss = torch.mean(fake_pred**2)
return (real_loss + fake_loss) * 0.5
[docs]
def forward_generator(self, fake_pred: torch.Tensor) -> torch.Tensor:
"""Forward pass for generator.
Args:
fake_pred (torch.Tensor): Discriminator outputs for fake data.
Returns:
torch.Tensor: Generator loss.
"""
return torch.mean((fake_pred - 1) ** 2)
[docs]
def forward(self, pred: torch.Tensor, is_real: bool, for_discriminator: bool = True) -> torch.Tensor:
"""Forward pass through the LSGANLoss module.
Args:
pred (torch.Tensor): Discriminator outputs.
is_real (bool): Whether predictions are for real data.
for_discriminator (bool): Whether calculating loss for discriminator. Default is True.
Returns:
torch.Tensor: The LSGAN loss.
"""
if for_discriminator:
if is_real:
return torch.mean((pred - 1) ** 2)
else:
return torch.mean(pred**2)
else: # for generator
return torch.mean((pred - 1) ** 2)
[docs]
@LossRegistry.register_loss()
class WassersteinGANLoss(BaseLoss):
"""Wasserstein GAN Loss Module.
This module implements the WGAN loss from Arjovsky et al. 2017.
"""
[docs]
def __init__(self):
"""Initialize the WassersteinGANLoss module."""
super().__init__()
[docs]
def forward_discriminator(self, real_pred: torch.Tensor, fake_pred: torch.Tensor) -> torch.Tensor:
"""Forward pass for discriminator.
Args:
real_pred (torch.Tensor): Discriminator outputs for real data.
fake_pred (torch.Tensor): Discriminator outputs for fake data.
Returns:
torch.Tensor: Discriminator loss.
"""
return -(torch.mean(real_pred) - torch.mean(fake_pred))
[docs]
def forward_generator(self, fake_pred: torch.Tensor) -> torch.Tensor:
"""Forward pass for generator.
Args:
fake_pred (torch.Tensor): Discriminator outputs for fake data.
Returns:
torch.Tensor: Generator loss.
"""
return -torch.mean(fake_pred)
[docs]
def forward(self, pred: torch.Tensor, is_real: bool, for_discriminator: bool = True) -> torch.Tensor:
"""Forward pass through the WassersteinGANLoss module.
Args:
pred (torch.Tensor): Discriminator outputs.
is_real (bool): Whether predictions are for real data.
for_discriminator (bool): Whether calculating loss for discriminator. Default is True.
Returns:
torch.Tensor: The Wasserstein loss.
"""
if for_discriminator:
if is_real:
return -torch.mean(pred)
else:
return torch.mean(pred)
else: # for generator
return -torch.mean(pred)
[docs]
@LossRegistry.register_loss()
class HingeLoss(BaseLoss):
"""Hinge Loss Module for GANs.
This module implements the hinge loss commonly used in spectral normalization GAN.
"""
[docs]
def __init__(self):
"""Initialize the HingeLoss module."""
super().__init__()
[docs]
def forward_discriminator(self, real_pred: torch.Tensor, fake_pred: torch.Tensor) -> torch.Tensor:
"""Forward pass for discriminator.
Args:
real_pred (torch.Tensor): Discriminator outputs for real data.
fake_pred (torch.Tensor): Discriminator outputs for fake data.
Returns:
torch.Tensor: Discriminator loss.
"""
real_loss = F.relu(1.0 - real_pred).mean()
fake_loss = F.relu(1.0 + fake_pred).mean()
return real_loss + fake_loss
[docs]
def forward_generator(self, fake_pred: torch.Tensor) -> torch.Tensor:
"""Forward pass for generator.
Args:
fake_pred (torch.Tensor): Discriminator outputs for fake data.
Returns:
torch.Tensor: Generator loss.
"""
return -fake_pred.mean()
[docs]
def forward(self, pred: torch.Tensor, is_real: bool, for_discriminator: bool = True) -> torch.Tensor:
"""Forward pass through the HingeLoss module.
Args:
pred (torch.Tensor): Discriminator outputs.
is_real (bool): Whether predictions are for real data.
for_discriminator (bool): Whether calculating loss for discriminator. Default is True.
Returns:
torch.Tensor: The hinge loss.
"""
if for_discriminator:
if is_real:
return F.relu(1.0 - pred).mean()
else:
return F.relu(1.0 + pred).mean()
else: # for generator
return -pred.mean()
[docs]
@LossRegistry.register_loss()
class FeatureMatchingLoss(BaseLoss):
"""Feature Matching Loss Module for GANs.
This module implements the feature matching loss for improved GAN training.
"""
[docs]
def __init__(self):
"""Initialize the FeatureMatchingLoss module."""
super().__init__()
[docs]
def forward(self, real_features: list, fake_features: list) -> torch.Tensor:
"""Forward pass through the FeatureMatchingLoss module.
Args:
real_features (list): List of discriminator features for real data.
fake_features (list): List of discriminator features for fake data.
Returns:
torch.Tensor: The feature matching loss.
"""
loss = 0.0
for real_feat, fake_feat in zip(real_features, fake_features):
loss += F.l1_loss(fake_feat.mean(0), real_feat.detach().mean(0))
return loss
[docs]
@LossRegistry.register_loss()
class R1GradientPenalty(BaseLoss):
"""R1 Gradient Penalty Module for GANs.
This module implements the R1 gradient penalty for GAN training.
"""
[docs]
def __init__(self, gamma=10.0):
"""Initialize the R1GradientPenalty module.
Args:
gamma (float): Weight for the gradient penalty. Default is 10.0.
"""
super().__init__()
self.gamma = gamma
[docs]
def forward(self, real_data: torch.Tensor, real_outputs: torch.Tensor) -> torch.Tensor:
"""Forward pass through the R1GradientPenalty module.
Args:
real_data (torch.Tensor): Real input data.
real_outputs (torch.Tensor): Discriminator outputs for real data.
Returns:
torch.Tensor: The R1 gradient penalty.
"""
# Check if real_data requires gradients
if not real_data.requires_grad:
# If not, issue a warning and return zero penalty
import warnings
warnings.warn("The real_data tensor does not require gradients. The grad will be treated as zero.")
return torch.tensor(0.0, device=real_data.device)
# Create gradient graph
grad_real = torch.autograd.grad(outputs=real_outputs.sum(), inputs=real_data, create_graph=True, retain_graph=True, allow_unused=True)[0] # Allow unused gradients
# If gradient is None, return zero penalty
if grad_real is None:
return torch.tensor(0.0, device=real_data.device)
# Flatten the gradients
grad_real = grad_real.view(grad_real.size(0), -1)
# Calculate gradient penalty
grad_penalty = (grad_real.norm(2, dim=1) ** 2).mean()
return self.gamma * 0.5 * grad_penalty
__all__ = ["VanillaGANLoss", "LSGANLoss", "WassersteinGANLoss", "HingeLoss", "FeatureMatchingLoss", "R1GradientPenalty"]