Source code for kaira.models.generic.parallel

"""Defines a model that applies multiple modules in parallel to the input."""

from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import Any, Callable, List, Optional, Tuple

from ..base import ConfigurableModel
from ..registry import ModelRegistry


[docs] @ModelRegistry.register_model() class ParallelModel(ConfigurableModel): """A model that processes steps in parallel. All steps receive the same input data and process independently. """ step_configs: List[Tuple[str, Callable]] # Renamed from 'steps'
[docs] def __init__(self, max_workers: Optional[int] = None, steps: Optional[List[Tuple[str, Callable]]] = None, branches: Optional[List[Callable]] = None, aggregator: Optional[Callable] = None): """Initialize the parallel model. Args: max_workers: Maximum number of worker threads (None uses default ThreadPoolExecutor behavior) steps: Optional initial list of named processing steps as (name, step) tuples branches: Alternative way to specify processing steps as a list of callables aggregator: Optional function to aggregate results (if None, returns dict of outputs) """ super().__init__() self.max_workers = max_workers self.aggregator = aggregator self._step_counter = 0 # Counter for auto-naming steps # Initialize step_configs list if steps: self.step_configs = steps # Use new attribute name else: self.step_configs = [] # Use new attribute name # Add branches if provided if branches: for i, branch in enumerate(branches): self.add_step(branch, f"branch_{i}")
[docs] def add_step(self, step: Callable, name: Optional[str] = None): """Add a processing step to the model with an optional name. Args: step: A callable function or object that processes input data name: Optional name for the step (auto-generated if None) Returns: The model instance for method chaining Raises: TypeError: If step is not callable """ if not callable(step): raise TypeError("Step must be callable") if name is None: name = f"step_{self._step_counter}" self._step_counter += 1 self.step_configs.append((name, step)) # Use new attribute name return self
[docs] def remove_step(self, index: int): """Remove a processing step from the model. Args: index: The index of the step to remove Returns: The model instance for method chaining Raises: IndexError: If the index is out of range """ if not 0 <= index < len(self.step_configs): # Use new attribute name raise IndexError(f"Step index {index} out of range (0-{len(self.step_configs)-1})") # Use new attribute name self.step_configs.pop(index) # Use new attribute name return self
[docs] def forward(self, input_data: Any, *args: Any, **kwargs: Any) -> Any: """Execute the model in parallel on the input data. Args: input_data: The data to process *args: Additional positional arguments passed to each step. **kwargs: Additional keyword arguments passed to each step. Returns: Dictionary mapping step names to their respective outputs or aggregated results if an aggregator is provided """ if not self.step_configs: # Use new attribute name return {} results = {} with ThreadPoolExecutor(max_workers=self.max_workers) as executor: # Pass *args and **kwargs to each step submitted to the executor future_to_step = {executor.submit(step_func, input_data, *args, **kwargs): name for name, step_func in self.step_configs} # Use new attribute name and step_func for future in as_completed(future_to_step): step_name = future_to_step[future] try: results[step_name] = future.result() except Exception as exc: results[step_name] = f"Error: {exc}" # Apply aggregator if provided if self.aggregator: # Convert dictionary of results to a list of values for the aggregator return self.aggregator(list(results.values())) return results