kaira.constraints.CompositeConstraint

Inheritance diagram of CompositeConstraint

Inheritance diagram for CompositeConstraint

class kaira.constraints.CompositeConstraint(constraints: Sequence[BaseConstraint] | ModuleList, *args, **kwargs)[source]

Bases: BaseConstraint

Applies multiple constraints in sequence.

This constraint combines multiple independent constraints and applies them in sequence to the input tensor. This allows for more complex constraint compositions like applying both power and spectral constraints together.

constraints

List of constraint modules to apply in sequence

Type:

nn.ModuleList

Example

>>> power_constraint = TotalPowerConstraint(1.0)
>>> papr_constraint = PAPRConstraint(4.0)
>>> combined = CompositeConstraint([power_constraint, papr_constraint])
>>> constrained_signal = combined(input_signal)

Note

When a composite constraint is applied, each component constraint is applied in the order they were provided. This ordering can significantly affect the final result, as constraints may interact with each other.

Methods

__init__

Initialize the composite constraint with a list of constraints.

add_constraint

Add a new constraint to the composite.

forward

Apply the composite constraint to the input signal.

get_dimensions

Helper method to get all dimensions except batch for calculating norms/means.

__init__(constraints: Sequence[BaseConstraint] | ModuleList, *args, **kwargs) None[source]

Initialize the composite constraint with a list of constraints.

Parameters:
  • constraints (Sequence[BaseConstraint] | nn.ModuleList) – List of constraint modules to apply in sequence

  • *args – Variable length argument list.

  • **kwargs – Arbitrary keyword arguments.

Raises:

TypeError – If any element in constraints is not a BaseConstraint

add_constraint(constraint: BaseConstraint) None[source]

Add a new constraint to the composite.

Parameters:

constraint (BaseConstraint) – New constraint to add to the sequence

forward(x, *args, **kwargs)[source]

Apply the composite constraint to the input signal.

Parameters:
  • x – Input signal to constrain

  • *args – Variable length argument list.

  • **kwargs – Arbitrary keyword arguments.

Returns:

Constrained signal after applying all component constraints

static get_dimensions(x: Tensor, exclude_batch: bool = True) Tuple[int, ...]

Helper method to get all dimensions except batch for calculating norms/means.

Utility function to generate dimension indices for reduction operations like mean or norm. Typically used to calculate signal properties across all dimensions except the batch dimension.

Parameters:
  • x (torch.Tensor) – Input tensor

  • exclude_batch (bool, optional) – Whether to exclude the batch dimension (first dimension). Defaults to True.

Returns:

Dimensions to use for reduction operations (e.g., mean, norm)

Return type:

Tuple[int, …]

Example

>>> x = torch.randn(32, 4, 128)  # [batch, antennas, time]
>>> dims = BaseConstraint.get_dimensions(x)
>>> # dims will be (1, 2) for summing across antennas and time