kaira.losses.multimodal.AlignmentLoss

Inheritance diagram of AlignmentLoss

Inheritance diagram for AlignmentLoss

class kaira.losses.multimodal.AlignmentLoss(alignment_type='l2', projection_dim=None)[source]

Bases: BaseLoss

Alignment Loss for multimodal embeddings.

This module aligns embeddings from different modalities.

Methods

__init__

Initialize the AlignmentLoss module.

forward

Forward pass through the AlignmentLoss module.

__init__(alignment_type='l2', projection_dim=None)[source]

Initialize the AlignmentLoss module.

Parameters:
  • alignment_type (str) – Type of alignment (‘l1’, ‘l2’, or ‘cosine’). Default is ‘l2’.

  • projection_dim (int, optional) – Dimension to project embeddings to before computing loss. If None, no projection is performed. Default is None.

forward(x1: Tensor, x2: Tensor) Tensor[source]

Forward pass through the AlignmentLoss module.

Parameters:
Returns:

The alignment loss.

Return type:

torch.Tensor