kaira.models.generic.LambdaModel

Inheritance diagram of LambdaModel

Inheritance diagram for LambdaModel

class kaira.models.generic.LambdaModel(func: Callable[[Tensor], Tensor], name: str | None = None)[source]

Bases: BaseModel

Lambda Model.

This model applies a user-provided function to the input tensor. It’s useful for quickly implementing custom transformations without creating a new model class.

Example

>>> # Apply a simple scaling function
>>> model = LambdaModel(lambda x: 2.0 * x)
>>> x = torch.ones(5, 10)
>>> output = model(x)
>>> assert torch.all(output == 2.0)

Methods

__init__

Initialize the Lambda model.

forward

Apply the lambda function to the input tensor.

__init__(func: Callable[[Tensor], Tensor], name: str | None = None)[source]

Initialize the Lambda model.

Parameters:
  • func (Callable[[torch.Tensor], torch.Tensor]) – Function to apply to input tensors. Should take a torch.Tensor as input and return a torch.Tensor.

  • name (Optional[str], optional) – Name for the model. If None, uses the function’s name. Defaults to None.

forward(x: Tensor, *args: Any, **kwargs: Any) Tensor[source]

Apply the lambda function to the input tensor.

Parameters:
  • x (torch.Tensor) – Input tensor

  • *args – Additional positional arguments passed to the function

  • **kwargs – Additional keyword arguments passed to the function

Returns:

Result of applying the lambda function to the input

Return type:

torch.Tensor