kaira.losses.multimodal.TripletLoss

Inheritance diagram of TripletLoss

Inheritance diagram for TripletLoss

class kaira.losses.multimodal.TripletLoss(margin=0.3, distance='cosine')[source]

Bases: BaseLoss

Triplet Loss Module for multimodal data.

This module implements triplet loss with hard negative mining.

Methods

__init__

Initialize the TripletLoss module.

forward

Forward pass through the TripletLoss module.

__init__(margin=0.3, distance='cosine')[source]

Initialize the TripletLoss module.

Parameters:
  • margin (float) – Margin for triplet loss. Default is 0.3.

  • distance (str) – Distance metric (‘cosine’ or ‘euclidean’). Default is ‘cosine’.

forward(anchor: Tensor, positive: Tensor, negative: Tensor = None, labels: Tensor = None) Tensor[source]

Forward pass through the TripletLoss module.

Parameters:
  • anchor (torch.Tensor) – Anchor embeddings.

  • positive (torch.Tensor) – Positive embeddings.

  • negative (torch.Tensor, optional) – Explicit negative embeddings.

  • labels (torch.Tensor, optional) – Labels for online mining. Default is None.

Returns:

The triplet loss.

Return type:

torch.Tensor