kaira.losses.multimodal.ContrastiveLoss

Inheritance diagram of ContrastiveLoss

Inheritance diagram for ContrastiveLoss

class kaira.losses.multimodal.ContrastiveLoss(margin=0.2, temperature=0.07)[source]

Bases: BaseLoss

Contrastive Loss Module.

This module calculates contrastive loss between embeddings from different modalities.

Methods

__init__

Initialize the ContrastiveLoss module.

forward

Forward pass through the ContrastiveLoss module.

__init__(margin=0.2, temperature=0.07)[source]

Initialize the ContrastiveLoss module.

Parameters:
  • margin (float) – Margin for contrastive loss. Default is 0.2.

  • temperature (float) – Temperature scaling factor. Default is 0.07.

forward(embeddings1: Tensor, embeddings2: Tensor, labels: Tensor = None) Tensor[source]

Forward pass through the ContrastiveLoss module.

Parameters:
  • embeddings1 (torch.Tensor) – Embeddings from the first modality.

  • embeddings2 (torch.Tensor) – Embeddings from the second modality.

  • labels (torch.Tensor, optional) – Matching labels. Default is None (assumes paired data).

Returns:

The contrastive loss between the modalities.

Return type:

torch.Tensor