Source code for kaira.losses.registry

"""Loss registry for Kaira."""

from typing import Callable, Dict, Optional, Type

from .base import BaseLoss


[docs] class LossRegistry: """A registry for loss functions in Kaira. This class provides a centralized registry for all loss functions, making it easier to instantiate them by name with appropriate parameters. """ _losses: Dict[str, Type[BaseLoss]] = {}
[docs] @classmethod def register(cls, name: str, loss_class: Type[BaseLoss]) -> None: """Register a new loss in the registry. Args: name (str): The name to register the loss under. loss_class (Type[BaseLoss]): The loss class to register. """ cls._losses[name] = loss_class
[docs] @classmethod def register_loss(cls, name: Optional[str] = None) -> Callable: """Decorator to register a loss class in the registry. Args: name (Optional[str], optional): The name to register the loss under. If None, the class name will be used (converted to lowercase). Returns: callable: A decorator function that registers the loss class. """ def decorator(loss_class): loss_name = name if name is not None else loss_class.__name__.lower() cls.register(loss_name, loss_class) return loss_class return decorator
[docs] @classmethod def get(cls, name: str) -> Type[BaseLoss]: """Get a loss class by name. Args: name (str): The name of the loss to get. Returns: Type[BaseLoss]: The loss class. Raises: KeyError: If the loss is not registered. """ if name not in cls._losses: raise KeyError(f"Loss '{name}' not found in registry. Available losses: {list(cls._losses.keys())}") return cls._losses[name]
[docs] @classmethod def create(cls, name: str, **kwargs) -> BaseLoss: """Create a loss instance by name. Args: name (str): The name of the loss to create. **kwargs: Additional arguments to pass to the loss constructor. Returns: BaseLoss: The instantiated loss. """ loss_class = cls.get(name) return loss_class(**kwargs)
[docs] @classmethod def list_losses(cls) -> list: """List all available losses in the registry. Returns: list: A list of loss names. """ return list(cls._losses.keys())