Source code for kaira.models.generic.sequential
"""Defines a sequential model container."""
from typing import Any, Callable, Optional, Sequence
from ..base import ConfigurableModel
from ..registry import ModelRegistry
[docs]
@ModelRegistry.register_model()
class SequentialModel(ConfigurableModel):
"""A model that processes steps sequentially.
Each step receives the output of the previous step as its input.
"""
[docs]
def __init__(self, steps: Optional[Sequence[Callable]] = None, *args: Any, **kwargs: Any):
"""Initialize the sequential model.
Args:
steps: Optional initial list of processing steps
*args: Variable positional arguments passed to the base class.
**kwargs: Variable keyword arguments passed to the base class.
"""
super().__init__(*args, **kwargs)
if steps:
# Ensure all initial steps are callable
for step in steps:
if not callable(step):
raise TypeError(f"All initial steps must be callable, got {type(step)}")
self.steps = list(steps)
[docs]
def forward(self, input_data: Any, *args: Any, **kwargs: Any) -> Any:
"""Execute the model sequentially on the input data.
Args:
input_data: The initial data to process
*args: Additional positional arguments passed to each step.
**kwargs: Additional keyword arguments passed to each step.
Returns:
The final result after passing through all steps
"""
result = input_data
for step in self.steps:
result = step(result, *args, **kwargs) # Pass *args and **kwargs to each step
return result