kaira.constraints.CompositeConstraint

Inheritance diagram for CompositeConstraint
- class kaira.constraints.CompositeConstraint(constraints: Sequence[BaseConstraint] | ModuleList, *args, **kwargs)[source]
Bases:
BaseConstraintApplies multiple constraints in sequence.
This constraint combines multiple independent constraints and applies them in sequence to the input tensor. This allows for more complex constraint compositions like applying both power and spectral constraints together.
- constraints
List of constraint modules to apply in sequence
- Type:
nn.ModuleList
Example
>>> power_constraint = TotalPowerConstraint(1.0) >>> papr_constraint = PAPRConstraint(4.0) >>> combined = CompositeConstraint([power_constraint, papr_constraint]) >>> constrained_signal = combined(input_signal)
Note
When a composite constraint is applied, each component constraint is applied in the order they were provided. This ordering can significantly affect the final result, as constraints may interact with each other.
Methods
Initialize the composite constraint with a list of constraints.
Add a new constraint to the composite.
Apply the composite constraint to the input signal.
Helper method to get all dimensions except batch for calculating norms/means.
- __init__(constraints: Sequence[BaseConstraint] | ModuleList, *args, **kwargs) None[source]
Initialize the composite constraint with a list of constraints.
- Parameters:
constraints (Sequence[BaseConstraint] | nn.ModuleList) – List of constraint modules to apply in sequence
*args – Variable length argument list.
**kwargs – Arbitrary keyword arguments.
- Raises:
TypeError – If any element in constraints is not a BaseConstraint
- add_constraint(constraint: BaseConstraint) None[source]
Add a new constraint to the composite.
- Parameters:
constraint (BaseConstraint) – New constraint to add to the sequence
- forward(x, *args, **kwargs)[source]
Apply the composite constraint to the input signal.
- Parameters:
x – Input signal to constrain
*args – Variable length argument list.
**kwargs – Arbitrary keyword arguments.
- Returns:
Constrained signal after applying all component constraints
- static get_dimensions(x: Tensor, exclude_batch: bool = True) Tuple[int, ...]
Helper method to get all dimensions except batch for calculating norms/means.
Utility function to generate dimension indices for reduction operations like mean or norm. Typically used to calculate signal properties across all dimensions except the batch dimension.
- Parameters:
x (torch.Tensor) – Input tensor
exclude_batch (bool, optional) – Whether to exclude the batch dimension (first dimension). Defaults to True.
- Returns:
Dimensions to use for reduction operations (e.g., mean, norm)
- Return type:
Tuple[int, …]
Example
>>> x = torch.randn(32, 4, 128) # [batch, antennas, time] >>> dims = BaseConstraint.get_dimensions(x) >>> # dims will be (1, 2) for summing across antennas and time