Source code for kaira.losses.audio

"""Audio Losses module for Kaira.

This module contains various loss functions for training audio-based communication systems.
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio

from .base import BaseLoss
from .registry import LossRegistry


[docs] @LossRegistry.register_loss() class L1AudioLoss(BaseLoss): """L1 Audio Loss Module. This module calculates the L1 loss between the input and target audio signals. """
[docs] def __init__(self): """Initialize the L1AudioLoss module.""" super().__init__() self.l1 = nn.L1Loss()
[docs] def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """Forward pass through the L1AudioLoss module. Args: x (torch.Tensor): The input audio tensor. target (torch.Tensor): The target audio tensor. Returns: torch.Tensor: The L1 loss between the input and target audio. """ return self.l1(x, target)
[docs] @LossRegistry.register_loss() class SpectralConvergenceLoss(BaseLoss): """Spectral Convergence Loss Module. This module calculates the spectral convergence loss between the input and target spectra. """
[docs] def __init__(self): """Initialize the SpectralConvergenceLoss module.""" super().__init__()
[docs] def forward(self, x_mag: torch.Tensor, target_mag: torch.Tensor) -> torch.Tensor: """Forward pass through the SpectralConvergenceLoss module. Args: x_mag (torch.Tensor): The magnitude of the input spectrum. target_mag (torch.Tensor): The magnitude of the target spectrum. Returns: torch.Tensor: The spectral convergence loss. """ return torch.norm(target_mag - x_mag, p="fro") / torch.norm(target_mag, p="fro")
[docs] @LossRegistry.register_loss() class LogSTFTMagnitudeLoss(BaseLoss): """Log STFT Magnitude Loss Module. This module calculates the log STFT magnitude loss between the input and target spectra. """
[docs] def __init__(self): """Initialize the LogSTFTMagnitudeLoss module.""" super().__init__()
[docs] def forward(self, x_mag: torch.Tensor, target_mag: torch.Tensor) -> torch.Tensor: """Forward pass through the LogSTFTMagnitudeLoss module. Args: x_mag (torch.Tensor): The magnitude of the input spectrum. target_mag (torch.Tensor): The magnitude of the target spectrum. Returns: torch.Tensor: The log STFT magnitude loss. """ log_x_mag = torch.log(x_mag + 1e-7) log_target_mag = torch.log(target_mag + 1e-7) return F.l1_loss(log_x_mag, log_target_mag)
[docs] @LossRegistry.register_loss() class STFTLoss(BaseLoss): """STFT Loss Module. This module calculates the STFT loss between the input and target audio signals, combining spectral convergence loss and log STFT magnitude loss. """
[docs] def __init__(self, fft_size=1024, hop_size=256, win_length=1024, window="hann"): """Initialize the STFTLoss module. Args: fft_size (int): FFT size for STFT. Default is 1024. hop_size (int): Hop size for STFT. Default is 256. win_length (int): Window length for STFT. Default is 1024. window (str): Window function type. Default is 'hann'. """ super().__init__() self.fft_size = fft_size self.hop_size = hop_size self.win_length = win_length self.window = window self.spectral_convergence_loss = SpectralConvergenceLoss() self.log_stft_magnitude_loss = LogSTFTMagnitudeLoss()
[docs] def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """Forward pass through the STFTLoss module. Args: x (torch.Tensor): The input audio tensor. target (torch.Tensor): The target audio tensor. Returns: torch.Tensor: The combined STFT loss. """ window_fn = getattr(torch, f"{self.window}_window") window = window_fn(self.win_length, dtype=x.dtype, device=x.device) x_stft = torch.stft( x, n_fft=self.fft_size, hop_length=self.hop_size, win_length=self.win_length, window=window, return_complex=True, ) target_stft = torch.stft( target, n_fft=self.fft_size, hop_length=self.hop_size, win_length=self.win_length, window=window, return_complex=True, ) x_mag = torch.abs(x_stft) target_mag = torch.abs(target_stft) sc_loss = self.spectral_convergence_loss(x_mag, target_mag) mag_loss = self.log_stft_magnitude_loss(x_mag, target_mag) return sc_loss + mag_loss
[docs] @LossRegistry.register_loss() class MultiResolutionSTFTLoss(BaseLoss): """Multi-Resolution STFT Loss Module. This module calculates STFT loss at multiple resolutions for better time-frequency coverage. """
[docs] def __init__( self, fft_sizes=[512, 1024, 2048], hop_sizes=[128, 256, 512], win_lengths=[512, 1024, 2048], window="hann", ): """Initialize the MultiResolutionSTFTLoss module. Args: fft_sizes (list): List of FFT sizes for each resolution. Default is [512, 1024, 2048]. hop_sizes (list): List of hop sizes for each resolution. Default is [128, 256, 512]. win_lengths (list): List of window lengths for each resolution. Default is [512, 1024, 2048]. window (str): Window function type. Default is 'hann'. """ super().__init__() assert len(fft_sizes) == len(hop_sizes) == len(win_lengths) self.stft_losses = nn.ModuleList([STFTLoss(fft_size=fft_size, hop_size=hop_size, win_length=win_length, window=window) for fft_size, hop_size, win_length in zip(fft_sizes, hop_sizes, win_lengths)])
[docs] def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """Forward pass through the MultiResolutionSTFTLoss module. Args: x (torch.Tensor): The input audio tensor. target (torch.Tensor): The target audio tensor. Returns: torch.Tensor: The multi-resolution STFT loss. """ loss = 0.0 for stft_loss in self.stft_losses: loss += stft_loss(x, target) return loss / len(self.stft_losses)
[docs] @LossRegistry.register_loss() class MelSpectrogramLoss(BaseLoss): """Mel-Spectrogram Loss Module. This module calculates the loss between mel-spectrograms of input and target audio. """
[docs] def __init__( self, sample_rate=22050, n_fft=1024, hop_length=256, n_mels=80, f_min=0.0, f_max=8000.0, log_mel=True, ): """Initialize the MelSpectrogramLoss module. Args: sample_rate (int): Audio sample rate. Default is 22050. n_fft (int): FFT size. Default is 1024. hop_length (int): Hop size. Default is 256. n_mels (int): Number of mel bands. Default is 80. f_min (float): Minimum frequency. Default is 0.0. f_max (float): Maximum frequency. Default is 8000.0. log_mel (bool): Whether to use log-mel spectrogram. Default is True. """ super().__init__() self.melspec_transform = torchaudio.transforms.MelSpectrogram( sample_rate=sample_rate, n_fft=n_fft, hop_length=hop_length, n_mels=n_mels, f_min=f_min, f_max=f_max, ) self.log_mel = log_mel
[docs] def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """Forward pass through the MelSpectrogramLoss module. Args: x (torch.Tensor): The input audio tensor. target (torch.Tensor): The target audio tensor. Returns: torch.Tensor: The mel-spectrogram loss. """ x_mel = self.melspec_transform(x) target_mel = self.melspec_transform(target) if self.log_mel: x_mel = torch.log(x_mel + 1e-7) target_mel = torch.log(target_mel + 1e-7) return F.l1_loss(x_mel, target_mel)
[docs] @LossRegistry.register_loss() class FeatureMatchingLoss(BaseLoss): """Feature Matching Loss Module. This module calculates the loss between features extracted from a pretrained model. """
[docs] def __init__(self, model, layers, weights=None): """Initialize the FeatureMatchingLoss module. Args: model (BaseLoss): Pretrained model for feature extraction. layers (list): List of layer indices to extract features from. weights (list, optional): Weights for each layer. Default is None (equal weights). """ super().__init__() self.model = model self.model.eval() self.layers = layers if weights is None: self.weights = [1.0] * len(layers) else: assert len(weights) == len(layers) self.weights = weights for param in self.model.parameters(): param.requires_grad = False
[docs] def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """Forward pass through the FeatureMatchingLoss module. Args: x (torch.Tensor): The input audio tensor. target (torch.Tensor): The target audio tensor. Returns: torch.Tensor: The feature matching loss. """ # Create tensors that require gradient x_with_grad = x.detach().requires_grad_(True) target_with_grad = target.detach().requires_grad_(True) # Register hooks to capture activations activations_x = {} activations_target = {} def get_activation(name): def hook(model, input, output): # Don't detach to allow gradient flow activations_x[name] = output return hook def get_target_activation(name): def hook(model, input, output): # Don't detach to allow gradient flow activations_target[name] = output return hook # Register hooks handles = [] for i, layer_idx in enumerate(self.layers): handles.append(list(self.model.children())[layer_idx].register_forward_hook(get_activation(f"layer_{i}"))) # Forward pass for input self.model(x_with_grad) # Remove hooks for handle in handles: handle.remove() # Register hooks for target handles = [] for i, layer_idx in enumerate(self.layers): handles.append(list(self.model.children())[layer_idx].register_forward_hook(get_target_activation(f"layer_{i}"))) # Forward pass for target self.model(target_with_grad) # Remove hooks for handle in handles: handle.remove() # Calculate loss loss = 0.0 for i in range(len(self.layers)): layer_name = f"layer_{i}" # Use features from activations # We only detach the target activations to prevent training signal # from affecting the feature extractor loss += self.weights[i] * F.l1_loss(activations_x[layer_name], activations_target[layer_name].detach()) return loss
[docs] @LossRegistry.register_loss() class AudioContrastiveLoss(BaseLoss): """Audio Contrastive Loss Module. This module calculates a contrastive loss to bring similar audio samples closer in feature space. It can be used for self-supervised learning of audio representations. """
[docs] def __init__(self, margin=1.0, temperature=0.1, normalize=True, reduction="mean"): """Initialize the AudioContrastiveLoss module. Args: margin (float): Margin for contrastive loss. Default is 1.0. temperature (float): Temperature scaling factor. Default is 0.1. normalize (bool): Whether to normalize features. Default is True. reduction (str): Reduction method ('mean', 'sum', 'none'). Default is 'mean'. """ super().__init__() self.margin = margin self.temperature = temperature self.normalize = normalize self.reduction = reduction
[docs] def forward(self, features: torch.Tensor, target: torch.Tensor = None, projector=None, view_maker=None, labels=None) -> torch.Tensor: """Forward pass through the AudioContrastiveLoss module. Args: features (torch.Tensor): Audio feature embeddings. target (torch.Tensor, optional): Target features for comparison. If None, features are compared with themselves (self-supervised). Default is None. projector (nn.Module, optional): Optional projection network to map features to a lower-dimensional space. Default is None. view_maker (callable, optional): Function to create different views of the same data. Default is None. labels (torch.Tensor, optional): Labels for supervised contrastive learning. Default is None. Returns: torch.Tensor: The contrastive loss. """ # Apply projector if provided if projector is not None: features = projector(features) if target is not None: target = projector(target) # Apply view maker if provided if view_maker is not None: # Create positive pairs using the view maker if target is None: target = view_maker(features) else: target = view_maker(target) # If no target is provided, use the features themselves if target is None: target = features # Normalize features if self.normalize: features = F.normalize(features, p=2, dim=1) target = F.normalize(target, p=2, dim=1) # Compute similarity matrix similarity_matrix = torch.matmul(features, target.t()) / self.temperature # Create mask for positive pairs if labels is not None: # Supervised contrastive learning with provided labels mask_positive = torch.eq(labels.view(-1, 1), labels.view(1, -1)).float() else: # Self-supervised learning (positive pairs are along the diagonal) batch_size = features.size(0) mask_positive = torch.eye(batch_size, device=features.device) # Remove self-comparisons for robustness mask_self = torch.eye(mask_positive.shape[0], device=mask_positive.device) mask_positive = mask_positive - mask_self # Compute loss (InfoNCE / NT-Xent loss) exp_logits = torch.exp(similarity_matrix) * (1 - mask_self) log_prob = similarity_matrix - torch.log(exp_logits.sum(dim=1, keepdim=True) + 1e-10) # Compute mean of positive pairs # Handle the case where there are no positive pairs for some samples positive_per_sample = mask_positive.sum(1) # Avoid division by zero (add small epsilon) positive_per_sample = torch.clamp(positive_per_sample, min=1e-10) mean_log_prob_pos = (mask_positive * log_prob).sum(1) / positive_per_sample # Apply reduction if self.reduction == "mean": loss = -mean_log_prob_pos.mean() elif self.reduction == "sum": loss = -mean_log_prob_pos.sum() else: loss = -mean_log_prob_pos return loss
__all__ = ["L1AudioLoss", "SpectralConvergenceLoss", "LogSTFTMagnitudeLoss", "STFTLoss", "MultiResolutionSTFTLoss", "MelSpectrogramLoss", "FeatureMatchingLoss", "AudioContrastiveLoss"]