Source code for kaira.models.fec.encoders.reed_muller_code

"""Reed-Muller coding module for forward error correction.

This module implements Reed-Muller codes for binary data transmission, a powerful class of
error-correcting codes defined by two parameters: r (order) and m (length parameter).
Reed-Muller codes are versatile linear block codes with well-defined structure and
theoretical properties. They include several other codes as special cases, such as
repetition codes and Hamming codes.

Reed-Muller codes are useful in scenarios requiring reliable transmission with
different error correction capabilities based on application requirements.
For more theoretical background on Reed-Muller codes, see :cite:`lin2004error,moon2005error`.

The implementation follows standard conventions in coding theory for binary linear block codes,
with elements belonging to the binary field GF(2) :cite:`richardson2008modern`.
"""

from itertools import combinations, product
from typing import Any, List, Tuple

import torch

from kaira.models.registry import ModelRegistry

from .linear_block_code import LinearBlockCodeEncoder


def _generate_evaluation_vectors(m: int) -> torch.Tensor:
    """Generate the evaluation vectors for Reed-Muller code.

    Args:
        m: Length parameter (resulting in 2^m code length)

    Returns:
        Tensor of shape (m, 2^m) containing evaluation vectors
    """
    v = torch.zeros((m, 2**m), dtype=torch.int64)
    for i in range(m):
        block_size = 2 ** (m - i - 1)
        # Create a block pattern of 0s followed by 1s
        block = torch.cat([torch.zeros(block_size, dtype=torch.int64), torch.ones(block_size, dtype=torch.int64)])
        # Assign in natural order: row i
        v[i] = block.repeat(2**i)

    return v


def _generate_reed_muller_matrix(r: int, m: int) -> torch.Tensor:
    """Generate the generator matrix for an RM(r,m) code.

    Args:
        r: Order parameter (0 ≤ r < m)
        m: Length parameter (resulting in 2^m code length)

    Returns:
        Generator matrix G for Reed-Muller code
    """
    if not (0 <= r < m):
        raise ValueError(f"Parameters must satisfy 0 ≤ r < m, got r={r}, m={m}")

    # Get evaluation vectors
    v = _generate_evaluation_vectors(m)

    # Build generator matrix
    rows = []

    # Add rows for each r-tuple combination of evaluation vectors
    for order in range(r, 0, -1):
        for indices in combinations(range(m), order):
            # Multiply the selected evaluation vectors elementwise
            row = v[list(indices)].prod(dim=0)
            rows.append(row)

    # Add all-ones row (corresponds to order 0)
    rows.append(torch.ones(2**m, dtype=torch.int64))

    # Stack all rows to form the generator matrix
    return torch.stack(rows) % 2  # Ensure binary values


def calculate_reed_muller_dimension(r: int, m: int) -> int:
    """Calculate the dimension (k) of a Reed-Muller code RM(r,m).

    Args:
        r: Order parameter (0 ≤ r < m)
        m: Length parameter

    Returns:
        Dimension k of the code
    """
    if not (0 <= r < m):
        raise ValueError(f"Parameters must satisfy 0 ≤ r < m, got r={r}, m={m}")

    # Dimension is sum_{i=0}^r binomial(m, i)
    from math import comb

    return sum(comb(m, i) for i in range(r + 1))


[docs] @ModelRegistry.register_model("reed_muller_code_encoder") class ReedMullerCodeEncoder(LinearBlockCodeEncoder): """Encoder for Reed-Muller codes. Reed-Muller codes are a family of linear error-correcting codes with parameters (r,m) where: - r is the order (0 ≤ r < m) - m is the length parameter The resulting code has the following properties: - Length: n = 2^m - Dimension: k = sum_{i=0}^r (m choose i) - Minimum distance: d = 2^(m-r) Special cases: - RM(0,m) is a repetition code - RM(m-1,m) is a single parity-check code - RM(1,m) is a first-order Reed-Muller code (also known as a lengthened simplex code) - RM(m-2,m) is an extended Hamming code This implementation follows the standard approach to Reed-Muller coding described in error control coding literature :cite:`lin2004error,moon2005error`. Attributes: order (int): The order r of the Reed-Muller code (0 ≤ r < m) length_param (int): The length parameter m (code length will be 2^m) minimum_distance (int): The minimum Hamming distance of the code (2^(m-r)) Args: order (int): The order r of the Reed-Muller code (0 ≤ r < m) length_param (int): The length parameter m *args: Variable positional arguments passed to the base class. **kwargs: Variable keyword arguments passed to the base class. """
[docs] def __init__(self, order: int, length_param: int, *args: Any, **kwargs: Any): """Initialize the Reed-Muller encoder. Args: order: The order r of the Reed-Muller code (0 ≤ r < m) length_param: The length parameter m (code length will be 2^m) *args: Variable positional arguments passed to the base class. **kwargs: Variable keyword arguments passed to the base class. Raises: ValueError: If the parameters do not satisfy 0 ≤ r < m """ if not (0 <= order < length_param): raise ValueError(f"Parameters must satisfy 0 ≤ r < m, got r={order}, m={length_param}") # Generate the generator matrix for the code and cast to float for tensor operations matrix = _generate_reed_muller_matrix(order, length_param) generator_matrix = matrix.to(torch.float) # Initialize the base class with the (float) generator matrix super().__init__(generator_matrix, *args, **kwargs) # Store Reed-Muller specific parameters self.order = order self.length_param = length_param self.minimum_distance = 2 ** (length_param - order)
def __repr__(self) -> str: """Formal string representation including key parameters.""" return f"ReedMullerCodeEncoder(order={self.order}, length_param={self.length_param}, " f"length={self.code_length}, dimension={self.code_dimension})"
[docs] @classmethod def from_parameters(cls, order: int, length_param: int, *args: Any, **kwargs: Any) -> "ReedMullerCodeEncoder": """Create a Reed-Muller encoder from parameters. This is an alternative constructor that creates the encoder directly from the Reed-Muller parameters. Args: order: The order r of the Reed-Muller code (0 ≤ r < m) length_param: The length parameter m *args: Variable positional arguments passed to the constructor. **kwargs: Variable keyword arguments passed to the constructor. Returns: A ReedMullerCodeEncoder instance """ return cls(order, length_param, *args, **kwargs)
[docs] def get_reed_partitions(self) -> List[torch.Tensor]: """Get the Reed partitions of the code. Reed partitions are useful for certain decoding algorithms, particularly majority-logic decoding. Returns: List of tensors representing the Reed partitions """ r, m = self.order, self.length_param reed_partitions = [] # Generate all binary vectors of various lengths binary_vectors = [] for ell in range(m + 1): vectors = torch.tensor(list(product([0, 1], repeat=ell)), dtype=torch.int64) if vectors.size(0) > 0: # Handle empty case for ell=0 binary_vectors.append(torch.flip(vectors, dims=[1])) else: binary_vectors.append(torch.zeros((1, 0), dtype=torch.int64)) # Generate Reed partitions for ell in range(r, -1, -1): for indices in combinations(range(m), ell): # Convert indices to tensor set_I = torch.tensor(list(indices), dtype=torch.int64) # Get complement set (indices not in set_I) set_E = torch.tensor([i for i in range(m) if i not in indices], dtype=torch.int64) # Calculate the components set_S = torch.matmul(binary_vectors[ell], torch.pow(2, set_I)) set_Q = torch.matmul(binary_vectors[m - ell], torch.pow(2, set_E)) # Form the partition partition = set_S.unsqueeze(1) + set_Q.unsqueeze(0) reed_partitions.append(partition) return reed_partitions
[docs] def inverse_encode(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> Tuple[torch.Tensor, torch.Tensor]: """Decode codeword(s) by brute-force nearest-codeword search. This method implements a maximum-likelihood decoder for Reed-Muller codes by finding the closest valid codeword to the received word(s). Args: x: The received codeword(s). Can be a single vector or a batch. *args: Additional positional arguments (unused). **kwargs: Additional keyword arguments (unused). Returns: Tuple containing: - Decoded message(s) - Syndrome (difference between closest valid codeword and received word) """ # Make input a batch if x.dim() == 1: y2d = x.unsqueeze(0) single = True else: y2d = x single = False device = y2d.device # Enumerate all possible messages (2^k of them) import itertools msgs = torch.tensor(list(itertools.product([0, 1], repeat=self.code_dimension)), dtype=torch.float, device=device) # (M, k) # Generate their codewords cws = (msgs @ self.generator_matrix) % 2 # (M, n) # Compute Hamming distances to each received y diff = (cws.unsqueeze(1) != y2d.unsqueeze(0)).float() dists = diff.sum(dim=2) # (M, B) best = dists.argmin(dim=0) # (B,) # Pick best messages and their codewords decoded = msgs[best] # (B, k) pred_cw = cws[best] # (B, n) syndrome = (pred_cw != y2d).float() # (B, n) if single: return decoded[0], syndrome[0] return decoded, syndrome
[docs] def calculate_syndrome(self, y: torch.Tensor): """Return the syndrome (error pattern) for given codeword(s). For Reed-Muller codes, we define the syndrome as the difference between the received word and the nearest valid codeword. Args: y: The received codeword(s) Returns: Syndrome tensor (difference between nearest valid codeword and received word) """ _, syn = self.inverse_encode(y) return syn