Note
Go to the end to download the full example code. or to run this example in your browser via Binder
Multimodal Losses for Cross-Modal Learning
This example demonstrates the various multimodal losses available in kaira for training models that work with multiple modalities (e.g., text-image, audio-video).
We’ll cover: - Contrastive Loss - Triplet Loss - InfoNCE Loss (Info Noise-Contrastive Estimation)
First, let’s import the necessary modules
import torch
import torch.nn as nn
from kaira.losses import LossRegistry
Let’s create some sample embeddings to simulate features from different modalities
def create_sample_embeddings(n_samples=100, n_dim=128):
"""Generate sample embeddings for multimodal loss demonstration.
Args:
n_samples (int): Number of samples to generate. Default is 100.
n_dim (int): Dimensionality of each embedding. Default is 128.
Returns:
Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
- anchors: Anchor embeddings (e.g., image features).
- positives: Positive embeddings (similar to anchors).
- negatives: Negative embeddings (different from anchors).
- labels: Labels corresponding to each sample.
"""
# Create anchor embeddings (e.g., image features)
anchors = torch.randn(n_samples, n_dim)
anchors = nn.functional.normalize(anchors, p=2, dim=1)
# Create positive embeddings (similar to anchors)
# Add small perturbations to anchors
positives = anchors + 0.1 * torch.randn(n_samples, n_dim)
positives = nn.functional.normalize(positives, p=2, dim=1)
# Create negative embeddings (different from anchors)
negatives = torch.randn(n_samples, n_dim)
negatives = nn.functional.normalize(negatives, p=2, dim=1)
# Create labels
labels = torch.arange(n_samples)
return anchors, positives, negatives, labels
# Create sample embeddings
anchors, positives, negatives, labels = create_sample_embeddings()
Now let’s compute different multimodal losses
# Contrastive Loss
contrastive_loss = LossRegistry.create("contrastiveloss", margin=0.5)
contrastive_value = contrastive_loss(anchors, positives, labels)
print(f"Contrastive Loss: {contrastive_value:.4f}")
# Triplet Loss
triplet_loss = LossRegistry.create("tripletloss", margin=0.3)
triplet_value = triplet_loss(anchors, positives, negatives)
print(f"Triplet Loss: {triplet_value:.4f}")
# InfoNCE Loss
infonce_loss = LossRegistry.create("infonceloss", temperature=0.07) # Changed from 'infoNCEloss' to 'infonceloss'
infonce_value = infonce_loss(anchors, positives)
print(f"InfoNCE Loss: {infonce_value:.4f}")
Contrastive Loss: 0.0188
Triplet Loss: 0.0000
InfoNCE Loss: 0.0188
Let’s visualize how these losses behave with different similarity values
def compute_similarity_losses(similarity):
"""Compute losses for a given cosine similarity value."""
# Create vectors with specified cosine similarity and consistent dtype
v1 = torch.tensor([[1.0, 0.0]], dtype=torch.float32) # Explicitly set dtype
v2 = torch.tensor([[similarity, np.sqrt(1 - similarity**2)]], dtype=torch.float32) # Match dtype
# Expand to batch
v1_batch = v1.expand(10, 2)
v2_batch = v2.expand(10, 2)
# Compute losses
losses = {"Contrastive": contrastive_loss(v1_batch, v2_batch).item(), "Triplet": triplet_loss(v1_batch, v2_batch, -v2_batch).item(), "InfoNCE": infonce_loss(v1_batch, v2_batch).item()}
return losses
# Generate range of similarity values
similarities = np.linspace(-1, 1, 100)
loss_curves: Dict[str, List[float]] = {name: [] for name in ["Contrastive", "Triplet", "InfoNCE"]}
for sim in similarities:
losses = compute_similarity_losses(sim)
for name, loss in losses.items():
loss_curves[name].append(loss)
Plot how losses vary with cosine similarity
plt.figure(figsize=(10, 6))
for name, losses in loss_curves.items():
plt.plot(similarities, losses, label=name)
plt.xlabel("Cosine Similarity")
plt.ylabel("Loss Value")
plt.title("Loss Response to Embedding Similarity")
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.show()

Let’s examine the clustering behavior of these losses
def plot_embedding_clusters(embeddings, labels, title):
"""Visualize embeddings using t-SNE for dimensionality reduction.
Args:
embeddings (torch.Tensor): 2D tensor of shape (n_samples, n_dim) representing the embeddings.
labels (torch.Tensor): 1D tensor of shape (n_samples,) representing the labels.
title (str): Title for the plot.
Raises:
AssertionError: If input tensors are not of the expected shape or type.
"""
# Input validation
assert torch.is_tensor(embeddings), "Embeddings must be a torch tensor"
assert torch.is_tensor(labels), "Labels must be a torch tensor"
assert embeddings.dim() == 2, f"Expected 2D embeddings, got {embeddings.dim()}D"
assert labels.dim() == 1, f"Expected 1D labels, got {labels.dim()}D"
assert embeddings.shape[0] == labels.shape[0], "Number of embeddings and labels must match"
# Use t-SNE for visualization
from sklearn.manifold import TSNE
tsne = TSNE(n_components=2, random_state=42)
embeddings_2d = tsne.fit_transform(embeddings.detach().numpy())
plt.scatter(embeddings_2d[:, 0], embeddings_2d[:, 1], c=labels, cmap="tab10")
plt.title(title)
plt.colorbar(label="Class")
# Plot original embeddings
plt.figure(figsize=(15, 5))
plt.subplot(131)
plot_embedding_clusters(anchors, labels, "Anchor Embeddings")
plt.subplot(132)
plot_embedding_clusters(positives, labels, "Positive Embeddings")
plt.subplot(133)
plot_embedding_clusters(negatives, labels, "Negative Embeddings")
plt.tight_layout()
plt.show()

Let’s also visualize the effect of the margin parameter in triplet loss
margins = [0.1, 0.3, 0.5, 1.0]
anchor_point = torch.tensor([[1.0, 0.0]])
theta = np.linspace(0, 2 * np.pi, 100)
loss_values: Dict[float, List[float]] = {margin: [] for margin in margins}
for t in theta:
point = torch.tensor([[np.cos(t), np.sin(t)]])
for margin in margins:
triplet_loss_margin = LossRegistry.create("tripletloss", margin=margin)
loss = triplet_loss_margin(anchor_point.expand(10, 2), point.expand(10, 2), -point.expand(10, 2)).item()
loss_values[margin].append(loss)
# Plot loss values in polar coordinates
plt.figure(figsize=(10, 10))
ax = plt.subplot(111, projection="polar")
for margin, losses in loss_values.items():
ax.plot(theta, losses, label=f"Margin={margin}")
plt.title("Triplet Loss Values Around Unit Circle")
plt.legend()
plt.show()

This example demonstrates various losses used in multimodal learning:
Contrastive Loss brings similar embeddings closer while pushing dissimilar ones apart, useful for tasks like face verification or image retrieval.
Triplet Loss ensures that an anchor is closer to a positive example than to a negative example by a margin, commonly used in few-shot learning and metric learning.
InfoNCE Loss is particularly effective for self-supervised learning and contrastive representation learning, as it can handle multiple negative examples efficiently.
The visualizations show how these losses respond to different similarity values and how the margin parameter affects the triplet loss behavior.
Total running time of the script: (0 minutes 1.501 seconds)