Source code for kaira.models.fec.encoders.hamming_code

"""Hamming code implementation for forward error correction.

This module implements Hamming codes, a family of linear error-correcting codes that can detect
up to two-bit errors and correct single-bit errors. For a given parameter μ ≥ 2, a Hamming code
has the following parameters:

- Length: n = 2^μ - 1
- Dimension: k = 2^μ - μ - 1
- Redundancy: m = μ
- Minimum distance: d = 3

In its extended version, the Hamming code has the following parameters:

- Length: n = 2^μ
- Dimension: k = 2^μ - μ - 1
- Redundancy: m = μ + 1
- Minimum distance: d = 4

Hamming codes are perfect codes, meaning they achieve the theoretical limit for the number
of correctable errors given their length and dimension :cite:`lin2004error,moon2005error`.
"""

import itertools
from functools import lru_cache
from typing import Any, List, Optional, Union

import torch

from kaira.models.registry import ModelRegistry

from .systematic_linear_block_code import SystematicLinearBlockCodeEncoder


def create_hamming_parity_submatrix(mu: int, extended: bool = False, dtype: torch.dtype = torch.float32, device: Optional[torch.device] = None) -> torch.Tensor:
    """Create the parity submatrix for a Hamming code.

    The parity submatrix has columns that are all possible non-zero binary μ-tuples.
    For extended Hamming codes, an additional row of all ones is added :cite:`lin2004error`.

    Args:
        mu: The parameter μ of the code. Must satisfy μ ≥ 2.
        extended: Whether to create an extended Hamming code. Default is False.
        dtype: The data type for tensor elements. Default is torch.float32.
        device: The device to place the resulting tensor on. Default is None (uses current device).

    Returns:
        The parity submatrix of the Hamming code.
    """
    # Validate input
    if mu < 2:
        raise ValueError("'mu' must be at least 2")

    # Calculate dimensions
    k = 2**mu - mu - 1  # Dimension (information length)
    m = mu  # Redundancy (parity length)

    # Create empty parity submatrix
    parity_submatrix = torch.zeros((k, m), dtype=dtype, device=device)

    # Optimized implementation for small mu values (common case)
    if mu <= 8:  # Arbitrary threshold based on practical use cases
        # Create all possible weight-1 binary tuples directly

        # Each column of the check matrix is a non-zero binary μ-tuple
        # For Hamming codes, we can generate these systematically

        # Start counter for filling parity submatrix
        row_idx = 0

        # Generate all weight 2+ combinations
        for w in range(2, mu + 1):
            for indices in itertools.combinations(range(mu), w):
                # Create a tuple with 1s at the specified positions
                row = torch.zeros(mu, dtype=dtype, device=device)
                row.index_fill_(0, torch.tensor(indices, device=device), 1.0)
                parity_submatrix[row_idx, :] = row
                row_idx += 1
    else:
        # For very large mu values, use the original implementation
        # Create all binary tuples of length μ (except all zeros)
        nonzero_tuples = []
        for w in range(1, mu + 1):
            for indices in itertools.combinations(range(mu), w):
                binary_tuple = torch.zeros(mu, dtype=dtype, device=device)
                binary_tuple[list(indices)] = 1
                nonzero_tuples.append(binary_tuple)

        # Construct check matrix with all nonzero tuples as columns
        check_matrix = torch.stack(nonzero_tuples, dim=1)

        # Create systematic parity submatrix by rearranging columns
        # The parity submatrix P consists of the columns of the check matrix
        # corresponding to the information set
        i = 0
        for w in range(2, mu + 1):
            for indices in itertools.combinations(range(mu), w):
                tuple_idx = nonzero_tuples.index(torch.zeros(mu, dtype=dtype, device=device).index_put_([list(indices)], torch.ones(len(indices), device=device)))
                parity_submatrix[i, :] = check_matrix[:, tuple_idx].T
                i += 1

    # For extended Hamming code, add an overall parity check
    if extended:
        # Add a row of all ones to the parity submatrix
        parity_extension = torch.ones((k, 1), dtype=dtype, device=device)
        parity_submatrix = torch.cat([parity_submatrix, parity_extension], dim=1)

    return parity_submatrix


[docs] @ModelRegistry.register_model("hamming_code_encoder") class HammingCodeEncoder(SystematicLinearBlockCodeEncoder): r"""Encoder for Hamming codes. Hamming codes are linear error-correcting codes that can detect up to two-bit errors and correct single-bit errors. They are perfect codes, meaning they achieve the theoretical limit for the number of correctable errors given their length and dimension :cite:`lin2004error,richardson2008modern`. For a given parameter μ ≥ 2, a Hamming code has the following parameters: - Length: n = 2^μ - 1 - Dimension: k = 2^μ - μ - 1 - Redundancy: m = μ - Minimum distance: d = 3 In its extended version, the Hamming code has the following parameters: - Length: n = 2^μ - Dimension: k = 2^μ - μ - 1 - Redundancy: m = μ + 1 - Minimum distance: d = 4 The implementation follows standard techniques in error control coding literature :cite:`lin2004error,moon2005error,sklar2001digital`. Args: mu (int): The parameter μ of the code. Must satisfy μ ≥ 2. extended (bool, optional): Whether to use the extended version of the Hamming code. Default is False. information_set (Union[List[int], torch.Tensor, str], optional): Information set specification. Default is "left". dtype (torch.dtype, optional): Data type for internal tensors. Default is torch.float32. **kwargs: Additional keyword arguments passed to the parent class. Examples: >>> encoder = HammingCodeEncoder(mu=3) >>> print(f"Length: {encoder.length}, Dimension: {encoder.dimension}, Redundancy: {encoder.redundancy}") Length: 7, Dimension: 4, Redundancy: 3 >>> message = torch.tensor([1., 0., 1., 1.]) >>> codeword = encoder(message) >>> print(codeword) tensor([1., 0., 1., 1., 0., 1., 1.]) >>> # Using the extended version >>> ext_encoder = HammingCodeEncoder(mu=3, extended=True) >>> print(f"Length: {ext_encoder.length}, Dimension: {ext_encoder.dimension}, Redundancy: {ext_encoder.redundancy}") Length: 8, Dimension: 4, Redundancy: 4 >>> message = torch.tensor([1., 0., 1., 1.]) >>> codeword = ext_encoder(message) >>> print(codeword) tensor([1., 0., 1., 1., 0., 1., 1., 0.]) """
[docs] def __init__(self, mu: int, extended: bool = False, information_set: Union[List[int], torch.Tensor, str] = "left", dtype: torch.dtype = torch.float32, **kwargs: Any): """Initialize the Hamming code encoder. Args: mu: The parameter μ of the code. Must satisfy μ ≥ 2. extended: Whether to use the extended version of the Hamming code. Default is False. information_set: Either indices of information positions, which must be a k-sublist of [0...n), or one of the strings 'left' or 'right'. Default is 'left'. dtype: Data type for internal tensors. Default is torch.float32. **kwargs: Additional keyword arguments passed to the parent class. Raises: ValueError: If mu < 2. """ if mu < 2: raise ValueError("'mu' must be at least 2") # Store parameters self._mu = mu self._extended = extended self._dtype = dtype # Calculate theoretical parameters based on mu self._theoretical_length = 2**mu - 1 if not extended else 2**mu self._theoretical_dimension = 2**mu - mu - 1 self._theoretical_redundancy = mu if not extended else mu + 1 # Get device from kwargs if provided device = kwargs.get("device", None) # Create parity submatrix for Hamming code parity_submatrix = create_hamming_parity_submatrix(mu=mu, extended=extended, dtype=dtype, device=device) # Initialize the parent class with this parity submatrix super().__init__(parity_submatrix=parity_submatrix, information_set=information_set, **kwargs) # Validate that the calculated dimensions match the theoretical ones self._validate_dimensions()
def _validate_dimensions(self) -> None: """Validate that the code dimensions match the theoretical values.""" if self._length != self._theoretical_length: raise ValueError(f"Code length mismatch: calculated {self._length}, " f"expected {self._theoretical_length}") if self._dimension != self._theoretical_dimension: raise ValueError(f"Code dimension mismatch: calculated {self._dimension}, " f"expected {self._theoretical_dimension}") if self._redundancy != self._theoretical_redundancy: raise ValueError(f"Code redundancy mismatch: calculated {self._redundancy}, " f"expected {self._theoretical_redundancy}") @property def mu(self) -> int: """Get the parameter μ of the code.""" return self._mu @property def extended(self) -> bool: """Get whether this is an extended Hamming code.""" return self._extended
[docs] @lru_cache(maxsize=None) def minimum_distance(self) -> int: """Calculate the minimum Hamming distance of the code. Returns: The minimum Hamming distance: - 3 for standard Hamming code - 4 for extended Hamming code """ return 4 if self._extended else 3
def __repr__(self) -> str: """Return a string representation of the encoder. Returns: A string representation with key parameters """ return f"{self.__class__.__name__}(" f"mu={self._mu}, " f"extended={self._extended}, " f"length={self._length}, " f"dimension={self._dimension}, " f"redundancy={self._redundancy}, " f"dtype={self._dtype.__repr__()}" f")"
[docs] def inverse_encode(self, y): """Decode a codeword back to its original message, correcting single-bit errors. Args: y: A tensor of codewords to decode. The last dimension should match the code length. Supports batch processing with arbitrary batch dimensions. Returns: tuple: (decoded_message, syndrome) - decoded_message: The decoded information bits with corrected errors - syndrome: The syndrome vectors for each codeword """ # Calculate syndrome syndrome = self.calculate_syndrome(y) # Prepare shapes original_dims = y.size()[:-1] y_reshaped = y.reshape(-1, self.code_length).clone() syndrome_reshaped = syndrome.reshape(-1, self.redundancy) # Vectorized error correction error_positions = torch.tensor([self._syndrome_to_error_position(s) for s in syndrome_reshaped], device=y.device) # Create mask for valid error positions (less than code_length) valid_errors = error_positions < self.code_length # Apply corrections using vectorized operations if valid_errors.any(): # Create batch indices for the samples with errors batch_indices = torch.nonzero(valid_errors, as_tuple=True)[0] # Get corresponding error positions pos = error_positions[valid_errors] # Flip the bits at error positions for i, p in zip(batch_indices, pos): y_reshaped[i, p] = 1 - y_reshaped[i, p] # Extract information bits decoded = y_reshaped[..., self.information_set] decoded = decoded.reshape(*original_dims, self.code_dimension) return decoded, syndrome
def _syndrome_to_error_position(self, syndrome): """Convert a syndrome to an error position by matching check matrix columns.""" # check_matrix shape: (m, n) # compare syndrome to each column of check_matrix H = self.check_matrix # shape (m, n) # ensure float dtype syn = syndrome.float() for j in range(self.code_length): col = H[:, j].float() if torch.equal(col, syn): return j return self.code_length