kaira.models.generic.LambdaModel

Inheritance diagram for LambdaModel
- class kaira.models.generic.LambdaModel(func: Callable[[Tensor], Tensor], name: str | None = None)[source]
Bases:
BaseModelLambda Model.
This model applies a user-provided function to the input tensor. It’s useful for quickly implementing custom transformations without creating a new model class.
Example
>>> # Apply a simple scaling function >>> model = LambdaModel(lambda x: 2.0 * x) >>> x = torch.ones(5, 10) >>> output = model(x) >>> assert torch.all(output == 2.0)
Methods
Initialize the Lambda model.
Apply the lambda function to the input tensor.
- __init__(func: Callable[[Tensor], Tensor], name: str | None = None)[source]
Initialize the Lambda model.
- Parameters:
func (Callable[[torch.Tensor], torch.Tensor]) – Function to apply to input tensors. Should take a torch.Tensor as input and return a torch.Tensor.
name (Optional[str], optional) – Name for the model. If None, uses the function’s name. Defaults to None.
- forward(x: Tensor, *args: Any, **kwargs: Any) Tensor[source]
Apply the lambda function to the input tensor.
- Parameters:
x (torch.Tensor) – Input tensor
*args – Additional positional arguments passed to the function
**kwargs – Additional keyword arguments passed to the function
- Returns:
Result of applying the lambda function to the input
- Return type: