kaira.constraints.BaseConstraint

Inheritance diagram of BaseConstraint

Inheritance diagram for BaseConstraint

class kaira.constraints.BaseConstraint(*args, **kwargs)[source]

Bases: Module, ABC

Abstract foundation for implementing signal constraints in PyTorch-compatible format.

This is an abstract base class for defining constraints on transmitted signals. Subclasses should implement the forward method to apply the specific constraint logic.

All constraints inherit from both nn.Module and ABC (Abstract Base Class) to ensure they are PyTorch-compatible and require implementation of key methods.

Methods

__init__

Initialize the base constraint.

forward

Apply the constraint to the input tensor.

get_dimensions

Helper method to get all dimensions except batch for calculating norms/means.

__init__(*args, **kwargs) None[source]

Initialize the base constraint.

Parameters:
  • *args – Variable length argument list.

  • **kwargs – Arbitrary keyword arguments.

abstractmethod forward(x: Tensor, *args, **kwargs) Tensor[source]

Apply the constraint to the input tensor.

This abstract method must be implemented by all constraint classes. The implementation should apply the specific constraint logic to the input tensor while preserving its essential dimensions.

Parameters:
  • x (torch.Tensor) – Input tensor to apply the constraint to

  • *args – Variable length argument list.

  • **kwargs – Arbitrary keyword arguments.

Returns:

The constrained tensor with the same essential dimensions as the input

Return type:

torch.Tensor

Raises:

NotImplementedError – If the subclass does not implement this method

static get_dimensions(x: Tensor, exclude_batch: bool = True) Tuple[int, ...][source]

Helper method to get all dimensions except batch for calculating norms/means.

Utility function to generate dimension indices for reduction operations like mean or norm. Typically used to calculate signal properties across all dimensions except the batch dimension.

Parameters:
  • x (torch.Tensor) – Input tensor

  • exclude_batch (bool, optional) – Whether to exclude the batch dimension (first dimension). Defaults to True.

Returns:

Dimensions to use for reduction operations (e.g., mean, norm)

Return type:

Tuple[int, …]

Example

>>> x = torch.randn(32, 4, 128)  # [batch, antennas, time]
>>> dims = BaseConstraint.get_dimensions(x)
>>> # dims will be (1, 2) for summing across antennas and time