Source code for kaira.losses.text

"""Text Losses module for Kaira.

This module contains various loss functions for training text-based systems.
"""

import torch
import torch.nn as nn
import torch.nn.functional as F

from .base import BaseLoss
from .registry import LossRegistry


[docs] @LossRegistry.register_loss() class CrossEntropyLoss(BaseLoss): """Cross Entropy Loss Module. This module calculates the cross entropy loss for classification tasks. """
[docs] def __init__(self, weight=None, ignore_index=-100, label_smoothing=0.0): """Initialize the CrossEntropyLoss module. Args: weight (torch.Tensor, optional): Class weights. Default is None. ignore_index (int): Index to ignore. Default is -100. label_smoothing (float): Label smoothing value. Default is 0.0. """ super().__init__() self.ce = nn.CrossEntropyLoss(weight=weight, ignore_index=ignore_index, label_smoothing=label_smoothing)
[docs] def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """Forward pass through the CrossEntropyLoss module. Args: x (torch.Tensor): The input logits tensor. target (torch.Tensor): The target tensor. Returns: torch.Tensor: The cross entropy loss. """ return self.ce(x, target)
[docs] @LossRegistry.register_loss() class LabelSmoothingLoss(BaseLoss): """Label Smoothing Loss Module. This module implements label smoothing to prevent overconfidence. """
[docs] def __init__(self, smoothing=0.1, classes=0, dim=-1): """Initialize the LabelSmoothingLoss module. Args: smoothing (float): Smoothing factor. Default is 0.1. classes (int): Number of classes. Default is 0. dim (int): Dimension to reduce. Default is -1. """ super().__init__() self.confidence = 1.0 - smoothing self.smoothing = smoothing self.classes = classes self.dim = dim
[docs] def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """Forward pass through the LabelSmoothingLoss module. Args: x (torch.Tensor): The input logits tensor. target (torch.Tensor): The target tensor. Returns: torch.Tensor: The label smoothing loss. """ assert x.size(1) == self.classes log_probs = F.log_softmax(x, dim=self.dim) # Hard targets nll_loss = -log_probs.gather(dim=self.dim, index=target.unsqueeze(1)) nll_loss = nll_loss.squeeze(1) # Smoothed targets smooth_loss = -log_probs.sum(dim=self.dim) # Combine losses loss = self.confidence * nll_loss + self.smoothing * smooth_loss / self.classes return loss.mean()
[docs] @LossRegistry.register_loss() class CosineSimilarityLoss(BaseLoss): """Cosine Similarity Loss Module. This module calculates loss based on cosine similarity between embeddings. """
[docs] def __init__(self, margin=0.0): """Initialize the CosineSimilarityLoss module. Args: margin (float): Margin for similarity. Default is 0.0. """ super().__init__() self.margin = margin
[docs] def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """Forward pass through the CosineSimilarityLoss module. Args: x (torch.Tensor): The input embeddings tensor. target (torch.Tensor): The target embeddings tensor. Returns: torch.Tensor: The cosine similarity loss. """ # Normalize embeddings x_norm = F.normalize(x, p=2, dim=1) target_norm = F.normalize(target, p=2, dim=1) # Calculate cosine similarity cosine_sim = torch.sum(x_norm * target_norm, dim=1) # Calculate loss loss = torch.mean(torch.clamp(self.margin - cosine_sim, min=0.0)) return loss
[docs] @LossRegistry.register_loss() class Word2VecLoss(BaseLoss): """Word2Vec Loss Module. This module implements the negative sampling loss used in Word2Vec. """
[docs] def __init__(self, embedding_dim, vocab_size, n_negatives=5): """Initialize the Word2VecLoss module. Args: embedding_dim (int): Dimensionality of embeddings. vocab_size (int): Size of vocabulary. n_negatives (int): Number of negative samples. Default is 5. """ super().__init__() self.embedding_dim = embedding_dim self.vocab_size = vocab_size self.n_negatives = n_negatives # Initialize embeddings self.in_embed = nn.Embedding(vocab_size, embedding_dim) self.out_embed = nn.Embedding(vocab_size, embedding_dim) # Initialize weights self.in_embed.weight.data.uniform_(-0.5 / embedding_dim, 0.5 / embedding_dim) self.out_embed.weight.data.uniform_(-0.5 / embedding_dim, 0.5 / embedding_dim)
[docs] def forward(self, input_idx: torch.Tensor, output_idx: torch.Tensor) -> torch.Tensor: """Forward pass through the Word2VecLoss module. Args: input_idx (torch.Tensor): Input word indices. output_idx (torch.Tensor): Output context word indices. Returns: torch.Tensor: The Word2Vec loss. """ batch_size = input_idx.size(0) # Get embeddings input_emb = self.in_embed(input_idx) # [batch_size, embed_dim] output_emb = self.out_embed(output_idx) # [batch_size, embed_dim] # Positive samples pos_score = torch.sum(input_emb * output_emb, dim=1) pos_loss = F.logsigmoid(pos_score) # Negative samples neg_samples = torch.randint(0, self.vocab_size, (batch_size, self.n_negatives), device=input_idx.device) neg_emb = self.out_embed(neg_samples) # [batch_size, n_negatives, embed_dim] # Calculate negative scores neg_score = torch.bmm(neg_emb, input_emb.unsqueeze(2)).squeeze(2) # [batch_size, n_negatives] neg_loss = F.logsigmoid(-neg_score).sum(1) # Total loss loss = -(pos_loss + neg_loss).mean() return loss
__all__ = ["CrossEntropyLoss", "LabelSmoothingLoss", "CosineSimilarityLoss", "Word2VecLoss"]