Source code for kaira.constraints.registry

"""Constraint registry for Kaira."""

from typing import Callable, Dict, Optional, Type

from .base import BaseConstraint


[docs] class ConstraintRegistry: """A registry for constraints in Kaira. This class provides a centralized registry for all constraints, making it easier to instantiate them by name with appropriate parameters. """ _constraints: Dict[str, Type[BaseConstraint]] = {}
[docs] @classmethod def register(cls, name: str, constraint_class: Type[BaseConstraint]) -> None: """Register a new constraint in the registry. Args: name (str): The name to register the constraint under. constraint_class (Type[BaseConstraint]): The constraint class to register. """ cls._constraints[name] = constraint_class
[docs] @classmethod def register_constraint(cls, name: Optional[str] = None) -> Callable: """Decorator to register a constraint class in the registry. Args: name (Optional[str], optional): The name to register the constraint under. If None, the class name will be used (converted to lowercase). Returns: callable: A decorator function that registers the constraint class. """ def decorator(constraint_class): constraint_name = name if name is not None else constraint_class.__name__.lower() cls.register(constraint_name, constraint_class) return constraint_class return decorator
[docs] @classmethod def get(cls, name: str) -> Type[BaseConstraint]: """Get a constraint class by name. Args: name (str): The name of the constraint to get. Returns: Type[BaseConstraint]: The constraint class. Raises: KeyError: If the constraint is not registered. """ if name not in cls._constraints: raise KeyError(f"Constraint '{name}' not found in registry. Available constraints: {list(cls._constraints.keys())}") return cls._constraints[name]
[docs] @classmethod def create(cls, name: str, **kwargs) -> BaseConstraint: """Create a constraint instance by name. Args: name (str): The name of the constraint to create. **kwargs: Additional arguments to pass to the constraint constructor. Returns: BaseConstraint: The instantiated constraint. """ constraint_class = cls.get(name) return constraint_class(**kwargs)
[docs] @classmethod def list_constraints(cls) -> list: """List all available constraints in the registry. Returns: list: A list of constraint names. """ return list(cls._constraints.keys())