kaira.constraints.LambdaConstraint

Inheritance diagram for LambdaConstraint
- class kaira.constraints.LambdaConstraint(function: Callable[[Tensor], Tensor], *args, **kwargs)[source]
Bases:
BaseConstraintConstraint that applies a user-defined function to the signal.
This constraint allows users to pass any function that operates on tensors to be used as a constraint, providing flexibility without requiring new class implementations for simple constraints.
- function
The function to apply to the input tensor
- Type:
Callable
Methods
Initialize with a user-defined constraint function.
Apply the user-defined function to the input tensor.
Helper method to get all dimensions except batch for calculating norms/means.
- __init__(function: Callable[[Tensor], Tensor], *args, **kwargs)[source]
Initialize with a user-defined constraint function.
- Parameters:
function (Callable[[torch.Tensor], torch.Tensor]) – A function that takes a torch.Tensor as input and returns a torch.Tensor as output. The function should maintain tensor dimensions.
*args – Variable length argument list.
**kwargs – Arbitrary keyword arguments.
- forward(x: Tensor, *args, **kwargs) Tensor[source]
Apply the user-defined function to the input tensor.
- Parameters:
x (torch.Tensor) – The input signal tensor
*args – Variable length argument list.
**kwargs – Arbitrary keyword arguments.
- Returns:
The result of applying the function to x
- Return type:
- static get_dimensions(x: Tensor, exclude_batch: bool = True) Tuple[int, ...]
Helper method to get all dimensions except batch for calculating norms/means.
Utility function to generate dimension indices for reduction operations like mean or norm. Typically used to calculate signal properties across all dimensions except the batch dimension.
- Parameters:
x (torch.Tensor) – Input tensor
exclude_batch (bool, optional) – Whether to exclude the batch dimension (first dimension). Defaults to True.
- Returns:
Dimensions to use for reduction operations (e.g., mean, norm)
- Return type:
Tuple[int, …]
Example
>>> x = torch.randn(32, 4, 128) # [batch, antennas, time] >>> dims = BaseConstraint.get_dimensions(x) >>> # dims will be (1, 2) for summing across antennas and time