Source code for kaira.metrics.signal.snr
"""Signal-to-Noise Ratio (SNR) metric.
SNR is a fundamental measure for quantifying the quality of a signal in the presence of noise,
widely used in communications and signal processing :cite:`goldsmith2005wireless` :cite:`sklar2001digital`.
"""
from typing import Any, Optional, Tuple
import torch
from torch import Tensor
from ..base import BaseMetric
from ..registry import MetricRegistry
[docs]
@MetricRegistry.register_metric("snr")
class SignalToNoiseRatio(BaseMetric):
"""Signal-to-Noise Ratio (SNR) metric.
SNR measures the ratio of signal power to noise power, often expressed in decibels (dB).
Higher values indicate better signal quality. It's a fundamental metric in signal processing
and communications :cite:`goldsmith2005wireless` :cite:`sklar2001digital`.
Attributes:
mode (str): Output mode - "db" for decibels or "linear" for linear ratio.
"""
[docs]
def __init__(self, name: Optional[str] = None, mode: str = "db", *args: Any, **kwargs: Any):
"""Initialize the SNR metric.
Args:
name (Optional[str]): Optional name for the metric
mode (str): Output mode - "db" for decibels or "linear" for linear ratio
*args: Variable length argument list passed to the base class.
**kwargs: Arbitrary keyword arguments passed to the base class.
"""
super().__init__(name=name or "SNR") # Pass only name
self.mode = mode.lower()
if self.mode not in ["db", "linear"]:
raise ValueError("Mode must be either 'db' or 'linear'")
[docs]
def forward(self, x: Tensor, y: Tensor, *args: Any, **kwargs: Any) -> Tensor:
"""Compute the Signal-to-Noise Ratio (SNR).
Args:
x (Tensor): The original (clean) signal tensor.
y (Tensor): The noisy signal tensor.
*args: Variable length argument list (unused).
**kwargs: Arbitrary keyword arguments (unused).
Returns:
Tensor: The computed SNR value(s). If input is batched, returns SNR per batch element.
"""
# Ensure inputs are tensors
if not isinstance(x, Tensor) or not isinstance(y, Tensor):
raise TypeError(f"Inputs must be torch.Tensor, got {type(x)} and {type(y)}")
# Ensure inputs have the same shape
if x.shape != y.shape:
raise ValueError(f"Input shapes must match: {x.shape} vs {y.shape}")
# Calculate noise
noise = y - x
# Check for batch dimension (assuming dim > 1 implies batching)
is_batched = x.dim() > 1 and x.shape[0] > 1
if is_batched:
result = []
for i in range(x.size(0)):
# Handle complex signals
if torch.is_complex(x):
signal_power = torch.mean(torch.abs(x[i]) ** 2)
noise_power = torch.mean(torch.abs(noise[i]) ** 2)
else:
# Calculate power of signal and noise
signal_power = torch.mean(x[i] ** 2)
noise_power = torch.mean(noise[i] ** 2)
# Avoid division by zero
eps = torch.finfo(torch.float32).eps
# For perfect signal (no noise), return very high value approaching infinity
if noise_power < eps:
result.append(torch.tensor(float("inf")))
else:
# Calculate SNR
snr_linear = signal_power / (noise_power + eps)
if self.mode == "db":
# Convert to dB: 10 * log10(signal_power / noise_power)
snr = 10 * torch.log10(snr_linear)
else:
# Return linear ratio
snr = snr_linear
result.append(snr)
return torch.stack(result)
else:
# Handle complex signals
if torch.is_complex(x):
signal_power = torch.mean(torch.abs(x) ** 2)
noise_power = torch.mean(torch.abs(noise) ** 2)
else:
# Calculate power of signal and noise
signal_power = torch.mean(x**2)
noise_power = torch.mean(noise**2)
# Avoid division by zero
eps = torch.finfo(torch.float32).eps
# For perfect signal (no noise), return very high value approaching infinity
if noise_power < eps:
return torch.tensor(float("inf"))
# Calculate SNR in linear form
snr_linear = signal_power / (noise_power + eps)
# Convert to dB if needed
if self.mode == "db":
snr = 10 * torch.log10(snr_linear)
else:
snr = snr_linear
# Return scalar tensor
return snr.squeeze()
[docs]
def compute_with_stats(self, x: Tensor, y: Tensor, *args: Any, **kwargs: Any) -> Tuple[Tensor, Tensor]:
"""Compute SNR with mean and standard deviation across batches.
Args:
x (Tensor): The original (clean) signal tensor (batched).
y (Tensor): The noisy signal tensor (batched).
*args: Variable length argument list (unused).
**kwargs: Arbitrary keyword arguments (unused).
Returns:
Tuple[Tensor, Tensor]: Mean and standard deviation of the SNR values across the batch.
"""
values = self.forward(x, y, *args, **kwargs)
# Handle potential inf values before calculating stats
values = values[torch.isfinite(values)]
if values.numel() == 0:
# Return NaN if all values were inf or input was empty
return torch.tensor(float("nan")), torch.tensor(float("nan"))
return values.mean(), values.std()
[docs]
def reset(self) -> None:
"""Reset accumulated statistics.
For SNR, there are no accumulated statistics to reset as it's a direct computation.
"""
pass
# Alias for backward compatibility
SNR = SignalToNoiseRatio