Source code for kaira.constraints.composite
"""Composite constraint implementation for combining multiple constraints.
This module provides the CompositeConstraint class, which allows multiple constraints to be applied
sequentially as a single unified constraint. This enables modular constraint creation and
composition for complex signal requirements.
"""
from typing import Sequence
import torch
from torch import nn
from .base import BaseConstraint
[docs]
class CompositeConstraint(BaseConstraint):
"""Applies multiple constraints in sequence.
This constraint combines multiple independent constraints and applies them
in sequence to the input tensor. This allows for more complex constraint
compositions like applying both power and spectral constraints together.
Attributes:
constraints (nn.ModuleList): List of constraint modules to apply in sequence
Example:
>>> power_constraint = TotalPowerConstraint(1.0)
>>> papr_constraint = PAPRConstraint(4.0)
>>> combined = CompositeConstraint([power_constraint, papr_constraint])
>>> constrained_signal = combined(input_signal)
Note:
When a composite constraint is applied, each component constraint is applied
in the order they were provided. This ordering can significantly affect the
final result, as constraints may interact with each other.
"""
[docs]
def __init__(self, constraints: Sequence[BaseConstraint] | nn.ModuleList, *args, **kwargs) -> None:
"""Initialize the composite constraint with a list of constraints.
Args:
constraints (Sequence[BaseConstraint] | nn.ModuleList): List of constraint modules to apply in sequence
*args: Variable length argument list.
**kwargs: Arbitrary keyword arguments.
Raises:
TypeError: If any element in constraints is not a BaseConstraint
"""
super().__init__(*args, **kwargs) # Call parent constructor
# Validate that all constraints are BaseConstraint instances
for constraint in constraints:
if not isinstance(constraint, BaseConstraint):
raise TypeError(f"Expected BaseConstraint, got {type(constraint).__name__}")
self.constraints = constraints if isinstance(constraints, torch.nn.ModuleList) else torch.nn.ModuleList(constraints)
[docs]
def add_constraint(self, constraint: BaseConstraint) -> None:
"""Add a new constraint to the composite.
Args:
constraint (BaseConstraint): New constraint to add to the sequence
"""
if not isinstance(constraint, BaseConstraint):
raise TypeError(f"Expected BaseConstraint, got {type(constraint).__name__}")
self.constraints.append(constraint)
[docs]
def forward(self, x, *args, **kwargs):
"""Apply the composite constraint to the input signal.
Args:
x: Input signal to constrain
*args: Variable length argument list.
**kwargs: Arbitrary keyword arguments.
Returns:
Constrained signal after applying all component constraints
"""
for step in self.constraints:
x = step(x, *args, **kwargs)
return x