kaira.metrics.BaseMetric

Inheritance diagram for BaseMetric
- class kaira.metrics.BaseMetric(name: str | None = None, *args: Any, **kwargs: Any)[source]
-
Base Metric Module.
This is an abstract base class for defining metrics to evaluate the performance of a communication system. Subclasses should implement the forward method to calculate the metric.
Methods
Initialize the metric.
Compute metric with mean and standard deviation.
Forward pass through the metric.
- __init__(name: str | None = None, *args: Any, **kwargs: Any)[source]
Initialize the metric.
- Parameters:
name (Optional[str]) – Name of the metric
*args – Variable length argument list.
**kwargs – Arbitrary keyword arguments.
- abstractmethod forward(x: Tensor, y: Tensor, *args: Any, **kwargs: Any) Tensor[source]
Forward pass through the metric.
- Parameters:
x (torch.Tensor) – The first input tensor (typically predictions)
y (torch.Tensor) – The second input tensor (typically targets)
*args – Variable length argument list.
**kwargs – Arbitrary keyword arguments.
- Returns:
The calculated metric value
- Return type:
- compute_with_stats(x: Tensor, y: Tensor, *args: Any, **kwargs: Any) Tuple[Tensor, Tensor][source]
Compute metric with mean and standard deviation.
- Parameters:
x (torch.Tensor) – The first input tensor (typically predictions)
y (torch.Tensor) – The second input tensor (typically targets)
*args – Variable length argument list.
**kwargs – Arbitrary keyword arguments.
- Returns:
Mean and standard deviation of the metric
- Return type:
Tuple[torch.Tensor, torch.Tensor]