kaira.losses.multimodal.CMCLoss

Inheritance diagram for CMCLoss
- class kaira.losses.multimodal.CMCLoss(lambda_cmc=1.0)[source]
Bases:
BaseLossCross-Modal Consistency Loss Module.
This module implements a loss to ensure consistency across modalities.
Methods
Initialize the CMCLoss module.
Forward pass through the CMCLoss module.
- __init__(lambda_cmc=1.0)[source]
Initialize the CMCLoss module.
- Parameters:
lambda_cmc (float) – Weight for the CMC loss. Default is 1.0.
- forward(x1: Tensor, x2: Tensor, proj1: BaseLoss, proj2: BaseLoss) Tensor[source]
Forward pass through the CMCLoss module.
- Parameters:
x1 (torch.Tensor) – Features from the first modality.
x2 (torch.Tensor) – Features from the second modality.
proj1 (BaseLoss) – Projection head for the first modality.
proj2 (BaseLoss) – Projection head for the second modality.
- Returns:
The cross-modal consistency loss.
- Return type: