Source code for kaira.metrics.registry

"""Metrics registry and factory module.

This module provides three main functionalities:
1. Registration of metrics and discovery of available metrics
2. Creation of metric instances from registered classes with proper parameters
3. Convenient factory functions for creating common metric collections and combinations

The registry pattern enables dynamic discovery of metrics and simplifies the creation
of configurable evaluation pipelines that can select metrics at runtime.

Examples:
    Register a custom metric:

    >>> from kaira.metrics.registry import MetricRegistry
    >>> from kaira.metrics.base import BaseMetric
    >>>
    >>> @MetricRegistry.register_metric()  # Uses class name as the registration key
    >>> class MyCustomMetric(BaseMetric):
    ...     def __init__(self, param1=1.0):
    ...         super().__init__(name="my_custom_metric")
    ...         self.param1 = param1
    ...
    ...     def forward(self, x, y):
    ...         # Implement metric computation
    ...         return result

    Register with custom name:

    >>> @MetricRegistry.register_metric("awesome_metric")
    >>> class ComplicatedMetricWithLongName(BaseMetric):
    ...     # implementation

    Create a metric from registry:

    >>> from kaira.metrics.registry import MetricRegistry
    >>>
    >>> # Create instance using registered name
    >>> metric = MetricRegistry.create("mycustommetric", param1=2.0)

    Use factory functions for common metric suites:

    >>> from kaira.metrics.registry import MetricRegistry
    >>>
    >>> # Create standard image metrics
    >>> metrics_dict = MetricRegistry.create_image_quality_metrics(data_range=2.0)
    >>>
    >>> # Create weighted combination
    >>> weights = {"psnr": 0.6, "ssim": 0.4}
    >>> combined = MetricRegistry.create_composite_metric(metrics_dict, weights)
"""

import inspect
from typing import Any, Callable, Dict, List, Literal, Optional, Type

import torch

from .base import BaseMetric
from .composite import CompositeMetric


[docs] class MetricRegistry: """A registry for metrics in Kaira. This class provides a centralized registry for all metrics, making it easier to instantiate them by name with appropriate parameters. """ _metrics: Dict[str, Type[BaseMetric]] = {}
[docs] @classmethod def register(cls, name: str, metric_class: Type[BaseMetric]) -> None: """Register a new metric in the registry. Args: name (str): The name to register the metric under. metric_class (Type[BaseMetric]): The metric class to register. """ if name in cls._metrics: raise ValueError(f"Metric with name '{name}' already registered") cls._metrics[name] = metric_class
[docs] @classmethod def register_metric(cls, name: Optional[str] = None) -> Callable[[Type[BaseMetric]], Type[BaseMetric]]: """Decorator to register a metric class in the global registry. This makes the metric discoverable and instantiable through the registry system. Each registered metric must inherit from BaseMetric to ensure compatibility. Args: name (Optional[str]): Optional custom name for the metric. If not provided, the lowercase class name will be used as the registration key. Using custom names is helpful for shorter keys or when the class name is not descriptive enough. Returns: Callable: Decorator function that registers the metric class Example: >>> @MetricRegistry.register_metric() # Uses class name as key >>> class MyMetric(BaseMetric): ... # implementation ... >>> @MetricRegistry.register_metric("better_name") # Uses custom name as key >>> class GenericNameThatNeedsBetterRegistryKey(BaseMetric): ... # implementation """ def decorator(cls: Type[BaseMetric]) -> Type[BaseMetric]: metric_name = name or cls.__name__.lower() MetricRegistry.register(metric_name, cls) return cls return decorator
[docs] @classmethod def create(cls, name: str, *args: Any, **kwargs: Any) -> BaseMetric: """Create a metric instance from the registry with the specified parameters. This function instantiates a registered metric class with the provided parameters, allowing for flexible creation of metrics at runtime based on configuration. Args: name (str): Name of the metric to create (case-sensitive registry key) *args: Positional arguments to pass to the metric constructor **kwargs: Keyword arguments to pass to the metric constructor. These should match the parameters expected by the metric's __init__ method. Returns: BaseMetric: Instantiated metric object ready for use Raises: KeyError: If the metric name is not found in the registry TypeError: If the provided args/kwargs don't match the metric's expected parameters Example: >>> # Create a PSNR metric with custom parameters >>> psnr = MetricRegistry.create("psnr", data_range=255.0) >>> >>> # Create a custom registered metric with positional arguments >>> my_metric = MetricRegistry.create("mycustommetric", 10, param2="value") """ if name not in cls._metrics: raise KeyError(f"Metric '{name}' not found in registry. Available metrics: {list(cls._metrics.keys())}") return cls._metrics[name](*args, **kwargs)
[docs] @classmethod def list_metrics(cls) -> List[str]: """List all registered metrics available for creation. This function returns the names of all metrics that have been registered and can be instantiated using the create_metric function. Returns: List[str]: Names (registry keys) of all registered metrics Example: >>> available_metrics = MetricRegistry.list_metrics() >>> print(f"Available metrics: {available_metrics}") >>> >>> # Check if a specific metric is available >>> if "lpips" in MetricRegistry.list_metrics(): ... metric = MetricRegistry.create("lpips") """ return list(cls._metrics.keys())
[docs] @classmethod def get_metric_info(cls, name: str) -> Dict[str, Any]: """Get detailed information about a registered metric. This function provides introspection capabilities to examine a metric's parameters, documentation, and other metadata without instantiating it. Useful for dynamic UI generation or parameter validation. Args: name (str): Name of the metric to inspect Returns: Dict[str, Any]: Dictionary containing: - name: Registry key of the metric - class: Original class name - module: Module where the class is defined - docstring: Documentation string - parameters: Dictionary of parameter names and default values Raises: KeyError: If the metric name is not found in the registry Example: >>> # Get information about the PSNR metric >>> psnr_info = MetricRegistry.get_metric_info("psnr") >>> print(f"PSNR parameters: {psnr_info['parameters']}") >>> print(f"Documentation: {psnr_info['docstring']}") """ if name not in cls._metrics: raise KeyError(f"Metric '{name}' not found in registry") metric_class = cls._metrics[name] signature = inspect.signature(metric_class.__init__) params = {k: v.default if v.default is not inspect.Parameter.empty else None for k, v in list(signature.parameters.items())[1:]} # Skip 'self' return { "name": name, "class": metric_class.__name__, "module": metric_class.__module__, "docstring": inspect.getdoc(metric_class), "parameters": params, }
[docs] @classmethod def create_image_quality_metrics(cls, data_range: float = 1.0, lpips_net_type: Literal["vgg", "alex", "squeeze"] = "alex", device: Optional[torch.device] = None) -> Dict[str, BaseMetric]: """Create a standard suite of image quality assessment metrics. This factory function creates a collection of commonly used image quality metrics with consistent parameters, making it easy to evaluate images across multiple metrics. The returned metrics include: - PSNR (Peak Signal-to-Noise Ratio): A pixel-level fidelity metric - SSIM (Structural Similarity Index): A perceptual metric focusing on structure - MS-SSIM (Multi-Scale SSIM): A multi-scale version of SSIM - LPIPS (Learned Perceptual Image Patch Similarity): A learned perceptual metric Args: data_range (float): The data range of the images. Use 1.0 for normalized images in range [0,1] or 255.0 for uint8 images in range [0,255]. lpips_net_type (Literal['vgg', 'alex', 'squeeze']): The backbone network for LPIPS. Options are: - 'alex': AlexNet (faster, less accurate) - 'vgg': VGG network (slower, more accurate) - 'squeeze': SqueezeNet (fastest, least accurate) device (Optional[torch.device]): Device to place the metrics on. If None, metrics will be on the default device (typically CPU). Returns: Dict[str, BaseMetric]: Dictionary mapping metric names to initialized metrics. All metrics follow the BaseMetric interface and can be called directly with input tensors. Example: >>> import torch >>> >>> # Create metrics for normalized images [0,1] >>> metrics = MetricRegistry.create_image_quality_metrics(data_range=1.0, device=torch.device('cuda')) >>> >>> # Generate some test images >>> pred = torch.rand(1, 3, 256, 256).cuda() # Batch of random RGB images >>> target = torch.rand(1, 3, 256, 256).cuda() >>> >>> # Compute metrics individually >>> psnr_value = metrics['psnr'](pred, target) >>> ssim_value = metrics['ssim'](pred, target) >>> >>> # Or create a composite metric >>> composite = MetricRegistry.create_composite_metric(metrics, weights={'psnr': 0.5, 'ssim': 0.5}) >>> score = composite(pred, target) """ from .image import LPIPS, PSNR, SSIM, MultiScaleSSIM metrics = { "psnr": PSNR(data_range=data_range), "ssim": SSIM(data_range=data_range), "ms_ssim": MultiScaleSSIM(data_range=data_range), "lpips": LPIPS(net_type=lpips_net_type), } if device is not None: for metric in metrics.values(): metric.to(device) return metrics
[docs] @classmethod def create_composite_metric(cls, metrics: Dict[str, BaseMetric], weights: Optional[Dict[str, float]] = None) -> BaseMetric: """Create a composite metric that combines multiple metrics with weights. This factory function creates a CompositeMetric instance that applies multiple metrics to the same inputs and combines their results according to specified weights. This is useful for: - Creating custom evaluation criteria that balance multiple aspects - Combining complementary metrics (e.g., pixel accuracy and perceptual quality) - Building task-specific evaluation metrics that focus on relevant properties Args: metrics (Dict[str, BaseMetric]): Dictionary mapping metric names to metric objects. All provided metrics should follow the BaseMetric interface. weights (Optional[Dict[str, float]]): Optional dictionary mapping metric names to their relative weights. If None, metrics will be equally weighted. Use negative weights for metrics where lower values are better (like LPIPS) when combining with metrics where higher values are better (like PSNR/SSIM). Returns: BaseMetric: A composite metric that combines the provided metrics according to the specified weights. This metric follows the BaseMetric interface and can be used like any other metric. Example: >>> from kaira.metrics import PSNR, SSIM >>> from kaira.metrics.registry import MetricRegistry >>> >>> # Create individual metrics >>> psnr = PSNR(data_range=1.0) >>> ssim = SSIM(data_range=1.0) >>> lpips = LPIPS(net_type='alex') # Lower values are better >>> >>> # Create a balanced composite metric (higher values = better) >>> metrics = {'psnr': psnr, 'ssim': ssim, 'lpips': lpips} >>> weights = {'psnr': 0.4, 'ssim': 0.4, 'lpips': -0.2} # Negative weight for LPIPS >>> >>> balanced_metric = MetricRegistry.create_composite_metric(metrics, weights) """ return CompositeMetric(metrics, weights)
[docs] @classmethod def clear(cls) -> None: """Clear all registered metrics from the registry. This is primarily useful for testing and reinitialization scenarios. """ cls._metrics.clear()
[docs] @classmethod def available_metrics(cls) -> List[str]: """Get a list of all available metrics in the registry. This is an alias for list_metrics() for backward compatibility. Returns: List[str]: List of registered metric names """ return cls.list_metrics()