kaira.constraints.TotalPowerConstraint

Inheritance diagram of TotalPowerConstraint

Inheritance diagram for TotalPowerConstraint

class kaira.constraints.TotalPowerConstraint(total_power: float, *args, **kwargs)[source]

Bases: BaseConstraint

Normalizes signal to achieve exact total power regardless of input signal power.

This module applies a constraint on the total power of the input tensor. It ensures that the total power does not exceed a specified limit by scaling the signal appropriately [Wunder et al., 2013].

The constraint normalizes the signal to exactly match the specified power level, regardless of the input signal’s power. It automatically detects complex signals and applies the appropriate power scaling, distributing power between real and imaginary components as needed.

total_power

The maximum allowed total power

Type:

float

total_power_factor

Precomputed square root of total power for efficiency

Type:

torch.Tensor

Methods

__init__

Initialize the TotalPowerConstraint module.

forward

Apply the total power constraint to the input tensor.

get_dimensions

Helper method to get all dimensions except batch for calculating norms/means.

__init__(total_power: float, *args, **kwargs) None[source]

Initialize the TotalPowerConstraint module.

Parameters:
  • total_power (float) – The target total power for the signal in linear units (not dB). The constraint will scale the signal to achieve exactly this power level for both real and complex signals.

  • *args – Variable length argument list.

  • **kwargs – Arbitrary keyword arguments.

forward(x: Tensor, *args, **kwargs) Tensor[source]

Apply the total power constraint to the input tensor.

Normalizes the input tensor to have exactly the specified total power. Automatically handles both real and complex-valued inputs.

Parameters:
  • x (torch.Tensor) – The input tensor of any shape (real or complex)

  • *args – Variable length argument list.

  • **kwargs – Arbitrary keyword arguments.

Returns:

The scaled tensor with the same shape as input, adjusted to

have exactly the target total power

Return type:

torch.Tensor

Note

The power is calculated across all dimensions except the batch dimension. For complex signals, power is distributed between real and imaginary components. A small epsilon (1e-8) is added to the denominator to prevent division by zero.

static get_dimensions(x: Tensor, exclude_batch: bool = True) Tuple[int, ...]

Helper method to get all dimensions except batch for calculating norms/means.

Utility function to generate dimension indices for reduction operations like mean or norm. Typically used to calculate signal properties across all dimensions except the batch dimension.

Parameters:
  • x (torch.Tensor) – Input tensor

  • exclude_batch (bool, optional) – Whether to exclude the batch dimension (first dimension). Defaults to True.

Returns:

Dimensions to use for reduction operations (e.g., mean, norm)

Return type:

Tuple[int, …]

Example

>>> x = torch.randn(32, 4, 128)  # [batch, antennas, time]
>>> dims = BaseConstraint.get_dimensions(x)
>>> # dims will be (1, 2) for summing across antennas and time