kaira.losses.multimodal.ContrastiveLoss

Inheritance diagram for ContrastiveLoss
- class kaira.losses.multimodal.ContrastiveLoss(margin=0.2, temperature=0.07)[source]
Bases:
BaseLossContrastive Loss Module.
This module calculates contrastive loss between embeddings from different modalities.
Methods
Initialize the ContrastiveLoss module.
Forward pass through the ContrastiveLoss module.
- forward(embeddings1: Tensor, embeddings2: Tensor, labels: Tensor = None) Tensor[source]
Forward pass through the ContrastiveLoss module.
- Parameters:
embeddings1 (torch.Tensor) – Embeddings from the first modality.
embeddings2 (torch.Tensor) – Embeddings from the second modality.
labels (torch.Tensor, optional) – Matching labels. Default is None (assumes paired data).
- Returns:
The contrastive loss between the modalities.
- Return type: