kaira.constraints.AveragePowerConstraint

Inheritance diagram of AveragePowerConstraint

Inheritance diagram for AveragePowerConstraint

class kaira.constraints.AveragePowerConstraint(average_power: float, *args, **kwargs)[source]

Bases: BaseConstraint

Scales signal to achieve specified average power per sample.

This module applies a constraint on the average power of the input tensor. It ensures that the average power (power per sample) does not exceed a specified limit. Average power constraints are essential in communications systems for meeting regulatory requirements and optimizing signal-to-noise ratio [Goldsmith, 2005] [Proakis and Salehi, 2007].

Unlike the TotalPowerConstraint which constrains the sum of power across all samples, this constraint focuses on the average power per sample. It automatically handles both real and complex signals, applying appropriate power scaling for complex signals.

average_power

The maximum allowed average power

Type:

float

power_avg_factor

Precomputed square root of average power for efficiency

Type:

torch.Tensor

Methods

__init__

Initialize the AveragePowerConstraint module.

forward

Apply the average power constraint to the input tensor.

get_dimensions

Helper method to get all dimensions except batch for calculating norms/means.

__init__(average_power: float, *args, **kwargs) None[source]

Initialize the AveragePowerConstraint module.

Parameters:
  • average_power (float) – The target average power per sample in linear units (not dB). The constraint will scale the signal to achieve exactly this average power level for both real and complex signals.

  • *args – Variable length argument list.

  • **kwargs – Arbitrary keyword arguments.

forward(x: Tensor, *args, **kwargs) Tensor[source]

Apply the average power constraint to the input tensor.

Normalizes the input tensor to have exactly the specified average power. Automatically handles both real and complex-valued inputs.

Parameters:
  • x (torch.Tensor) – The input tensor of any shape (real or complex)

  • *args – Variable length argument list.

  • **kwargs – Arbitrary keyword arguments.

Returns:

The scaled tensor with the same shape as input, adjusted to

have exactly the target average power

Return type:

torch.Tensor

Note

The power is calculated across all dimensions. For complex signals, power is distributed between real and imaginary components. A small epsilon (1e-8) is added to the denominator to prevent division by zero.

static get_dimensions(x: Tensor, exclude_batch: bool = True) Tuple[int, ...]

Helper method to get all dimensions except batch for calculating norms/means.

Utility function to generate dimension indices for reduction operations like mean or norm. Typically used to calculate signal properties across all dimensions except the batch dimension.

Parameters:
  • x (torch.Tensor) – Input tensor

  • exclude_batch (bool, optional) – Whether to exclude the batch dimension (first dimension). Defaults to True.

Returns:

Dimensions to use for reduction operations (e.g., mean, norm)

Return type:

Tuple[int, …]

Example

>>> x = torch.randn(32, 4, 128)  # [batch, antennas, time]
>>> dims = BaseConstraint.get_dimensions(x)
>>> # dims will be (1, 2) for summing across antennas and time