"""Multimodal Losses module for Kaira.
This module contains various loss functions for training multimodal systems.
"""
import torch
import torch.nn.functional as F
from .base import BaseLoss
from .registry import LossRegistry
[docs]
@LossRegistry.register_loss()
class ContrastiveLoss(BaseLoss):
"""Contrastive Loss Module.
This module calculates contrastive loss between embeddings from different modalities.
"""
[docs]
def __init__(self, margin=0.2, temperature=0.07):
"""Initialize the ContrastiveLoss module.
Args:
margin (float): Margin for contrastive loss. Default is 0.2.
temperature (float): Temperature scaling factor. Default is 0.07.
"""
super().__init__()
self.margin = margin
self.temperature = temperature
[docs]
def forward(self, embeddings1: torch.Tensor, embeddings2: torch.Tensor, labels: torch.Tensor = None) -> torch.Tensor:
"""Forward pass through the ContrastiveLoss module.
Args:
embeddings1 (torch.Tensor): Embeddings from the first modality.
embeddings2 (torch.Tensor): Embeddings from the second modality.
labels (torch.Tensor, optional): Matching labels. Default is None (assumes paired data).
Returns:
torch.Tensor: The contrastive loss between the modalities.
"""
# Normalize embeddings
embeddings1 = F.normalize(embeddings1, p=2, dim=1)
embeddings2 = F.normalize(embeddings2, p=2, dim=1)
# Calculate cosine similarity
similarity = torch.mm(embeddings1, embeddings2.t()) / self.temperature
# For paired data (default)
if labels is None:
labels = torch.arange(similarity.size(0), device=similarity.device)
else:
labels = labels.long() # Ensure labels are of type Long
# Compute loss
loss = F.cross_entropy(similarity, labels)
return loss
[docs]
@LossRegistry.register_loss()
class TripletLoss(BaseLoss):
"""Triplet Loss Module for multimodal data.
This module implements triplet loss with hard negative mining.
"""
[docs]
def __init__(self, margin=0.3, distance="cosine"):
"""Initialize the TripletLoss module.
Args:
margin (float): Margin for triplet loss. Default is 0.3.
distance (str): Distance metric ('cosine' or 'euclidean'). Default is 'cosine'.
"""
super().__init__()
self.margin = margin
self.distance = distance
if distance not in ["cosine", "euclidean"]:
raise ValueError(f"Unsupported distance metric: {distance}")
[docs]
def forward(
self,
anchor: torch.Tensor,
positive: torch.Tensor,
negative: torch.Tensor = None,
labels: torch.Tensor = None,
) -> torch.Tensor:
"""Forward pass through the TripletLoss module.
Args:
anchor (torch.Tensor): Anchor embeddings.
positive (torch.Tensor): Positive embeddings.
negative (torch.Tensor, optional): Explicit negative embeddings.
labels (torch.Tensor, optional): Labels for online mining. Default is None.
Returns:
torch.Tensor: The triplet loss.
"""
if self.distance == "cosine":
# Normalize for cosine distance
anchor = F.normalize(anchor, p=2, dim=1)
positive = F.normalize(positive, p=2, dim=1)
# Calculate cosine similarity
pos_sim = torch.sum(anchor * positive, dim=1)
pos_dist = 1.0 - pos_sim
if negative is not None:
negative = F.normalize(negative, p=2, dim=1)
neg_sim = torch.sum(anchor * negative, dim=1)
neg_dist = 1.0 - neg_sim
elif labels is not None:
# Online mining using labels
all_dists = []
for i in range(anchor.size(0)):
neg_mask = labels != labels[i]
if not torch.any(neg_mask):
continue
curr_anchor = anchor[i].unsqueeze(0)
neg_candidates = anchor[neg_mask]
neg_sims = torch.mm(curr_anchor, neg_candidates.t()).squeeze()
hardest_neg_sim = torch.max(neg_sims)
all_dists.append(1.0 - hardest_neg_sim)
if all_dists:
neg_dist = torch.stack(all_dists)
else:
return pos_dist.mean() # No negatives found
else:
raise ValueError("Either negative samples or labels must be provided")
else: # euclidean
pos_dist = torch.pairwise_distance(anchor, positive)
if negative is not None:
neg_dist = torch.pairwise_distance(anchor, negative)
elif labels is not None:
# Online mining using labels
all_dists = []
for i in range(anchor.size(0)):
neg_mask = labels != labels[i]
if not torch.any(neg_mask):
continue
curr_anchor = anchor[i].unsqueeze(0).expand(torch.sum(neg_mask), -1)
neg_candidates = anchor[neg_mask]
dists = torch.pairwise_distance(curr_anchor, neg_candidates)
hardest_neg_dist = torch.min(dists)
all_dists.append(hardest_neg_dist)
if all_dists:
neg_dist = torch.stack(all_dists)
else:
return pos_dist.mean() # No negatives found
else:
raise ValueError("Either negative samples or labels must be provided")
# Calculate triplet loss
loss = torch.clamp(pos_dist - neg_dist + self.margin, min=0.0)
return loss.mean()
[docs]
@LossRegistry.register_loss()
class InfoNCELoss(BaseLoss):
"""InfoNCE Loss Module for multimodal contrastive learning.
This module implements the Noise Contrastive Estimation loss.
"""
[docs]
def __init__(self, temperature=0.07):
"""Initialize the InfoNCELoss module.
Args:
temperature (float): Temperature scaling factor. Default is 0.07.
"""
super().__init__()
self.temperature = temperature
[docs]
def forward(self, query: torch.Tensor, key: torch.Tensor, queue: torch.Tensor = None, mask: torch.Tensor = None) -> torch.Tensor:
"""Forward pass through the InfoNCELoss module.
Args:
query (torch.Tensor): Query embeddings from one modality.
key (torch.Tensor): Key embeddings from another modality (positives).
queue (torch.Tensor, optional): Queue of negative samples. Default is None.
mask (torch.Tensor, optional): Binary mask defining positive pairs. Default is None.
Shape should be [query.size(0), key.size(0)] where 1 indicates a positive pair.
Returns:
torch.Tensor: The InfoNCE loss.
"""
# Normalize embeddings
query = F.normalize(query, p=2, dim=1)
key = F.normalize(key, p=2, dim=1)
# Handle different masking scenarios
if queue is not None:
# Compute positive logits
l_pos = torch.einsum("nc,nc->n", [query, key]).unsqueeze(-1)
# Compute negative logits with queue
queue = F.normalize(queue, p=2, dim=1)
l_neg = torch.einsum("nc,kc->nk", [query, queue])
logits = torch.cat([l_pos, l_neg], dim=1)
# Labels: positives are the 0-th
labels = torch.zeros(logits.shape[0], dtype=torch.long, device=query.device)
else:
# Compute all pairwise similarities
similarities = torch.einsum("nc,kc->nk", [query, key])
if mask is not None:
# Apply custom masking to define positives and negatives
# Make sure the mask is properly shaped
assert mask.shape == similarities.shape, "Mask shape must match similarity matrix shape"
# For each query, get the positive key with the highest similarity
positive_mask = mask.bool()
negative_mask = ~positive_mask
# Replace non-positive similarities with -inf
masked_similarities = similarities.clone()
masked_similarities.masked_fill_(negative_mask, float("-inf"))
# Get positive logits (max similarity for each query among its positive keys)
l_pos = masked_similarities.max(dim=1, keepdim=True)[0]
# Prepare negative logits
# Replace diagonal with -inf to avoid self-contrast if not already masked
diag_mask = torch.eye(similarities.shape[0], device=similarities.device).bool()
negative_mask = negative_mask & ~diag_mask # Remove diagonal from negatives
# Extract only negative similarities
l_neg = similarities.masked_select(negative_mask).reshape(similarities.shape[0], -1)
if l_neg.shape[1] == 0: # No negatives found
# Just minimize distance between positive pairs
return -l_pos.mean()
# Concatenate positive and negative logits
logits = torch.cat([l_pos, l_neg], dim=1)
# Labels: positives are at index 0
labels = torch.zeros(logits.shape[0], dtype=torch.long, device=query.device)
else:
# Default behavior: use diagonal elements as positives
# Get positive logits (diagonal elements)
l_pos = torch.diag(similarities).unsqueeze(-1)
# Remove diagonal from similarities to get negative logits
mask = torch.eye(similarities.shape[0], device=similarities.device)
similarities.masked_fill_(mask.bool(), float("-inf"))
l_neg = similarities
# Concatenate positive and negative logits
logits = torch.cat([l_pos, l_neg], dim=1)
# Labels: positives are at index 0
labels = torch.zeros(logits.shape[0], dtype=torch.long, device=query.device)
# Scale with temperature
logits /= self.temperature
# Compute loss
loss = F.cross_entropy(logits, labels)
return loss
[docs]
@LossRegistry.register_loss()
class CMCLoss(BaseLoss):
"""Cross-Modal Consistency Loss Module.
This module implements a loss to ensure consistency across modalities.
"""
[docs]
def __init__(self, lambda_cmc=1.0):
"""Initialize the CMCLoss module.
Args:
lambda_cmc (float): Weight for the CMC loss. Default is 1.0.
"""
super().__init__()
self.lambda_cmc = lambda_cmc
[docs]
def forward(self, x1: torch.Tensor, x2: torch.Tensor, proj1: BaseLoss, proj2: BaseLoss) -> torch.Tensor:
"""Forward pass through the CMCLoss module.
Args:
x1 (torch.Tensor): Features from the first modality.
x2 (torch.Tensor): Features from the second modality.
proj1 (BaseLoss): Projection head for the first modality.
proj2 (BaseLoss): Projection head for the second modality.
Returns:
torch.Tensor: The cross-modal consistency loss.
"""
z1 = proj1(x1)
z2 = proj2(x2)
z1 = F.normalize(z1, p=2, dim=1)
z2 = F.normalize(z2, p=2, dim=1)
# Cross-modal similarity
sim_1to2 = torch.mm(z1, z2.t())
sim_2to1 = torch.mm(z2, z1.t())
# Target: identity matrix (matching indices should have high similarity)
targets = torch.arange(z1.size(0), device=z1.device)
# Calculate loss
loss = (F.cross_entropy(sim_1to2, targets) + F.cross_entropy(sim_2to1, targets)) / 2
return self.lambda_cmc * loss
[docs]
@LossRegistry.register_loss()
class AlignmentLoss(BaseLoss):
"""Alignment Loss for multimodal embeddings.
This module aligns embeddings from different modalities.
"""
[docs]
def __init__(self, alignment_type="l2", projection_dim=None):
"""Initialize the AlignmentLoss module.
Args:
alignment_type (str): Type of alignment ('l1', 'l2', or 'cosine'). Default is 'l2'.
projection_dim (int, optional): Dimension to project embeddings to before computing loss.
If None, no projection is performed. Default is None.
"""
super().__init__()
self.alignment_type = alignment_type
self.projection_dim = projection_dim
if alignment_type not in ["l1", "l2", "cosine"]:
raise ValueError(f"Unsupported alignment type: {alignment_type}")
# Create projection layer if needed
self.projector = None
if self.projection_dim is not None:
self.projector = torch.nn.Linear(in_features=1, out_features=projection_dim, bias=False)
# We'll initialize the actual weights in the forward pass when we know the input dimension
[docs]
def forward(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor:
"""Forward pass through the AlignmentLoss module.
Args:
x1 (torch.Tensor): Embeddings from the first modality.
x2 (torch.Tensor): Embeddings from the second modality.
Returns:
torch.Tensor: The alignment loss.
"""
# Apply projection if needed
if self.projection_dim is not None:
# Initialize the projector if it's the first call
if self.projector.in_features != x1.shape[1]:
# Replace the projector with a properly sized one
device = x1.device
self.projector = torch.nn.Linear(in_features=x1.shape[1], out_features=self.projection_dim, bias=False).to(device)
# Initialize with orthogonal weights for better preservation of distances
torch.nn.init.orthogonal_(self.projector.weight)
# Apply projection
x1 = self.projector(x1)
x2 = self.projector(x2)
# Compute alignment loss based on the chosen type
if self.alignment_type == "l1":
return F.l1_loss(x1, x2)
elif self.alignment_type == "l2":
return F.mse_loss(x1, x2)
elif self.alignment_type == "cosine":
x1 = F.normalize(x1, p=2, dim=1)
x2 = F.normalize(x2, p=2, dim=1)
return 1 - torch.mean(torch.sum(x1 * x2, dim=1))
else:
raise ValueError(f"Unsupported alignment type: {self.alignment_type}")
__all__ = ["ContrastiveLoss", "TripletLoss", "InfoNCELoss", "CMCLoss", "AlignmentLoss"]