kaira.metrics.MetricRegistry

Inheritance diagram of MetricRegistry

Inheritance diagram for MetricRegistry

class kaira.metrics.MetricRegistry[source]

Bases: object

A registry for metrics in Kaira.

This class provides a centralized registry for all metrics, making it easier to instantiate them by name with appropriate parameters.

Methods

__init__

available_metrics

Get a list of all available metrics in the registry.

clear

Clear all registered metrics from the registry.

create

Create a metric instance from the registry with the specified parameters.

create_composite_metric

Create a composite metric that combines multiple metrics with weights.

create_image_quality_metrics

Create a standard suite of image quality assessment metrics.

get_metric_info

Get detailed information about a registered metric.

list_metrics

List all registered metrics available for creation.

register

Register a new metric in the registry.

register_metric

Decorator to register a metric class in the global registry.

classmethod register(name: str, metric_class: Type[BaseMetric]) None[source]

Register a new metric in the registry.

Parameters:
  • name (str) – The name to register the metric under.

  • metric_class (Type[BaseMetric]) – The metric class to register.

classmethod register_metric(name: str | None = None) Callable[[Type[BaseMetric]], Type[BaseMetric]][source]

Decorator to register a metric class in the global registry.

This makes the metric discoverable and instantiable through the registry system. Each registered metric must inherit from BaseMetric to ensure compatibility.

Parameters:

name (Optional[str]) – Optional custom name for the metric. If not provided, the lowercase class name will be used as the registration key. Using custom names is helpful for shorter keys or when the class name is not descriptive enough.

Returns:

Decorator function that registers the metric class

Return type:

Callable

Example

>>> @MetricRegistry.register_metric()  # Uses class name as key
>>> class MyMetric(BaseMetric):
...     # implementation
...
>>> @MetricRegistry.register_metric("better_name")  # Uses custom name as key
>>> class GenericNameThatNeedsBetterRegistryKey(BaseMetric):
...     # implementation
classmethod create(name: str, *args: Any, **kwargs: Any) BaseMetric[source]

Create a metric instance from the registry with the specified parameters.

This function instantiates a registered metric class with the provided parameters, allowing for flexible creation of metrics at runtime based on configuration.

Parameters:
  • name (str) – Name of the metric to create (case-sensitive registry key)

  • *args – Positional arguments to pass to the metric constructor

  • **kwargs – Keyword arguments to pass to the metric constructor. These should match the parameters expected by the metric’s __init__ method.

Returns:

Instantiated metric object ready for use

Return type:

BaseMetric

Raises:
  • KeyError – If the metric name is not found in the registry

  • TypeError – If the provided args/kwargs don’t match the metric’s expected parameters

Example

>>> # Create a PSNR metric with custom parameters
>>> psnr = MetricRegistry.create("psnr", data_range=255.0)
>>>
>>> # Create a custom registered metric with positional arguments
>>> my_metric = MetricRegistry.create("mycustommetric", 10, param2="value")
classmethod list_metrics() List[str][source]

List all registered metrics available for creation.

This function returns the names of all metrics that have been registered and can be instantiated using the create_metric function.

Returns:

Names (registry keys) of all registered metrics

Return type:

List[str]

Example

>>> available_metrics = MetricRegistry.list_metrics()
>>> print(f"Available metrics: {available_metrics}")
>>>
>>> # Check if a specific metric is available
>>> if "lpips" in MetricRegistry.list_metrics():
...     metric = MetricRegistry.create("lpips")
classmethod get_metric_info(name: str) Dict[str, Any][source]

Get detailed information about a registered metric.

This function provides introspection capabilities to examine a metric’s parameters, documentation, and other metadata without instantiating it. Useful for dynamic UI generation or parameter validation.

Parameters:

name (str) – Name of the metric to inspect

Returns:

Dictionary containing:
  • name: Registry key of the metric

  • class: Original class name

  • module: Module where the class is defined

  • docstring: Documentation string

  • parameters: Dictionary of parameter names and default values

Return type:

Dict[str, Any]

Raises:

KeyError – If the metric name is not found in the registry

Example

>>> # Get information about the PSNR metric
>>> psnr_info = MetricRegistry.get_metric_info("psnr")
>>> print(f"PSNR parameters: {psnr_info['parameters']}")
>>> print(f"Documentation: {psnr_info['docstring']}")
classmethod create_image_quality_metrics(data_range: float = 1.0, lpips_net_type: Literal['vgg', 'alex', 'squeeze'] = 'alex', device: device | None = None) Dict[str, BaseMetric][source]

Create a standard suite of image quality assessment metrics.

This factory function creates a collection of commonly used image quality metrics with consistent parameters, making it easy to evaluate images across multiple metrics.

The returned metrics include: - PSNR (Peak Signal-to-Noise Ratio): A pixel-level fidelity metric - SSIM (Structural Similarity Index): A perceptual metric focusing on structure - MS-SSIM (Multi-Scale SSIM): A multi-scale version of SSIM - LPIPS (Learned Perceptual Image Patch Similarity): A learned perceptual metric

Parameters:
  • data_range (float) – The data range of the images. Use 1.0 for normalized images in range [0,1] or 255.0 for uint8 images in range [0,255].

  • lpips_net_type (Literal['vgg', 'alex', 'squeeze']) – The backbone network for LPIPS. Options are: - ‘alex’: AlexNet (faster, less accurate) - ‘vgg’: VGG network (slower, more accurate) - ‘squeeze’: SqueezeNet (fastest, least accurate)

  • device (Optional[torch.device]) – Device to place the metrics on. If None, metrics will be on the default device (typically CPU).

Returns:

Dictionary mapping metric names to initialized metrics.

All metrics follow the BaseMetric interface and can be called directly with input tensors.

Return type:

Dict[str, BaseMetric]

Example

>>> import torch
>>>
>>> # Create metrics for normalized images [0,1]
>>> metrics = MetricRegistry.create_image_quality_metrics(data_range=1.0, device=torch.device('cuda'))
>>>
>>> # Generate some test images
>>> pred = torch.rand(1, 3, 256, 256).cuda()  # Batch of random RGB images
>>> target = torch.rand(1, 3, 256, 256).cuda()
>>>
>>> # Compute metrics individually
>>> psnr_value = metrics['psnr'](pred, target)
>>> ssim_value = metrics['ssim'](pred, target)
>>>
>>> # Or create a composite metric
>>> composite = MetricRegistry.create_composite_metric(metrics, weights={'psnr': 0.5, 'ssim': 0.5})
>>> score = composite(pred, target)
classmethod create_composite_metric(metrics: Dict[str, BaseMetric], weights: Dict[str, float] | None = None) BaseMetric[source]

Create a composite metric that combines multiple metrics with weights.

This factory function creates a CompositeMetric instance that applies multiple metrics to the same inputs and combines their results according to specified weights.

This is useful for: - Creating custom evaluation criteria that balance multiple aspects - Combining complementary metrics (e.g., pixel accuracy and perceptual quality) - Building task-specific evaluation metrics that focus on relevant properties

Parameters:
  • metrics (Dict[str, BaseMetric]) – Dictionary mapping metric names to metric objects. All provided metrics should follow the BaseMetric interface.

  • weights (Optional[Dict[str, float]]) –

    Optional dictionary mapping metric names to their relative weights. If None, metrics will be equally weighted.

    Use negative weights for metrics where lower values are better (like LPIPS) when combining with metrics where higher values are better (like PSNR/SSIM).

Returns:

A composite metric that combines the provided metrics according

to the specified weights. This metric follows the BaseMetric interface and can be used like any other metric.

Return type:

BaseMetric

Example

>>> from kaira.metrics import PSNR, SSIM
>>> from kaira.metrics.registry import MetricRegistry
>>>
>>> # Create individual metrics
>>> psnr = PSNR(data_range=1.0)
>>> ssim = SSIM(data_range=1.0)
>>> lpips = LPIPS(net_type='alex')  # Lower values are better
>>>
>>> # Create a balanced composite metric (higher values = better)
>>> metrics = {'psnr': psnr, 'ssim': ssim, 'lpips': lpips}
>>> weights = {'psnr': 0.4, 'ssim': 0.4, 'lpips': -0.2}  # Negative weight for LPIPS
>>>
>>> balanced_metric = MetricRegistry.create_composite_metric(metrics, weights)
classmethod clear() None[source]

Clear all registered metrics from the registry.

This is primarily useful for testing and reinitialization scenarios.

classmethod available_metrics() List[str][source]

Get a list of all available metrics in the registry. This is an alias for list_metrics() for backward compatibility.

Returns:

List of registered metric names

Return type:

List[str]

__init__()