kaira.models.generic.BranchingModel

Inheritance diagram of BranchingModel

Inheritance diagram for BranchingModel

class kaira.models.generic.BranchingModel(condition: Callable[[Any], bool] | None = None, true_branch: BaseModel | None = None, false_branch: BaseModel | None = None)[source]

Bases: BaseModel

Model that routes inputs through different paths based on conditions.

This model enables conditional processing by maintaining a collection of branches, where each branch consists of: - A condition function that determines if the branch should be taken - A model that processes the input when the branch is taken

The model also supports a default branch that is taken when no other branch conditions are met.

Key features: - Dynamic branch selection based on input or state - Multiple independent processing paths - Optional default path for unmatched conditions - Branch conditions can be any callable returning a boolean - Branch models can be any BaseModel instance

Example

>>> model = BranchingModel()
>>> # Add branch for small inputs
>>> model.add_branch("small",
...                  condition=lambda x: x.shape[-1] < 64,
...                  model=small_processor)
>>> # Add branch for large inputs
>>> model.add_branch("large",
...                  condition=lambda x: x.shape[-1] >= 64,
...                  model=large_processor)
>>> # Process input - automatically selects appropriate branch
>>> output = model(input_tensor)

Methods

__init__

Initialize a branching model.

add_branch

Add a new conditional branch.

forward

Process input through the appropriate branch.

get_branch

Get a branch's condition and model.

remove_branch

Remove a branch by name.

set_default_branch

Set the default branch model.

__init__(condition: Callable[[Any], bool] | None = None, true_branch: BaseModel | None = None, false_branch: BaseModel | None = None)[source]

Initialize a branching model.

Parameters:
  • condition – Optional condition function for simple true/false branching

  • true_branch – Model to use when condition is True

  • false_branch – Model to use when condition is False

add_branch(name: str, condition: Callable[[Any], bool], model: BaseModel) None[source]

Add a new conditional branch.

Parameters:
  • name – Unique identifier for the branch

  • condition – Function that determines if branch should be taken. Should take same input as model and return bool.

  • model – Model to use when branch is taken

Raises:

ValueError – If branch name already exists

set_default_branch(model: BaseModel) None[source]

Set the default branch model.

The default branch is used when no other branch conditions are met.

Parameters:

model – Model to use as default branch

remove_branch(name: str) None[source]

Remove a branch by name.

Parameters:

name – Name of branch to remove

Raises:

KeyError – If branch doesn’t exist

forward(x: Any, return_branch: bool = False, *args: Any, **kwargs: Any) Any[source]

Process input through the appropriate branch.

Evaluates branch conditions in registration order and processes input through the first matching branch. If no branches match and a default branch exists, processes through default branch.

Parameters:
  • x – Input to process

  • return_branch – If True, returns tuple of (output, branch_name)

  • *args – Additional positional arguments passed to branch models

  • **kwargs – Additional keyword arguments passed to branch models

Returns:

Output from selected branch - If return_branch=True: Tuple of (output, branch_name)

Return type:

  • If return_branch=False

Raises:

RuntimeError – If no matching branch and no default branch

get_branch(name: str) Tuple[Callable[[Any], bool], BaseModel][source]

Get a branch’s condition and model.

Parameters:

name – Name of branch to retrieve

Returns:

Tuple of (condition_function, model)

Raises:

KeyError – If branch doesn’t exist