kaira.constraints.AveragePowerConstraint

Inheritance diagram for AveragePowerConstraint
- class kaira.constraints.AveragePowerConstraint(average_power: float, *args, **kwargs)[source]
Bases:
BaseConstraintScales signal to achieve specified average power per sample.
This module applies a constraint on the average power of the input tensor. It ensures that the average power (power per sample) does not exceed a specified limit. Average power constraints are essential in communications systems for meeting regulatory requirements and optimizing signal-to-noise ratio [Goldsmith, 2005] [Proakis and Salehi, 2007].
Unlike the TotalPowerConstraint which constrains the sum of power across all samples, this constraint focuses on the average power per sample. It automatically handles both real and complex signals, applying appropriate power scaling for complex signals.
- power_avg_factor
Precomputed square root of average power for efficiency
- Type:
Methods
Initialize the AveragePowerConstraint module.
Apply the average power constraint to the input tensor.
Helper method to get all dimensions except batch for calculating norms/means.
- __init__(average_power: float, *args, **kwargs) None[source]
Initialize the AveragePowerConstraint module.
- Parameters:
average_power (float) – The target average power per sample in linear units (not dB). The constraint will scale the signal to achieve exactly this average power level for both real and complex signals.
*args – Variable length argument list.
**kwargs – Arbitrary keyword arguments.
- forward(x: Tensor, *args, **kwargs) Tensor[source]
Apply the average power constraint to the input tensor.
Normalizes the input tensor to have exactly the specified average power. Automatically handles both real and complex-valued inputs.
- Parameters:
x (torch.Tensor) – The input tensor of any shape (real or complex)
*args – Variable length argument list.
**kwargs – Arbitrary keyword arguments.
- Returns:
- The scaled tensor with the same shape as input, adjusted to
have exactly the target average power
- Return type:
Note
The power is calculated across all dimensions. For complex signals, power is distributed between real and imaginary components. A small epsilon (1e-8) is added to the denominator to prevent division by zero.
- static get_dimensions(x: Tensor, exclude_batch: bool = True) Tuple[int, ...]
Helper method to get all dimensions except batch for calculating norms/means.
Utility function to generate dimension indices for reduction operations like mean or norm. Typically used to calculate signal properties across all dimensions except the batch dimension.
- Parameters:
x (torch.Tensor) – Input tensor
exclude_batch (bool, optional) – Whether to exclude the batch dimension (first dimension). Defaults to True.
- Returns:
Dimensions to use for reduction operations (e.g., mean, norm)
- Return type:
Tuple[int, …]
Example
>>> x = torch.randn(32, 4, 128) # [batch, antennas, time] >>> dims = BaseConstraint.get_dimensions(x) >>> # dims will be (1, 2) for summing across antennas and time