Source code for kaira.models.generic.identity

"""Defines an identity model that passes the input through unchanged."""

from typing import Any

import torch

from kaira.models.base import BaseModel
from kaira.models.registry import ModelRegistry


[docs] @ModelRegistry.register_model() class IdentityModel(BaseModel): """Identity Model. This model returns the input tensor without any modifications. It can be used as a baseline model or as a placeholder in model pipelines. Example: >>> model = IdentityModel() >>> x = torch.randn(5, 10) >>> output = model(x) >>> assert torch.allclose(x, output) """
[docs] def __init__(self): """Initialize the IdentityModel.""" super().__init__()
[docs] def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor: """Forward pass through the model. Args: x (torch.Tensor): The input tensor. *args: Additional positional arguments **kwargs: Additional keyword arguments Returns: torch.Tensor: The input tensor (unchanged). """ return x
def __repr__(self) -> str: """String representation of the model.""" return f"{self.__class__.__name__}()"