kaira.models.generic.BranchingModel

Inheritance diagram for BranchingModel
- class kaira.models.generic.BranchingModel(condition: Callable[[Any], bool] | None = None, true_branch: BaseModel | None = None, false_branch: BaseModel | None = None)[source]
Bases:
BaseModelModel that routes inputs through different paths based on conditions.
This model enables conditional processing by maintaining a collection of branches, where each branch consists of: - A condition function that determines if the branch should be taken - A model that processes the input when the branch is taken
The model also supports a default branch that is taken when no other branch conditions are met.
Key features: - Dynamic branch selection based on input or state - Multiple independent processing paths - Optional default path for unmatched conditions - Branch conditions can be any callable returning a boolean - Branch models can be any BaseModel instance
Example
>>> model = BranchingModel() >>> # Add branch for small inputs >>> model.add_branch("small", ... condition=lambda x: x.shape[-1] < 64, ... model=small_processor) >>> # Add branch for large inputs >>> model.add_branch("large", ... condition=lambda x: x.shape[-1] >= 64, ... model=large_processor) >>> # Process input - automatically selects appropriate branch >>> output = model(input_tensor)
Methods
Initialize a branching model.
Add a new conditional branch.
Process input through the appropriate branch.
Get a branch's condition and model.
Remove a branch by name.
Set the default branch model.
- __init__(condition: Callable[[Any], bool] | None = None, true_branch: BaseModel | None = None, false_branch: BaseModel | None = None)[source]
Initialize a branching model.
- Parameters:
condition – Optional condition function for simple true/false branching
true_branch – Model to use when condition is True
false_branch – Model to use when condition is False
- add_branch(name: str, condition: Callable[[Any], bool], model: BaseModel) None[source]
Add a new conditional branch.
- Parameters:
name – Unique identifier for the branch
condition – Function that determines if branch should be taken. Should take same input as model and return bool.
model – Model to use when branch is taken
- Raises:
ValueError – If branch name already exists
- set_default_branch(model: BaseModel) None[source]
Set the default branch model.
The default branch is used when no other branch conditions are met.
- Parameters:
model – Model to use as default branch
- remove_branch(name: str) None[source]
Remove a branch by name.
- Parameters:
name – Name of branch to remove
- Raises:
KeyError – If branch doesn’t exist
- forward(x: Any, return_branch: bool = False, *args: Any, **kwargs: Any) Any[source]
Process input through the appropriate branch.
Evaluates branch conditions in registration order and processes input through the first matching branch. If no branches match and a default branch exists, processes through default branch.
- Parameters:
x – Input to process
return_branch – If True, returns tuple of (output, branch_name)
*args – Additional positional arguments passed to branch models
**kwargs – Additional keyword arguments passed to branch models
- Returns:
Output from selected branch - If return_branch=True: Tuple of (output, branch_name)
- Return type:
If return_branch=False
- Raises:
RuntimeError – If no matching branch and no default branch