Source code for kaira.models.fec.encoders.bch_code

"""BCH code implementation for forward error correction.

This module implements Bose-Chaudhuri-Hocquenghem (BCH) codes, a class of cyclic error-correcting
codes that are constructed using polynomials over finite fields. BCH codes are powerful and
versatile, providing the ability to control the trade-off between redundancy and error-correcting
capability.

For given parameters μ ≥ 2 and δ satisfying 2 ≤ δ ≤ 2^μ - 1, a binary BCH code has
the following parameters, where δ = 2τ + 1:

- Length: n = 2^μ - 1
- Dimension: k ≥ n - μτ
- Redundancy: m ≤ μτ
- Minimum distance: d ≥ δ

This implementation handles narrow-sense, primitive BCH codes, which are optimal
for many applications requiring reliable transmission over noisy channels.

    :cite:`lin2004error`
    :cite:`moon2005error`
    :cite:`richardson2008modern`
"""

from functools import cache, lru_cache
from typing import Any, Dict, List, Optional, Union

import torch

from kaira.models.registry import ModelRegistry

from ..algebra import BinaryPolynomial, FiniteBifield
from .cyclic_code import CyclicCodeEncoder


@cache
def compute_bch_generator_polynomial(mu: int, delta: int) -> BinaryPolynomial:
    """Compute the generator polynomial for a BCH code.

    Args:
        mu: The parameter μ of the BCH code.
        delta: The design distance δ of the BCH code.

    Returns:
        The generator polynomial.
    """
    # Create the finite field
    field = FiniteBifield(mu)

    # Get the primitive element
    alpha = field.primitive_element()

    # Compute the minimal polynomials of alpha^1, alpha^2, ..., alpha^(delta-1)
    minimal_polys = set()
    for i in range(1, delta):
        minimal_poly = (alpha**i).minimal_polynomial()
        minimal_polys.add(minimal_poly)

    # Compute the LCM of the minimal polynomials
    if not minimal_polys:
        raise ValueError("No minimal polynomials found")

    # Convert the set to a list for consistent ordering
    minimal_polys_list = sorted(list(minimal_polys), key=lambda p: p.value)

    # Compute the LCM
    generator_poly = minimal_polys_list[0]
    for poly in minimal_polys_list[1:]:
        generator_poly = generator_poly.lcm(poly)

    return generator_poly


@cache
def get_valid_bose_distances(mu: int) -> List[int]:
    """Get all valid Bose distances for a given mu.

    Args:
        mu: The parameter μ of the BCH code.

    Returns:
        List of all valid Bose distances for the given mu.
    """

    valid_distances = []
    for delta in range(2, 2**mu):
        if is_bose_distance(mu, delta):
            valid_distances.append(delta)

    return valid_distances


@lru_cache(maxsize=64)
def is_bose_distance(mu: int, delta: int) -> bool:
    """Check if delta is a Bose distance for the given mu.

    A Bose distance is a value δ such that the BCH code with parameters μ and δ
    has a different generator polynomial than the BCH code with parameters μ and δ-1.

    Args:
        mu: The parameter μ of the BCH code.
        delta: The potential Bose distance δ.

    Returns:
        True if delta is a Bose distance, False otherwise.
    """
    # Simple checks first
    if delta < 2 or delta > 2**mu - 1:
        return False

    if delta == 2:
        return True  # δ=2 is always a Bose distance

    # Special cases for efficiency
    if delta == 3:
        return True  # δ=3 is always a Bose distance
    if delta == 5 and mu >= 3:
        return True  # δ=5 is a Bose distance for mu >= 3
    if delta == 2**mu - 1:
        return True  # Maximum possible δ is always a Bose distance

    # Check if the minimal polynomial of alpha^delta is already in the LCM set
    field = FiniteBifield(mu)
    alpha = field.primitive_element()

    # Get minimal polynomials for powers 1 to delta-1
    minimal_polys = set()
    for i in range(1, delta):
        minimal_poly = (alpha**i).minimal_polynomial()
        minimal_polys.add(minimal_poly.value)  # Store the value for easier comparison

    # Check if the minimal polynomial of alpha^delta is already included
    delta_poly = (alpha**delta).minimal_polynomial()
    return delta_poly.value not in minimal_polys


def create_bch_generator_matrix(length: int, generator_poly: BinaryPolynomial, dtype: torch.dtype = torch.float32, device: Optional[torch.device] = None) -> torch.Tensor:
    """Create the generator matrix for a BCH code.

    This function creates a systematic generator matrix for a BCH code.

    Args:
        length: The length of the code.
        generator_poly: The generator polynomial.
        dtype: The data type for the resulting tensor. Default is torch.float32.
        device: The device to place the tensor on. Default is None (uses current device).

    Returns:
        The generator matrix.
    """
    # Compute dimensions
    n = length
    redundancy = generator_poly.degree
    dimension = n - redundancy

    # Create the generator matrix
    G = torch.zeros((dimension, n), dtype=dtype, device=device)

    # First, set the identity matrix in the first k columns (for systematic form)
    for i in range(dimension):
        G[i, i] = 1.0

    # For each row, compute the parity part
    for i in range(dimension):
        # Multiply the message polynomial x^i by x^(n-k)
        message_poly = BinaryPolynomial(1 << i)
        shifted_poly = BinaryPolynomial(message_poly.value << redundancy)

        # Find the remainder when divided by the generator polynomial
        remainder = shifted_poly % generator_poly

        # Set the parity bits in the generator matrix
        # The remainder corresponds to the parity bits
        coeffs = remainder.to_coefficient_list()
        for j in range(min(len(coeffs), redundancy)):
            if coeffs[j] == 1:
                G[i, dimension + j] = 1.0

    return G


[docs] @ModelRegistry.register_model("bch_code_encoder") class BCHCodeEncoder(CyclicCodeEncoder): r"""Encoder for BCH (Bose–Chaudhuri–Hocquenghem) codes. BCH codes are a class of powerful cyclic error-correcting codes that can be designed to correct multiple errors. They are constructed using polynomials over finite fields and provide great flexibility in the trade-off between redundancy and error-correcting capability :cite:`lin2004error,richardson2008modern`. For given parameters μ ≥ 2 and δ satisfying 2 ≤ δ ≤ 2^μ - 1, a binary BCH code has the following parameters, where δ = 2τ + 1: - Length: n = 2^μ - 1 - Dimension: k ≥ n - μτ - Redundancy: m ≤ μτ - Minimum distance: d ≥ δ This implementation handles narrow-sense, primitive BCH codes :cite:`lin2004error,moon2005error,sklar2001digital`. Args: mu (int): The parameter μ of the code. Must satisfy μ ≥ 2. delta (int): The design distance δ of the code. Must satisfy 2 ≤ δ ≤ 2^μ - 1 and be a valid Bose distance. information_set (Union[List[int], torch.Tensor, str], optional): Information set specification. Default is "left". dtype (torch.dtype, optional): Data type for internal tensors. Default is torch.float32. **kwargs: Additional keyword arguments passed to the parent class. Examples: >>> encoder = BCHCodeEncoder(mu=4, delta=5) >>> print(f"Length: {encoder.length, Dimension: {encoder.dimension}, Redundancy: {encoder.redundancy}") Length: 15, Dimension: 7, Redundancy: 8 >>> message = torch.tensor([1., 0., 1., 1., 0., 1., 0.]) >>> codeword = encoder(message) >>> print(codeword) tensor([1., 0., 1., 1., 0., 1., 0., 1., 0., 0., 1., 1., 0., 0., 1.]) """
[docs] def __init__(self, mu: int, delta: int, information_set: Union[List[int], torch.Tensor, str] = "left", dtype: torch.dtype = torch.float32, **kwargs: Any): """Initialize the BCH code encoder. Args: mu: The parameter μ of the code. Must satisfy μ ≥ 2. delta: The design distance δ of the code. Must satisfy 2 ≤ δ ≤ 2^μ - 1 and be a valid Bose distance. information_set: Either indices of information positions, which must be a k-sublist of [0...n), or one of the strings 'left' or 'right'. Default is 'left'. dtype: Data type for internal tensors. Default is torch.float32. **kwargs: Additional keyword arguments passed to the parent class. Raises: ValueError: If μ < 2 or if δ is not a valid Bose distance. """ if mu < 2: raise ValueError("'mu' must satisfy mu >= 2") if not 2 <= delta <= 2**mu - 1: raise ValueError("'delta' must satisfy 2 <= delta <= 2**mu - 1") # Store parameters self._mu = mu self._delta = delta self._dtype = dtype # Check if delta is a valid Bose distance if not is_bose_distance(mu, delta): # Find the next valid Bose distance valid_distances = get_valid_bose_distances(mu) next_deltas = [d for d in valid_distances if d > delta] if next_deltas: next_delta = next_deltas[0] raise ValueError(f"'delta' must be a Bose distance (the next one is {next_delta})") else: raise ValueError("'delta' must be a Bose distance") # Compute the generator polynomial self._generator_polynomial = compute_bch_generator_polynomial(mu, delta) # Compute code parameters n = 2**mu - 1 m = self._generator_polynomial.degree k = n - m # Calculate error correction capability self._error_correction_capability = (delta - 1) // 2 # Get device from kwargs if provided # device = kwargs.get("device", None) # Create generator matrix # generator_matrix = create_bch_generator_matrix(length=n, generator_poly=self._generator_polynomial, dtype=dtype, device=device) # Initialize the parent class with proper parameters super().__init__(code_length=n, generator_polynomial=self._generator_polynomial.value, information_set=information_set, dtype=dtype, **kwargs) # Store dimensions self._length = n self._dimension = k self._redundancy = m # Create the finite field (used for decoding) self._field = FiniteBifield(mu) self._alpha = self._field.primitive_element() # Compute the check matrix self._compute_check_matrix() # Register the check matrix buffer self.register_buffer("check_matrix", self._check_matrix)
def _compute_check_matrix(self) -> None: """Compute the parity check matrix from the generator matrix.""" # For a systematic code, the check matrix H can be derived from the generator matrix G. # If G = [I_k | P], then H = [P^T | I_(n-k)] identity_part = torch.eye(self._redundancy, dtype=self._dtype, device=self.generator_matrix.device) parity_part = self.generator_matrix[:, self._dimension :].T # Construct H = [P^T | I_m] self._check_matrix = torch.cat([parity_part, identity_part], dim=1) @property def mu(self) -> int: """Parameter μ of the code.""" return self._mu @property def delta(self) -> int: """Design distance δ of the code.""" return self._delta @property def error_correction_capability(self) -> int: """Error correction capability of the code (t = ⌊(δ-1)/2⌋).""" return self._error_correction_capability
[docs] @lru_cache(maxsize=None) def minimum_distance(self) -> int: """Get the minimum distance of the code. For BCH codes, the minimum distance is at least the design distance. Returns: The minimum distance of the code, which is at least δ. """ return self._delta
[docs] @classmethod def from_design_rate(cls, mu: int, target_rate: float, **kwargs: Any) -> "BCHCodeEncoder": """Create a BCH code with a design rate close to the target rate. Args: mu: The parameter μ of the BCH code. target_rate: The target rate (k/n) of the code. **kwargs: Additional arguments passed to the constructor. Returns: A BCH code encoder with rate close to the target rate. Raises: ValueError: If no suitable code can be found. """ if mu < 2: raise ValueError("'mu' must satisfy mu >= 2") if not 0 < target_rate < 1: raise ValueError("'target_rate' must be between 0 and 1") # Get all valid Bose distances for this mu valid_distances = get_valid_bose_distances(mu) # Calculate the code length n = 2**mu - 1 # Find the delta that gives the closest rate to the target best_delta = None best_diff = float("inf") for delta in valid_distances: # Compute the generator polynomial to get the dimension generator_poly = compute_bch_generator_polynomial(mu, delta) k = n - generator_poly.degree rate = k / n diff = abs(rate - target_rate) if diff < best_diff: best_diff = diff best_delta = delta if best_delta is None: raise ValueError(f"Could not find a suitable BCH code for mu={mu} and rate={target_rate}") return cls(mu=mu, delta=best_delta, **kwargs)
[docs] @classmethod def get_standard_codes(cls) -> Dict[str, Dict[str, Any]]: """Get a dictionary of standard BCH codes with their parameters. Returns: Dictionary mapping code names to their parameters. """ return { "BCH(7,4)": {"mu": 3, "delta": 3}, # Equivalent to Hamming(7,4) "BCH(15,7)": {"mu": 4, "delta": 5}, # Can correct 2 errors "BCH(15,5)": {"mu": 4, "delta": 7}, # Can correct 3 errors "BCH(31,16)": {"mu": 5, "delta": 7}, # Can correct 3 errors "BCH(31,11)": {"mu": 5, "delta": 11}, # Can correct 5 errors "BCH(63,36)": {"mu": 6, "delta": 11}, # Can correct 5 errors "BCH(63,24)": {"mu": 6, "delta": 15}, # Can correct 7 errors "BCH(127,64)": {"mu": 7, "delta": 21}, # Can correct 10 errors "BCH(127,36)": {"mu": 7, "delta": 31}, # Can correct 15 errors "BCH(255,123)": {"mu": 8, "delta": 39}, # Can correct 19 errors "BCH(255,71)": {"mu": 8, "delta": 59}, # Can correct 29 errors }
[docs] @classmethod def create_standard_code(cls, name: str, **kwargs: Any) -> "BCHCodeEncoder": """Create a standard BCH code by name. Args: name: Name of the standard code from get_standard_codes(). **kwargs: Additional arguments passed to the constructor. Returns: A BCH code encoder for the requested standard code. Raises: ValueError: If the requested code is not recognized. """ standard_codes = cls.get_standard_codes() if name not in standard_codes: valid_names = list(standard_codes.keys()) raise ValueError(f"Unknown standard code: {name}. Valid options are: {valid_names}") params = standard_codes[name].copy() params.update(kwargs) return cls(**params)
def __repr__(self) -> str: """Return a string representation of the encoder. Returns: A string representation with key parameters """ return f"{self.__class__.__name__}(" f"mu={self._mu}, " f"delta={self._delta}, " f"length={self._length}, " f"dimension={self._dimension}, " f"redundancy={self._redundancy}, " f"t={self._error_correction_capability}, " f"dtype={self._dtype.__repr__()}" f")"
[docs] def calculate_syndrome_polynomial(self, received: List[Any]) -> List[Any]: """Calculate the syndrome polynomial for a received word. This method computes the syndrome polynomial S(x) for a received codeword by evaluating the received polynomial at powers of alpha, which are the roots of the generator polynomial. Args: received: List of field elements representing the received word Returns: List of syndrome values in the field, S = [S_0, S_1, ..., S_{2t-1}] """ syndrome = [] for i in range(1, 2 * self._error_correction_capability + 1): # Evaluate the received polynomial at alpha^i alpha_i = self._alpha**i eval_result = self._field(0) # Initialize with field zero element for j, bit in enumerate(received): if bit != self._field.zero: # For each non-zero bit, add alpha^(j*i) to the result eval_result = eval_result + (alpha_i**j) syndrome.append(eval_result) return syndrome