kaira.constraints.SpectralMaskConstraint

Inheritance diagram for SpectralMaskConstraint
- class kaira.constraints.SpectralMaskConstraint(mask: Tensor, *args, **kwargs)[source]
Bases:
BaseConstraintRestricts signal frequency components to comply with regulatory spectral masks.
Ensures the signal’s spectrum complies with regulatory requirements by limiting the power spectral density at specific frequencies. This is particularly important for preventing interference with adjacent channels or frequency bands [Weiss and Jondral, 2004] [Federal Communications Commission, 2002].
The constraint works in the frequency domain by applying a scaling operation to frequency components that exceed the mask.
- mask
Spectral mask defining maximum power per frequency bin
- Type:
Methods
Initialize the spectral mask constraint.
Apply spectral mask constraint.
Helper method to get all dimensions except batch for calculating norms/means.
- __init__(mask: Tensor, *args, **kwargs) None[source]
Initialize the spectral mask constraint.
- Parameters:
mask (torch.Tensor) – Spectral mask tensor defining maximum power per frequency bin. The shape of this tensor should match the last dimension of the input signal after FFT transformation.
*args – Variable length argument list.
**kwargs – Arbitrary keyword arguments.
- forward(x: Tensor, *args, **kwargs) Tensor[source]
Apply spectral mask constraint.
Transforms the signal to the frequency domain, applies the spectral mask by scaling frequency components that exceed the mask, then transforms back to the time domain.
- Parameters:
x (torch.Tensor) – Input tensor in time domain
*args – Variable length argument list.
**kwargs – Arbitrary keyword arguments.
- Returns:
- Spectral mask constrained signal in time domain with the
same shape as the input
- Return type:
Note
This operation preserves the signal phase while scaling the magnitude to comply with the mask.
- static get_dimensions(x: Tensor, exclude_batch: bool = True) Tuple[int, ...]
Helper method to get all dimensions except batch for calculating norms/means.
Utility function to generate dimension indices for reduction operations like mean or norm. Typically used to calculate signal properties across all dimensions except the batch dimension.
- Parameters:
x (torch.Tensor) – Input tensor
exclude_batch (bool, optional) – Whether to exclude the batch dimension (first dimension). Defaults to True.
- Returns:
Dimensions to use for reduction operations (e.g., mean, norm)
- Return type:
Tuple[int, …]
Example
>>> x = torch.randn(32, 4, 128) # [batch, antennas, time] >>> dims = BaseConstraint.get_dimensions(x) >>> # dims will be (1, 2) for summing across antennas and time