kaira.losses.multimodal.AlignmentLoss

Inheritance diagram for AlignmentLoss
- class kaira.losses.multimodal.AlignmentLoss(alignment_type='l2', projection_dim=None)[source]
Bases:
BaseLossAlignment Loss for multimodal embeddings.
This module aligns embeddings from different modalities.
Methods
Initialize the AlignmentLoss module.
Forward pass through the AlignmentLoss module.
- forward(x1: Tensor, x2: Tensor) Tensor[source]
Forward pass through the AlignmentLoss module.
- Parameters:
x1 (torch.Tensor) – Embeddings from the first modality.
x2 (torch.Tensor) – Embeddings from the second modality.
- Returns:
The alignment loss.
- Return type: