kaira.constraints.IdentityConstraint

Inheritance diagram for IdentityConstraint
- class kaira.constraints.IdentityConstraint(*args, **kwargs)[source]
Bases:
BaseConstraintIdentity constraint that returns the input signal unchanged.
This is a simple passthrough constraint that does not modify the input signal. It can be used when a constraint is expected in an interface but no actual constraint should be applied.
Methods
Initialize the identity constraint.
Forward pass that returns the input tensor unchanged.
Helper method to get all dimensions except batch for calculating norms/means.
- __init__(*args, **kwargs) None[source]
Initialize the identity constraint.
- Parameters:
*args – Variable length argument list.
**kwargs – Arbitrary keyword arguments.
- forward(x: Tensor, *args, **kwargs) Tensor[source]
Forward pass that returns the input tensor unchanged.
- Parameters:
x (torch.Tensor) – The input signal tensor
*args – Variable length argument list.
**kwargs – Arbitrary keyword arguments.
- Returns:
The same input tensor x (unchanged)
- Return type:
- static get_dimensions(x: Tensor, exclude_batch: bool = True) Tuple[int, ...]
Helper method to get all dimensions except batch for calculating norms/means.
Utility function to generate dimension indices for reduction operations like mean or norm. Typically used to calculate signal properties across all dimensions except the batch dimension.
- Parameters:
x (torch.Tensor) – Input tensor
exclude_batch (bool, optional) – Whether to exclude the batch dimension (first dimension). Defaults to True.
- Returns:
Dimensions to use for reduction operations (e.g., mean, norm)
- Return type:
Tuple[int, …]
Example
>>> x = torch.randn(32, 4, 128) # [batch, antennas, time] >>> dims = BaseConstraint.get_dimensions(x) >>> # dims will be (1, 2) for summing across antennas and time