kaira.constraints.LambdaConstraint

Inheritance diagram of LambdaConstraint

Inheritance diagram for LambdaConstraint

class kaira.constraints.LambdaConstraint(function: Callable[[Tensor], Tensor], *args, **kwargs)[source]

Bases: BaseConstraint

Constraint that applies a user-defined function to the signal.

This constraint allows users to pass any function that operates on tensors to be used as a constraint, providing flexibility without requiring new class implementations for simple constraints.

function

The function to apply to the input tensor

Type:

Callable

Methods

__init__

Initialize with a user-defined constraint function.

forward

Apply the user-defined function to the input tensor.

get_dimensions

Helper method to get all dimensions except batch for calculating norms/means.

__init__(function: Callable[[Tensor], Tensor], *args, **kwargs)[source]

Initialize with a user-defined constraint function.

Parameters:
  • function (Callable[[torch.Tensor], torch.Tensor]) – A function that takes a torch.Tensor as input and returns a torch.Tensor as output. The function should maintain tensor dimensions.

  • *args – Variable length argument list.

  • **kwargs – Arbitrary keyword arguments.

forward(x: Tensor, *args, **kwargs) Tensor[source]

Apply the user-defined function to the input tensor.

Parameters:
  • x (torch.Tensor) – The input signal tensor

  • *args – Variable length argument list.

  • **kwargs – Arbitrary keyword arguments.

Returns:

The result of applying the function to x

Return type:

torch.Tensor

static get_dimensions(x: Tensor, exclude_batch: bool = True) Tuple[int, ...]

Helper method to get all dimensions except batch for calculating norms/means.

Utility function to generate dimension indices for reduction operations like mean or norm. Typically used to calculate signal properties across all dimensions except the batch dimension.

Parameters:
  • x (torch.Tensor) – Input tensor

  • exclude_batch (bool, optional) – Whether to exclude the batch dimension (first dimension). Defaults to True.

Returns:

Dimensions to use for reduction operations (e.g., mean, norm)

Return type:

Tuple[int, …]

Example

>>> x = torch.randn(32, 4, 128)  # [batch, antennas, time]
>>> dims = BaseConstraint.get_dimensions(x)
>>> # dims will be (1, 2) for summing across antennas and time