kaira.losses.multimodal.TripletLoss

Inheritance diagram for TripletLoss
- class kaira.losses.multimodal.TripletLoss(margin=0.3, distance='cosine')[source]
Bases:
BaseLossTriplet Loss Module for multimodal data.
This module implements triplet loss with hard negative mining.
Methods
Initialize the TripletLoss module.
Forward pass through the TripletLoss module.
- forward(anchor: Tensor, positive: Tensor, negative: Tensor = None, labels: Tensor = None) Tensor[source]
Forward pass through the TripletLoss module.
- Parameters:
anchor (torch.Tensor) – Anchor embeddings.
positive (torch.Tensor) – Positive embeddings.
negative (torch.Tensor, optional) – Explicit negative embeddings.
labels (torch.Tensor, optional) – Labels for online mining. Default is None.
- Returns:
The triplet loss.
- Return type: