kaira.constraints.SpectralMaskConstraint

Inheritance diagram of SpectralMaskConstraint

Inheritance diagram for SpectralMaskConstraint

class kaira.constraints.SpectralMaskConstraint(mask: Tensor, *args, **kwargs)[source]

Bases: BaseConstraint

Restricts signal frequency components to comply with regulatory spectral masks.

Ensures the signal’s spectrum complies with regulatory requirements by limiting the power spectral density at specific frequencies. This is particularly important for preventing interference with adjacent channels or frequency bands [Weiss and Jondral, 2004] [Federal Communications Commission, 2002].

The constraint works in the frequency domain by applying a scaling operation to frequency components that exceed the mask.

mask

Spectral mask defining maximum power per frequency bin

Type:

torch.Tensor

Methods

__init__

Initialize the spectral mask constraint.

forward

Apply spectral mask constraint.

get_dimensions

Helper method to get all dimensions except batch for calculating norms/means.

__init__(mask: Tensor, *args, **kwargs) None[source]

Initialize the spectral mask constraint.

Parameters:
  • mask (torch.Tensor) – Spectral mask tensor defining maximum power per frequency bin. The shape of this tensor should match the last dimension of the input signal after FFT transformation.

  • *args – Variable length argument list.

  • **kwargs – Arbitrary keyword arguments.

forward(x: Tensor, *args, **kwargs) Tensor[source]

Apply spectral mask constraint.

Transforms the signal to the frequency domain, applies the spectral mask by scaling frequency components that exceed the mask, then transforms back to the time domain.

Parameters:
  • x (torch.Tensor) – Input tensor in time domain

  • *args – Variable length argument list.

  • **kwargs – Arbitrary keyword arguments.

Returns:

Spectral mask constrained signal in time domain with the

same shape as the input

Return type:

torch.Tensor

Note

This operation preserves the signal phase while scaling the magnitude to comply with the mask.

static get_dimensions(x: Tensor, exclude_batch: bool = True) Tuple[int, ...]

Helper method to get all dimensions except batch for calculating norms/means.

Utility function to generate dimension indices for reduction operations like mean or norm. Typically used to calculate signal properties across all dimensions except the batch dimension.

Parameters:
  • x (torch.Tensor) – Input tensor

  • exclude_batch (bool, optional) – Whether to exclude the batch dimension (first dimension). Defaults to True.

Returns:

Dimensions to use for reduction operations (e.g., mean, norm)

Return type:

Tuple[int, …]

Example

>>> x = torch.randn(32, 4, 128)  # [batch, antennas, time]
>>> dims = BaseConstraint.get_dimensions(x)
>>> # dims will be (1, 2) for summing across antennas and time