kaira.losses.multimodal.CMCLoss

Inheritance diagram of CMCLoss

Inheritance diagram for CMCLoss

class kaira.losses.multimodal.CMCLoss(lambda_cmc=1.0)[source]

Bases: BaseLoss

Cross-Modal Consistency Loss Module.

This module implements a loss to ensure consistency across modalities.

Methods

__init__

Initialize the CMCLoss module.

forward

Forward pass through the CMCLoss module.

__init__(lambda_cmc=1.0)[source]

Initialize the CMCLoss module.

Parameters:

lambda_cmc (float) – Weight for the CMC loss. Default is 1.0.

forward(x1: Tensor, x2: Tensor, proj1: BaseLoss, proj2: BaseLoss) Tensor[source]

Forward pass through the CMCLoss module.

Parameters:
  • x1 (torch.Tensor) – Features from the first modality.

  • x2 (torch.Tensor) – Features from the second modality.

  • proj1 (BaseLoss) – Projection head for the first modality.

  • proj2 (BaseLoss) – Projection head for the second modality.

Returns:

The cross-modal consistency loss.

Return type:

torch.Tensor