kaira.metrics.BaseMetric

Inheritance diagram of BaseMetric

Inheritance diagram for BaseMetric

class kaira.metrics.BaseMetric(name: str | None = None, *args: Any, **kwargs: Any)[source]

Bases: Module, ABC

Base Metric Module.

This is an abstract base class for defining metrics to evaluate the performance of a communication system. Subclasses should implement the forward method to calculate the metric.

Methods

__init__

Initialize the metric.

compute_with_stats

Compute metric with mean and standard deviation.

forward

Forward pass through the metric.

__init__(name: str | None = None, *args: Any, **kwargs: Any)[source]

Initialize the metric.

Parameters:
  • name (Optional[str]) – Name of the metric

  • *args – Variable length argument list.

  • **kwargs – Arbitrary keyword arguments.

abstractmethod forward(x: Tensor, y: Tensor, *args: Any, **kwargs: Any) Tensor[source]

Forward pass through the metric.

Parameters:
  • x (torch.Tensor) – The first input tensor (typically predictions)

  • y (torch.Tensor) – The second input tensor (typically targets)

  • *args – Variable length argument list.

  • **kwargs – Arbitrary keyword arguments.

Returns:

The calculated metric value

Return type:

torch.Tensor

compute_with_stats(x: Tensor, y: Tensor, *args: Any, **kwargs: Any) Tuple[Tensor, Tensor][source]

Compute metric with mean and standard deviation.

Parameters:
  • x (torch.Tensor) – The first input tensor (typically predictions)

  • y (torch.Tensor) – The second input tensor (typically targets)

  • *args – Variable length argument list.

  • **kwargs – Arbitrary keyword arguments.

Returns:

Mean and standard deviation of the metric

Return type:

Tuple[torch.Tensor, torch.Tensor]