kaira.losses.multimodal.InfoNCELoss

Inheritance diagram of InfoNCELoss

Inheritance diagram for InfoNCELoss

class kaira.losses.multimodal.InfoNCELoss(temperature=0.07)[source]

Bases: BaseLoss

InfoNCE Loss Module for multimodal contrastive learning.

This module implements the Noise Contrastive Estimation loss.

Methods

__init__

Initialize the InfoNCELoss module.

forward

Forward pass through the InfoNCELoss module.

__init__(temperature=0.07)[source]

Initialize the InfoNCELoss module.

Parameters:

temperature (float) – Temperature scaling factor. Default is 0.07.

forward(query: Tensor, key: Tensor, queue: Tensor = None, mask: Tensor = None) Tensor[source]

Forward pass through the InfoNCELoss module.

Parameters:
  • query (torch.Tensor) – Query embeddings from one modality.

  • key (torch.Tensor) – Key embeddings from another modality (positives).

  • queue (torch.Tensor, optional) – Queue of negative samples. Default is None.

  • mask (torch.Tensor, optional) – Binary mask defining positive pairs. Default is None. Shape should be [query.size(0), key.size(0)] where 1 indicates a positive pair.

Returns:

The InfoNCE loss.

Return type:

torch.Tensor