Source code for kaira.metrics.base
"""Base class for defining evaluation metrics in Kaira."""
from abc import ABC, abstractmethod
from typing import Any, Optional, Tuple
import torch
from torch import nn
# A base class for metrics.
[docs]
class BaseMetric(nn.Module, ABC):
"""Base Metric Module.
This is an abstract base class for defining metrics to evaluate the performance of a
communication system. Subclasses should implement the forward method to calculate the metric.
"""
[docs]
def __init__(self, name: Optional[str] = None, *args: Any, **kwargs: Any):
"""Initialize the metric.
Args:
name (Optional[str]): Name of the metric
*args: Variable length argument list.
**kwargs: Arbitrary keyword arguments.
"""
super().__init__()
self.name = name or self.__class__.__name__
[docs]
@abstractmethod
def forward(self, x: torch.Tensor, y: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:
"""Forward pass through the metric.
Args:
x (torch.Tensor): The first input tensor (typically predictions)
y (torch.Tensor): The second input tensor (typically targets)
*args: Variable length argument list.
**kwargs: Arbitrary keyword arguments.
Returns:
torch.Tensor: The calculated metric value
"""
pass
[docs]
def compute_with_stats(self, x: torch.Tensor, y: torch.Tensor, *args: Any, **kwargs: Any) -> Tuple[torch.Tensor, torch.Tensor]:
"""Compute metric with mean and standard deviation.
Args:
x (torch.Tensor): The first input tensor (typically predictions)
y (torch.Tensor): The second input tensor (typically targets)
*args: Variable length argument list.
**kwargs: Arbitrary keyword arguments.
Returns:
Tuple[torch.Tensor, torch.Tensor]: Mean and standard deviation of the metric
"""
values = self.forward(x, y, *args, **kwargs)
return values.mean(), values.std()
def __str__(self) -> str:
"""Return string representation of the metric."""
return f"{self.name} Metric"