Source code for kaira.constraints.signal
"""Signal characteristic constraints for communication systems.
This module provides constraints related to signal characteristics such as amplitude limitations
and spectral properties. These constraints are essential for ensuring that transmitted signals
comply with hardware limitations and regulatory requirements :cite:`han2005overview` :cite:`armstrong2002peak`.
"""
import torch
from .base import BaseConstraint
from .registry import ConstraintRegistry
[docs]
@ConstraintRegistry.register_constraint()
class PeakAmplitudeConstraint(BaseConstraint):
"""Enforces maximum signal amplitude by clipping values that exceed threshold.
Limits the maximum amplitude of the signal to prevent clipping in digital-to-analog
converters (DACs) and power amplifiers. This constraint applies a hard clipping
operation to ensure signal values remain within the specified bounds.
Peak amplitude constraints are critical for practical communication systems as discussed
in :cite:`armstrong2002peak` and :cite:`jiang2008overview`.
Attributes:
max_amplitude (float): Maximum allowed amplitude value
"""
[docs]
def __init__(self, max_amplitude: float, *args, **kwargs) -> None:
"""Initialize the peak amplitude constraint.
Args:
max_amplitude (float): Maximum allowed amplitude. Signal values exceeding
this threshold (positive or negative) will be clipped.
*args: Variable length argument list.
**kwargs: Arbitrary keyword arguments.
"""
super().__init__(*args, **kwargs)
self.max_amplitude = max_amplitude
[docs]
def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
"""Apply peak amplitude constraint.
Clips the input signal to ensure all values fall within the range
[-max_amplitude, max_amplitude].
Args:
x (torch.Tensor): Input tensor of any shape
*args: Variable length argument list.
**kwargs: Arbitrary keyword arguments.
Returns:
torch.Tensor: Amplitude-constrained signal with the same shape as input
"""
# Simple clipping approach
return torch.clamp(x, -self.max_amplitude, self.max_amplitude)
[docs]
@ConstraintRegistry.register_constraint()
class SpectralMaskConstraint(BaseConstraint):
"""Restricts signal frequency components to comply with regulatory spectral masks.
Ensures the signal's spectrum complies with regulatory requirements by limiting
the power spectral density at specific frequencies. This is particularly important
for preventing interference with adjacent channels or frequency bands
:cite:`weiss2004spectrum` :cite:`fcc2002revision`.
The constraint works in the frequency domain by applying a scaling operation
to frequency components that exceed the mask.
Attributes:
mask (torch.Tensor): Spectral mask defining maximum power per frequency bin
"""
[docs]
def __init__(self, mask: torch.Tensor, *args, **kwargs) -> None:
"""Initialize the spectral mask constraint.
Args:
mask (torch.Tensor): Spectral mask tensor defining maximum power per frequency bin.
The shape of this tensor should match the last dimension of the input signal
after FFT transformation.
*args: Variable length argument list.
**kwargs: Arbitrary keyword arguments.
"""
super().__init__(*args, **kwargs)
self.register_buffer("mask", mask)
[docs]
def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
"""Apply spectral mask constraint.
Transforms the signal to the frequency domain, applies the spectral mask by
scaling frequency components that exceed the mask, then transforms back to
the time domain.
Args:
x (torch.Tensor): Input tensor in time domain
*args: Variable length argument list.
**kwargs: Arbitrary keyword arguments.
Returns:
torch.Tensor: Spectral mask constrained signal in time domain with the
same shape as the input
Note:
This operation preserves the signal phase while scaling the magnitude
to comply with the mask.
"""
x_freq = torch.fft.fft(x, dim=-1)
# Calculate power in frequency domain
power_spectrum = torch.abs(x_freq) ** 2
# Apply mask by scaling where needed
excess_indices = power_spectrum > self.mask.expand_as(power_spectrum)
if torch.any(excess_indices):
# Scale frequency components to meet the mask
scale_factor = torch.sqrt(self.mask / (power_spectrum + 1e-8))
scale_factor = torch.where(excess_indices, scale_factor, torch.ones_like(scale_factor))
x_freq = x_freq * scale_factor
# Convert back to time domain
return torch.fft.ifft(x_freq, dim=-1).real