kaira.losses.multimodal.InfoNCELoss

Inheritance diagram for InfoNCELoss
- class kaira.losses.multimodal.InfoNCELoss(temperature=0.07)[source]
Bases:
BaseLossInfoNCE Loss Module for multimodal contrastive learning.
This module implements the Noise Contrastive Estimation loss.
Methods
Initialize the InfoNCELoss module.
Forward pass through the InfoNCELoss module.
- __init__(temperature=0.07)[source]
Initialize the InfoNCELoss module.
- Parameters:
temperature (float) – Temperature scaling factor. Default is 0.07.
- forward(query: Tensor, key: Tensor, queue: Tensor = None, mask: Tensor = None) Tensor[source]
Forward pass through the InfoNCELoss module.
- Parameters:
query (torch.Tensor) – Query embeddings from one modality.
key (torch.Tensor) – Key embeddings from another modality (positives).
queue (torch.Tensor, optional) – Queue of negative samples. Default is None.
mask (torch.Tensor, optional) – Binary mask defining positive pairs. Default is None. Shape should be [query.size(0), key.size(0)] where 1 indicates a positive pair.
- Returns:
The InfoNCE loss.
- Return type: