"""Utility functions for Signal-to-Noise Ratio (SNR) calculations and conversions."""
from typing import Optional, Tuple, Union
import torch
__all__ = [
"snr_db_to_linear",
"snr_linear_to_db",
"snr_to_noise_power",
"noise_power_to_snr",
"calculate_snr",
"add_noise_for_snr",
"estimate_signal_power",
]
[docs]
def snr_db_to_linear(snr_db: Union[float, torch.Tensor]) -> torch.Tensor:
"""Convert Signal-to-Noise Ratio from decibel to linear scale.
Args:
snr_db (Union[float, torch.Tensor]): Signal-to-Noise Ratio in decibels (dB).
Returns:
torch.Tensor: Signal-to-Noise Ratio in linear scale.
"""
if isinstance(snr_db, float):
snr_db = torch.tensor(snr_db)
return 10 ** (snr_db / 10.0)
[docs]
def snr_linear_to_db(snr_linear: Union[float, torch.Tensor]) -> torch.Tensor:
"""Convert Signal-to-Noise Ratio from linear scale to decibels.
Args:
snr_linear (Union[float, torch.Tensor]): SNR in linear scale.
Returns:
torch.Tensor: SNR in decibel (dB) scale.
Raises:
ValueError: If snr_linear contains negative values.
"""
if isinstance(snr_linear, float):
snr_linear = torch.tensor(snr_linear)
if torch.any(snr_linear < 0):
raise ValueError("SNR in linear scale must be positive")
# Handle zero explicitly to return -inf
if torch.any(snr_linear == 0):
if snr_linear.numel() == 1 and snr_linear.item() == 0:
return torch.tensor(float("-inf"))
else:
# For tensors with multiple elements, replace zeros with -inf after conversion
result = 10 * torch.log10(torch.clamp(snr_linear, min=torch.finfo(torch.float32).eps))
result[snr_linear == 0] = float("-inf")
return result
return 10 * torch.log10(snr_linear)
[docs]
def snr_to_noise_power(signal_power: Union[float, torch.Tensor], snr_db: Union[float, torch.Tensor]) -> torch.Tensor:
"""Convert SNR in dB to noise power given a signal power.
Args:
signal_power (Union[float, torch.Tensor]): Power of the signal.
snr_db (Union[float, torch.Tensor]): Signal-to-Noise Ratio in decibels (dB).
Returns:
torch.Tensor: Corresponding noise power for the specified SNR.
"""
# Handle signal_power
if isinstance(signal_power, float) or isinstance(signal_power, int):
signal_power = torch.tensor(signal_power, dtype=torch.float64)
elif not isinstance(signal_power, torch.Tensor):
# Handle NumPy arrays and other numeric types
signal_power = torch.tensor(float(signal_power), dtype=torch.float64)
else:
signal_power = signal_power.to(torch.float64)
# Handle snr_db
if isinstance(snr_db, float) or isinstance(snr_db, int):
snr_db = torch.tensor(snr_db, dtype=torch.float64)
elif not isinstance(snr_db, torch.Tensor):
# Handle NumPy arrays and other numeric types
snr_db = torch.tensor(float(snr_db), dtype=torch.float64)
else:
snr_db = snr_db.to(torch.float64)
snr_linear = snr_db_to_linear(snr_db)
result = signal_power / snr_linear
return result.to(torch.float32)
[docs]
def noise_power_to_snr(signal_power: Union[float, torch.Tensor], noise_power: Union[float, torch.Tensor]) -> torch.Tensor:
"""Calculate SNR in dB given signal and noise power.
Args:
signal_power (Union[float, torch.Tensor]): Power of the signal.
noise_power (Union[float, torch.Tensor]): Power of the noise.
Returns:
torch.Tensor: Signal-to-Noise Ratio in decibels (dB).
Raises:
ValueError: If noise_power contains zero values.
"""
if isinstance(signal_power, float):
signal_power = torch.tensor(signal_power)
if isinstance(noise_power, float):
noise_power = torch.tensor(noise_power)
if signal_power.numel() > 1:
noise_power = noise_power.expand_as(signal_power)
if noise_power.numel() > 1 and signal_power.numel() == 1:
signal_power = signal_power.expand_as(noise_power)
if torch.any(noise_power <= 0):
raise ValueError("Noise power cannot be zero")
snr_linear = signal_power / noise_power
return 10 * torch.log10(snr_linear)
[docs]
def calculate_snr(
original_signal: torch.Tensor,
noisy_signal: torch.Tensor,
dim: Optional[Union[int, Tuple[int, ...]]] = None,
keepdim: bool = False,
) -> torch.Tensor:
"""Calculate the SNR between original and noisy signals.
Args:
original_signal (torch.Tensor): The original clean signal.
noisy_signal (torch.Tensor): The noisy signal (original signal plus noise).
dim (Optional[Union[int, Tuple[int, ...]]]): Dimensions to reduce when calculating power.
If None, uses all dimensions.
keepdim (bool): Whether to keep the reduced dimensions in the output. Default is False.
Returns:
torch.Tensor: SNR in decibels (dB).
Raises:
ValueError: If original and noisy signals have different shapes.
"""
# Ensure tensors have the same shape
if original_signal.shape != noisy_signal.shape:
raise ValueError("Original and noisy signals must have the same shape")
# Extract noise component
noise = noisy_signal - original_signal
# Calculate powers based on signal type
if torch.is_complex(original_signal):
original_power = torch.mean(torch.abs(original_signal) ** 2, dim=dim, keepdim=keepdim)
noise_power = torch.mean(torch.abs(noise) ** 2, dim=dim, keepdim=keepdim)
else:
original_power = torch.mean(original_signal**2, dim=dim, keepdim=keepdim)
noise_power = torch.mean(noise**2, dim=dim, keepdim=keepdim)
# Handle zero noise case
eps = torch.finfo(original_power.dtype).eps
noise_power = torch.clamp(noise_power, min=eps)
# Calculate SNR in dB
return 10 * torch.log10(original_power / noise_power)
[docs]
def add_noise_for_snr(
signal: torch.Tensor,
target_snr_db: Union[float, torch.Tensor],
dim: Optional[Union[int, Tuple[int, ...]]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Add Gaussian noise to achieve a target Signal-to-Noise Ratio.
Args:
signal (torch.Tensor): The original clean signal.
target_snr_db (Union[float, torch.Tensor]): Target Signal-to-Noise Ratio in decibels (dB).
dim (Optional[Union[int, Tuple[int, ...]]]): Dimensions to reduce when calculating power.
If None, uses all dimensions.
Returns:
Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
- Noisy signal (original signal with added noise)
- The generated noise component
"""
# Calculate signal power
if torch.is_complex(signal):
signal_power = torch.mean(torch.abs(signal) ** 2, dim=dim, keepdim=True)
else:
signal_power = torch.mean(signal**2, dim=dim, keepdim=True)
# Calculate required noise power
noise_power = snr_to_noise_power(signal_power, target_snr_db)
# Generate noise with the right power
if torch.is_complex(signal):
# For complex signals, generate complex noise
noise_std = torch.sqrt(noise_power / 2)
real_noise = torch.randn_like(signal.real) * noise_std
imag_noise = torch.randn_like(signal.imag) * noise_std
noise = torch.complex(real_noise, imag_noise)
else:
noise_std = torch.sqrt(noise_power)
noise = torch.randn_like(signal) * noise_std
noisy_signal = signal + noise
return noisy_signal, noise
[docs]
def estimate_signal_power(signal: torch.Tensor, dim: Optional[Union[int, Tuple[int, ...]]] = None, keepdim: bool = False) -> torch.Tensor:
"""Estimate the power of a signal.
Args:
signal (torch.Tensor): The input signal (real or complex).
dim (Optional[Union[int, Tuple[int, ...]]]): Dimensions to reduce when calculating power.
If None, uses all dimensions.
keepdim (bool): Whether to keep the reduced dimensions in the output. Default is False.
Returns:
torch.Tensor: Signal power estimation.
"""
if torch.is_complex(signal):
return torch.mean(torch.abs(signal) ** 2, dim=dim, keepdim=keepdim)
else:
return torch.mean(signal**2, dim=dim, keepdim=keepdim)