Source code for kaira.models.generic.lambda_model

"""Defines a model based on a user-provided lambda function."""

from typing import Any, Callable, Optional

import torch

from kaira.models.base import BaseModel
from kaira.models.registry import ModelRegistry


[docs] @ModelRegistry.register_model() class LambdaModel(BaseModel): """Lambda Model. This model applies a user-provided function to the input tensor. It's useful for quickly implementing custom transformations without creating a new model class. Example: >>> # Apply a simple scaling function >>> model = LambdaModel(lambda x: 2.0 * x) >>> x = torch.ones(5, 10) >>> output = model(x) >>> assert torch.all(output == 2.0) """
[docs] def __init__(self, func: Callable[[torch.Tensor], torch.Tensor], name: Optional[str] = None): """Initialize the Lambda model. Args: func (Callable[[torch.Tensor], torch.Tensor]): Function to apply to input tensors. Should take a torch.Tensor as input and return a torch.Tensor. name (Optional[str], optional): Name for the model. If None, uses the function's name. Defaults to None. """ super().__init__() self.func = func self.name = name or func.__name__
[docs] def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor: """Apply the lambda function to the input tensor. Args: x (torch.Tensor): Input tensor *args: Additional positional arguments passed to the function **kwargs: Additional keyword arguments passed to the function Returns: torch.Tensor: Result of applying the lambda function to the input """ return self.func(x, *args, **kwargs)
def __repr__(self) -> str: """Get string representation of the model. Returns: str: Description including model name and function """ # Check if the function is a lambda by examining its name and code if self.func.__name__ == "<lambda>" or (hasattr(self.func, "__code__") and self.func.__code__.co_name == "<lambda>"): func_repr = "<lambda>" else: func_repr = str(self.func) return f"{self.__class__.__name__}(name={self.name}, func={func_repr})"