kaira.models.image.compressors.NeuralCompressor

Inheritance diagram of NeuralCompressor

Inheritance diagram for NeuralCompressor

class kaira.models.image.compressors.NeuralCompressor(method: str, metric: str = 'mse', max_bits_per_image: int | None = None, quality: int | None = None, lazy_loading: bool = True, return_bits: bool = False, collect_stats: bool = False, return_compressed_data: bool = False, device: str | device | None = None, early_stopping_threshold: float | None = None, *args: Any, **kwargs: Any)[source]

Bases: BaseModel

Neural network-based image compression model.

This class provides neural network-based compression using various pretrained models from the CompressAI library. It can operate in two modes: 1. Fixed quality mode: directly uses the specified quality level 2. Bit-constrained mode: finds the highest quality that stays under a bit budget

The implementation efficiently manages model loading to minimize memory usage and supports a variety of modern image compression methods.

Methods

__init__

Initialize the neural compressor.

compute_bits_compressai

Compute bits required for each image in the batch.

forward

Forward pass of the neural compressor.

get_bits_per_image

Compress images and return only the bit counts per image.

get_model

Get a model with the specified quality, using cache if available.

get_stats

Return compression statistics if collect_stats=True was set.

reset_stats

Reset compression statistics.

__init__(method: str, metric: str = 'mse', max_bits_per_image: int | None = None, quality: int | None = None, lazy_loading: bool = True, return_bits: bool = False, collect_stats: bool = False, return_compressed_data: bool = False, device: str | device | None = None, early_stopping_threshold: float | None = None, *args: Any, **kwargs: Any)[source]

Initialize the neural compressor.

Parameters:
  • method – Compression method to use (e.g., “bmshj2018_factorized”)

  • metric – Metric used for training (“mse” or “ms-ssim”)

  • max_bits_per_image – Maximum bits allowed per image

  • quality – Specific quality level to use

  • lazy_loading – Whether to load models only when needed (saves memory)

  • return_bits – Whether to return bits per image in forward pass

  • collect_stats – Whether to collect and return compression statistics

  • return_compressed_data – Whether to return compressed representation

  • device – Device to load models on (e.g., “cuda”, “cpu”)

  • early_stopping_threshold – Bit threshold below which to stop quality search (e.g., 0.95 means stop if within 5% of bit budget)

  • *args – Variable positional arguments passed to the base class.

  • **kwargs – Variable keyword arguments passed to the base class.

get_model(quality: int) Module[source]

Get a model with the specified quality, using cache if available.

compute_bits_compressai(r: Dict) Tensor[source]

Compute bits required for each image in the batch.

Parameters:

r – CompressAI model output dictionary

Returns:

Tensor containing bits per image

forward(x: Tensor, *args: Any, **kwargs: Any) Tensor | Tuple[Tensor, Tensor] | Tuple[Tensor, List[Any]] | Tuple[Tensor, Tensor, List[Any]][source]

Forward pass of the neural compressor.

Parameters:
  • x – Input image tensor

  • *args – Additional positional arguments (unused in this method).

  • **kwargs – Additional keyword arguments (unused in this method).

Returns:

Just the reconstructed image If return_bits=True: Tuple of (reconstructed image, bits per image tensor) If return_compressed_data=True: Tuple of (reconstructed image, compressed data list) If both are True: Tuple of (reconstructed image, bits per image tensor, compressed data list)

Return type:

If no additional returns

reset_stats()[source]

Reset compression statistics.

get_stats()[source]

Return compression statistics if collect_stats=True was set.

get_bits_per_image(x: Tensor, *args: Any, **kwargs: Any) Tensor[source]

Compress images and return only the bit counts per image.

Parameters:
  • x – Tensor of shape [batch_size, channels, height, width]

  • *args – Additional positional arguments passed to forward.

  • **kwargs – Additional keyword arguments passed to forward.

Returns:

Number of bits used for each compressed image

Return type:

Tensor