"""BPG (Better Portable Graphics) image compressor wrapper."""
import logging
import multiprocessing
import os
import re
import shutil
import subprocess # nosec
import tempfile
import time
import uuid
# Change CompletedProcess import location if needed, or just use subprocess.CompletedProcess
from subprocess import CompletedProcess # nosec B404
from typing import Any, Dict, List, Optional, Tuple, Union
import torch
from joblib import Parallel, delayed
from PIL import Image
from torchvision import transforms
from torchvision.utils import save_image
from kaira.models.base import BaseModel
logger = logging.getLogger(__name__)
[docs]
class BPGCompressor(BaseModel):
"""BPG (Better Portable Graphics) image compression based on bpgenc and bpgdec.
This class provides BPG-based compression using external BPG tools. It can operate in two modes:
1. Fixed quality mode: directly uses the specified quality level
2. Bit-constrained mode: finds the highest quality that stays under a bit budget
BPG (Better Portable Graphics) is a lossy image compression format based on HEVC
(High Efficiency Video Coding) that provides superior compression efficiency compared
to JPEG while maintaining good visual quality.
Installation:
The BPG tools (bpgenc and bpgdec) must be installed separately on your system.
For installation instructions, see:
https://kaira.readthedocs.io/en/latest/installation.html#bpg-image-compression-support
Example:
# Fixed quality compression
compressor = BPGCompressor(quality=30)
compressed_images = compressor(image_batch)
# Bit-constrained compression
compressor = BPGCompressor(max_bits_per_image=5000)
compressed_images, bits_used = compressor(image_batch)
# With compression statistics
compressor = BPGCompressor(quality=25, collect_stats=True, return_bits=True)
compressed_images, bits_per_image = compressor(image_batch)
stats = compressor.get_stats()
Note:
This class requires external BPG tools to be installed and available in PATH
or specified via bpg_encoder_path and bpg_decoder_path parameters.
"""
[docs]
def __init__(
self,
max_bits_per_image: Optional[int] = None,
quality: Optional[int] = None,
bpg_encoder_path: str = "bpgenc",
bpg_decoder_path: str = "bpgdec",
n_jobs: Optional[int] = None,
collect_stats: bool = False,
return_bits: bool = True,
return_compressed_data: bool = False,
*args: Any,
**kwargs: Any,
):
"""Initialize the BPG Compressor.
Args:
max_bits_per_image: Maximum bits allowed per compressed image. If provided without
quality, the compressor will find the highest quality that
produces files smaller than this limit.
quality: Fixed quality level for BPG compression (0-51, lower is better).
If provided, this exact quality will be used regardless of resulting file size.
bpg_encoder_path: Path to the BPG encoder executable
bpg_decoder_path: Path to the BPG decoder executable
n_jobs: Number of parallel jobs to use (default: CPU count // 2)
collect_stats: Whether to collect and return compression statistics
return_bits: Whether to return bits per image in forward pass
return_compressed_data: Whether to return the compressed binary data
*args: Variable positional arguments passed to the base class.
**kwargs: Variable keyword arguments passed to the base class.
"""
super().__init__(*args, **kwargs) # Pass args and kwargs to base
# At least one of the two parameters must be provided
if max_bits_per_image is None and quality is None:
raise ValueError("At least one of the two parameters must be provided")
if quality is not None and (quality < 0 or quality > 51):
raise ValueError("BPG quality must be between 0 and 51")
self.max_bits_per_image = max_bits_per_image
self.quality = quality
# Validate executable paths to prevent command injection
self._validate_executable_path(bpg_encoder_path)
self._validate_executable_path(bpg_decoder_path)
self.bpg_encoder_path = bpg_encoder_path
self.bpg_decoder_path = bpg_decoder_path
self.n_jobs = n_jobs if n_jobs is not None else max(1, multiprocessing.cpu_count() // 2)
self.collect_stats = collect_stats
self.return_bits = return_bits
self.return_compressed_data = return_compressed_data
self.stats: Dict[str, Any] = {}
# Check if BPG tools are available using secure subprocess execution
try:
self._safe_subprocess_run([self.bpg_encoder_path, "--help"])
except (subprocess.SubprocessError, FileNotFoundError):
logger.error(f"BPG encoder not found at '{self.bpg_encoder_path}'. Please install BPG tools.")
raise RuntimeError(f"BPG encoder not found at '{self.bpg_encoder_path}'. " "Please install BPG tools following the instructions at: " "https://kaira.readthedocs.io/en/latest/installation.html#bpg-image-compression-support")
def _validate_executable_path(self, path: str) -> None:
"""Validate that an executable path doesn't contain shell metacharacters.
Args:
path: The executable path to validate
Raises:
ValueError: If the path contains potentially dangerous characters
"""
# Simple validation to prevent basic command injection
if ";" in path or "&" in path or "|" in path or ">" in path or "<" in path:
raise ValueError(f"Executable path '{path}' contains invalid characters")
# Check if path doesn't exist but contains shell metacharacters
if not os.path.exists(path) and re.search(r"[${}()`\[\]\s]", path):
raise ValueError(f"Executable path '{path}' contains potentially dangerous characters")
def _safe_subprocess_run(self, cmd_args: List[str], **kwargs) -> CompletedProcess:
"""Execute subprocess safely with validated arguments.
Args:
cmd_args: Command arguments list
**kwargs: Additional arguments for subprocess.run
Returns:
subprocess.CompletedProcess object
"""
# Always enforce shell=False
kwargs["shell"] = False
# Default to capturing output
if "capture_output" not in kwargs and "stdout" not in kwargs:
kwargs["capture_output"] = True
# Ensure return type matches annotation
return subprocess.run(cmd_args, **kwargs) # type: ignore # nosec B603
[docs]
def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> Union[torch.Tensor, Tuple[torch.Tensor, List[int]], Tuple[torch.Tensor, List[bytes]], Tuple[torch.Tensor, List[int], List[bytes]]]:
"""Process a batch of images through BPG compression.
The compression method depends on initialization parameters:
- If quality was provided, that fixed quality is used
- If max_bits_per_image was provided, the highest quality meeting the bit constraint is found
Args:
x: Tensor of shape [batch_size, channels, height, width]
*args: Additional positional arguments (passed to internal methods).
**kwargs: Additional keyword arguments (passed to internal methods).
Returns:
If no additional returns: Just the reconstructed image tensor
If return_bits=True: Tuple of (tensor, bits per image)
If return_compressed_data=True: Tuple of (tensor, compressed binary data)
If both are True: Tuple of (tensor, bits per image, compressed binary data)
"""
start_time = time.time()
if self.collect_stats:
self.stats = {"total_bits": 0, "avg_quality": 0, "img_stats": []}
# Always collect bits information if return_bits or collect_stats is True
collect_info = self.return_bits or self.collect_stats or self.return_compressed_data
# Process images in parallel
results = Parallel(n_jobs=self.n_jobs)(delayed(self.parallel_forward_bpg)(i, x[i], collect_info, *args, **kwargs) for i in range(x.shape[0]))
# Unpack results
images = []
bits_per_image: List[int] = [] if self.return_bits or self.collect_stats else []
compressed_data: List[bytes] = [] if self.return_compressed_data else []
for result in results:
if collect_info:
img, info = result
images.append(img)
if (self.return_bits or self.collect_stats) and bits_per_image is not None:
bits_per_image.append(int(info.get("bits", 0)))
if self.return_compressed_data and compressed_data is not None:
compressed_data.append(info.get("compressed_data", b""))
# Update full stats if requested
if self.collect_stats:
self.stats["total_bits"] += info.get("bits", 0)
self.stats["img_stats"].append(info)
else:
images.append(result)
x_hat = torch.stack(images, dim=0).to(x.device)
# Calculate aggregate statistics if requested
if self.collect_stats and x.shape[0] > 0:
self.stats["avg_quality"] = sum(s.get("quality", 0) for s in self.stats["img_stats"]) / x.shape[0]
self.stats["avg_bpp"] = self.stats["total_bits"] / x.shape[0]
self.stats["avg_compression_ratio"] = sum(s.get("compression_ratio", 0) for s in self.stats["img_stats"]) / x.shape[0]
self.stats["processing_time"] = time.time() - start_time
# Return appropriate output based on flags
if self.return_bits and self.return_compressed_data:
return x_hat, bits_per_image, compressed_data
elif self.return_bits:
return x_hat, bits_per_image
elif self.return_compressed_data:
return x_hat, compressed_data
else:
return x_hat
[docs]
def parallel_forward_bpg(self, idx: int, img: torch.Tensor, return_info: bool = False, *args: Any, **kwargs: Any):
"""Process a single image with BPG compression.
Args:
idx: Image index
img: Image tensor of shape [channels, height, width]
return_info: Whether to return compression information
*args: Additional positional arguments (passed to compression methods).
**kwargs: Additional keyword arguments (passed to compression methods).
Returns:
If return_info=False: Processed image tensor
If return_info=True: Tuple of (tensor, info_dict)
"""
if self.quality is not None:
# Pass *args, **kwargs
result = self.compress_with_quality(idx, img, self.quality, return_info, *args, **kwargs)
else:
# Ensure max_bits_per_image is not None before calling compress_with_target_size
assert self.max_bits_per_image is not None, "max_bits_per_image must be set if quality is not provided"
# Pass *args, **kwargs
result = self.compress_with_target_size(idx, img, self.max_bits_per_image, return_info, *args, **kwargs)
return result
def _setup_temp_paths(self, idx: int) -> Dict[str, str]:
"""Create temporary directory and generate file paths.
This method creates a temporary directory with unique filenames for:
- Input image (PNG format)
- Compressed image (BPG format)
- Decompressed output (PNG format)
- Best output for binary search (PNG format)
Args:
idx: Image index for generating unique filenames
Returns:
Dict containing paths for 'dir', 'input', 'compressed', 'output', 'best_output'
"""
temp_dir = tempfile.mkdtemp(prefix="bpg_")
uid = f"{idx}_{uuid.uuid4()}"
paths = {"dir": temp_dir, "input": os.path.join(temp_dir, f"input_{uid}.png"), "compressed": os.path.join(temp_dir, f"compressed_{uid}.bpg"), "output": os.path.join(temp_dir, f"output_{uid}.png"), "best_output": os.path.join(temp_dir, f"best_{uid}.png")}
return paths
[docs]
def compress_with_quality(self, idx: int, x: torch.Tensor, quality: int, return_info: bool = False, *args: Any, **kwargs: Any):
"""Compress image with a specific quality level.
Args:
idx: Image index for generating unique filenames
x: Image tensor
quality: BPG quality level (0-51)
return_info: Whether to return compression information
*args: Additional positional arguments (unused in this method).
**kwargs: Additional keyword arguments (unused in this method).
Returns:
If return_info=False: Compressed-decompressed image tensor
If return_info=True: Tuple of (tensor, info_dict)
"""
paths = self._setup_temp_paths(idx)
# Save input image
save_image(x, paths["input"])
# Measure original file size
original_size = os.path.getsize(paths["input"])
# Compress with specified quality using safe subprocess execution
result_enc = self._safe_subprocess_run([self.bpg_encoder_path, "-q", str(quality), "-o", paths["compressed"], paths["input"]], text=True)
if result_enc.returncode != 0:
logger.error(f"BPG encoding failed: {result_enc.stderr}")
shutil.rmtree(paths["dir"])
# Return directly, don't reassign result_enc
return (torch.randn_like(x), {"quality": -1, "bits": 0}) if return_info else torch.randn_like(x)
# Get compressed size
compressed_size = os.path.getsize(paths["compressed"])
bits = compressed_size * 8
# Read compressed data if needed
compressed_data = None
if self.return_compressed_data and return_info:
with open(paths["compressed"], "rb") as f:
compressed_data = f.read()
# Decompress using safe subprocess execution
result_dec = self._safe_subprocess_run([self.bpg_decoder_path, "-o", paths["output"], paths["compressed"]], text=True)
if result_dec.returncode != 0:
logger.error(f"BPG decoding failed: {result_dec.stderr}")
shutil.rmtree(paths["dir"])
# Return directly, don't reassign result_dec
return (torch.randn_like(x), {"quality": -1, "bits": 0}) if return_info else torch.randn_like(x)
# Load result
transform = transforms.ToTensor()
img = transform(Image.open(paths["output"]).convert("RGB"))
# Prepare result
if return_info:
stats = {"quality": quality, "bits": bits, "bpp": bits / (x.shape[1] * x.shape[2]), "compression_ratio": original_size / compressed_size if compressed_size > 0 else 0}
if compressed_data is not None:
stats["compressed_data"] = compressed_data
# Assign to final_result instead of result
final_result = (img, stats)
else:
# Assign to final_result instead of result
final_result = img
# Cleanup
shutil.rmtree(paths["dir"])
# Return the final_result
return final_result
# Change target_bits type hint from Optional[int] to int
[docs]
def compress_with_target_size(self, idx: int, x: torch.Tensor, target_bits: int, return_info: bool = False, *args: Any, **kwargs: Any):
"""Find highest quality that produces file size below target_bits using binary search.
Args:
idx: Image index for generating unique filenames
x: Image tensor
target_bits: Maximum bits for the compressed image
return_info: Whether to return compression information
*args: Additional positional arguments (unused in this method).
**kwargs: Additional keyword arguments (unused in this method).
Returns:
If return_info=False: Compressed-decompressed image tensor
If return_info=True: Tuple of (tensor, info_dict)
"""
paths = self._setup_temp_paths(idx)
# Save input image
save_image(x, paths["input"])
original_size = os.path.getsize(paths["input"])
transform = transforms.ToTensor()
# Perform initial quality estimates using safe subprocess execution
initial_quality = 30
result_init = self._safe_subprocess_run([self.bpg_encoder_path, "-q", str(initial_quality), "-o", paths["compressed"], paths["input"]], text=True)
if result_init.returncode == 0:
bits_at_q30 = os.path.getsize(paths["compressed"]) * 8
# Check against target_bits (now guaranteed to be int)
if bits_at_q30 <= target_bits:
# Quality can be higher, start from here
low, high = initial_quality, 51
else:
# Need lower quality
low, high = 0, initial_quality - 1
# Clean up the test file
os.remove(paths["compressed"])
else:
# Fallback to full range if initial test fails
low, high = 0, 51
# Binary search for the highest quality that meets the target bit size
best_quality = -1
# Initialize best_bits as float
best_bits: float = 0.0
while low <= high:
mid = (low + high) // 2
# Try compression with the current quality using safe subprocess execution
result_bs = self._safe_subprocess_run([self.bpg_encoder_path, "-q", str(mid), "-o", paths["compressed"], paths["input"]], text=True)
if result_bs.returncode != 0:
logger.error(f"BPG encoding failed at quality {mid}: {result_bs.stderr}")
high = mid - 1
continue
# Check file size
bytes_out = os.path.getsize(paths["compressed"])
bitrate_out = float(bytes_out) * 8
# Check against target_bits (now guaranteed to be int)
if bitrate_out <= target_bits:
# This quality works - save it and try higher quality
best_quality = mid
# Assign float directly
best_bits = bitrate_out
# Decode the image using safe subprocess execution
result_dec_bs = self._safe_subprocess_run([self.bpg_decoder_path, "-o", paths["output"], paths["compressed"]], text=True)
if result_dec_bs.returncode == 0:
# Save this as our best result so far
if os.path.exists(paths["best_output"]):
os.remove(paths["best_output"])
os.rename(paths["output"], paths["best_output"])
# Try higher quality
low = mid + 1
else:
# Quality too high, try lower
high = mid - 1
# Clean up compressed file
if os.path.exists(paths["compressed"]):
os.remove(paths["compressed"])
# Load the best image we found
if best_quality != -1:
img = transform(Image.open(paths["best_output"]).convert("RGB"))
if return_info:
# Read compressed data if requested
compressed_data = None
if self.return_compressed_data:
# We need to re-compress at the best quality to get the data
temp_compressed = os.path.join(paths["dir"], f"final_{uuid.uuid4()}.bpg")
result_final = self._safe_subprocess_run([self.bpg_encoder_path, "-q", str(best_quality), "-o", temp_compressed, paths["input"]], text=True)
if result_final.returncode == 0:
with open(temp_compressed, "rb") as f:
compressed_data = f.read()
os.remove(temp_compressed)
# Cast best_bits to int for stats dict if needed, or keep as float
stats = {"quality": best_quality, "bits": int(best_bits), "bpp": best_bits / (x.shape[1] * x.shape[2]), "compression_ratio": original_size / (best_bits / 8) if best_bits > 0 else 0, "target_bits": target_bits}
if compressed_data is not None:
stats["compressed_data"] = compressed_data
# Assign to final_result instead of result
final_result = (img, stats)
else:
# Assign to final_result instead of result
final_result = img
else:
logger.warning(f"Could not find any quality level meeting target of {target_bits} bits")
if return_info:
stats = {"quality": -1, "bits": 0, "target_bits": target_bits}
if self.return_compressed_data:
stats["compressed_data"] = b""
# Assign to final_result instead of result
final_result = (torch.randn_like(x), stats)
else:
# Assign to final_result instead of result
final_result = torch.randn_like(x)
# Cleanup
shutil.rmtree(paths["dir"])
# Return the final_result
return final_result
[docs]
def get_stats(self) -> Dict[str, Any]:
"""Return compression statistics if collect_stats=True was set.
Returns detailed compression statistics collected during the forward pass,
including total bits, average quality, bits per pixel, compression ratios,
and processing time.
Returns:
Dict containing compression statistics:
- total_bits: Total bits used for all images
- avg_quality: Average BPG quality level used
- avg_bpp: Average bits per pixel across all images
- avg_compression_ratio: Average compression ratio (original/compressed)
- processing_time: Time taken for compression
- img_stats: List of per-image statistics
Note:
Returns empty dict if collect_stats=False was set during initialization.
"""
if not self.collect_stats:
logger.warning("Statistics not collected. Initialize with collect_stats=True to enable.")
return {}
return self.stats
# Update method signature to align with class variable
[docs]
def get_bits_per_image(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> List[int]:
"""Compress images and return only the bit counts per image.
The compression method depends on whether quality or max_bits_per_image was provided
during initialization.
Args:
x: Tensor of shape [batch_size, channels, height, width]
*args: Additional positional arguments passed to forward.
**kwargs: Additional keyword arguments passed to forward.
Returns:
List[int]: Number of bits used for each compressed image
"""
# Temporarily override return_bits setting
original_return_bits = self.return_bits
self.return_bits = True
try:
# Pass *args, **kwargs to forward
forward_output = self.forward(x, *args, **kwargs)
# Ensure forward returned the expected tuple when return_bits is True
if isinstance(forward_output, tuple) and len(forward_output) >= 2:
bits_per_image = forward_output[1]
if not isinstance(bits_per_image, list):
raise TypeError(f"Expected list of bits, but got {type(bits_per_image)}")
else:
raise TypeError(f"Forward method did not return expected tuple (tensor, bits), got {type(forward_output)}")
finally:
# Restore original setting
self.return_bits = original_return_bits
return bits_per_image