Source code for kaira.models.fec.encoders.systematic_linear_block_code

"""Systematic linear block coding module for forward error correction.

This module implements systematic linear block coding for binary data transmission, a specific
form of linear block coding where the information bits appear unchanged in predefined positions
of the codeword. The remaining positions, called the parity set, contain parity bits calculated
from the information bits :cite:`lin2004error,moon2005error,richardson2008modern`.

Systematic codes have the advantage that the original message can be directly extracted from
the codeword without decoding, making them practical for many applications. Their generator
matrices have a specific structure where the columns indexed by the information set form an
identity matrix.
"""

from typing import Any, List, Tuple, Union

import torch

from kaira.models.registry import ModelRegistry

from ..utils import apply_blockwise
from .linear_block_code import LinearBlockCodeEncoder


def create_systematic_generator_matrix(parity_submatrix: torch.Tensor, information_set: Union[List[int], torch.Tensor, str] = "left") -> torch.Tensor:
    """Create a systematic generator matrix from a parity submatrix and information set.

    Args:
        parity_submatrix: The parity submatrix P of shape (k, m) where k is the dimension
            and m is the redundancy.
        information_set: Either indices of information positions, which must be a k-sublist
            of [0...n), or one of the strings 'left' or 'right'. Default is 'left'.

    Returns:
        A systematic generator matrix of shape (k, n)

    Raises:
        ValueError: If information_set is invalid.
    """
    k, m = parity_submatrix.shape
    n = k + m

    # Create a generator matrix of the proper size
    generator_matrix = torch.zeros((k, n), dtype=parity_submatrix.dtype)

    # Get information and parity sets
    information_indices, parity_indices = get_information_and_parity_sets(k, n, information_set)

    # Set the identity matrix at the information positions
    generator_matrix[:, information_indices] = torch.eye(k, dtype=parity_submatrix.dtype)

    # Set the parity submatrix at the parity positions
    generator_matrix[:, parity_indices] = parity_submatrix

    return generator_matrix


def get_information_and_parity_sets(k: int, n: int, information_set: Union[List[int], torch.Tensor, str]) -> Tuple[torch.Tensor, torch.Tensor]:
    """Determine the information and parity sets for a systematic code.

    Args:
        k: Code dimension (information length)
        n: Code length
        information_set: Either indices of information positions, which must be a k-sublist
            of [0...n), or one of the strings 'left' or 'right'. Default is 'left'.

    Returns:
        Tuple containing:
            - information_indices: Tensor of information set indices
            - parity_indices: Tensor of parity set indices

    Raises:
        ValueError: If information_set is invalid.
    """
    # Process the information set
    if isinstance(information_set, str):
        if information_set == "left":
            information_indices = torch.arange(k)
            parity_indices = torch.arange(k, n)
        elif information_set == "right":
            information_indices = torch.arange(n - k, n)
            parity_indices = torch.arange(n - k)
        else:
            raise ValueError("If string, information_set must be 'left' or 'right'")
    else:
        # Convert to tensor if it's a list
        if isinstance(information_set, list):
            information_indices = torch.tensor(information_set)
        else:
            information_indices = information_set

        # Validate information indices
        if information_indices.size(0) != k or information_indices.min() < 0 or information_indices.max() >= n:
            raise ValueError(f"information_set must be a {k}-sublist of [0...{n})")

        # Calculate parity indices as the complement
        all_indices = torch.arange(n)
        parity_indices = torch.tensor([i for i in all_indices if i not in information_indices])

    return information_indices, parity_indices


[docs] @ModelRegistry.register_model("systematic_linear_block_code_encoder") class SystematicLinearBlockCodeEncoder(LinearBlockCodeEncoder): r"""Encoder for systematic linear block coding. A systematic linear block code is a linear block code in which the information bits can be found in predefined positions in the codeword, called the information set K, which is a k-sublist of [0...n). The remaining positions are called the parity set M, which is an m-sublist of [0...n). In this case, the generator matrix has the property that the columns indexed by K are equal to I_k (identity matrix), and the columns indexed by M are equal to P (the parity submatrix). The check matrix has the property that the columns indexed by M are equal to I_m, and the columns indexed by K are equal to P^T. This implementation follows the standard approach to systematic linear block coding described in the error control coding literature :cite:`lin2004error,moon2005error,sklar2001digital`. Args: parity_submatrix (torch.Tensor): The parity submatrix for the code. information_set: Either indices of information positions, which must be a k-sublist of [0...n), or one of the strings 'left' or 'right'. Default is 'left'. """
[docs] def __init__(self, parity_submatrix: torch.Tensor, information_set: Union[List[int], torch.Tensor, str] = "left", **kwargs: Any): """Initialize the systematic linear block encoder. Args: parity_submatrix: The parity submatrix P for the code. Must be a binary matrix of shape (k, m) where k is the message length and m is the redundancy. **kwargs: Variable keyword arguments passed to the base class. Raises: ValueError: If parity_submatrix or information_set are invalid. """ # Ensure parity submatrix is a torch tensor if not isinstance(parity_submatrix, torch.Tensor): parity_submatrix = torch.tensor(parity_submatrix, dtype=torch.float32) # Store local copies of key attributes self._parity_submatrix = parity_submatrix k, m = self._parity_submatrix.shape n = k + m # Store the original information_set configuration if it's a string self._info_set_config = information_set if isinstance(information_set, str) else None # Store the information and parity sets before using them to create the generator matrix self._information_set, self._parity_set = get_information_and_parity_sets(k, n, information_set) # Create the systematic generator matrix generator_matrix = create_systematic_generator_matrix(parity_submatrix=parity_submatrix, information_set=information_set) # Make a copy of kwargs to avoid modifying the original kwargs_copy = kwargs.copy() if "generator_matrix" in kwargs_copy: del kwargs_copy["generator_matrix"] # Initialize the parent class with this generator matrix # This will set up _length, _dimension, _redundancy, check_matrix, and generator_right_inverse super().__init__(generator_matrix=generator_matrix, **kwargs_copy) # After parent initialization, register buffers with different names to avoid conflicts with properties self.register_buffer("_info_set_buffer", self._information_set) self.register_buffer("_parity_set_buffer", self._parity_set) self.register_buffer("_parity_submatrix_buffer", self._parity_submatrix)
@property def information_set(self) -> torch.Tensor: """Either indices of information positions, which must be a k-sublist of [0...n), or one of the strings 'left' or 'right'. Default is 'left'. """ return self._info_set_buffer @property def parity_submatrix(self) -> torch.Tensor: """Parity submatrix P of the code.""" return self._parity_submatrix_buffer @property def parity_set(self) -> torch.Tensor: """Parity set M of the code.""" return self._parity_set_buffer
[docs] def project_word(self, x: torch.Tensor) -> torch.Tensor: """Project a codeword onto the information set. This extracts the information bits directly from a codeword without decoding, which is a key advantage of systematic codes. Args: x: Input tensor of shape (..., codeword_length) or (..., b*codeword_length) where b is a positive integer. Returns: Projected tensor of shape (..., message_length) or (..., b*message_length) Raises: ValueError: If the last dimension of the input is not a multiple of n. """ # Get the last dimension size last_dim_size = x.shape[-1] # Check if the last dimension is a multiple of n if last_dim_size % self._length != 0: raise ValueError(f"Last dimension size {last_dim_size} must be a multiple of " f"the code length {self._length}") # Define projection function to apply to blocks def projection_fn(reshaped_x): # Extract information bits directly from the corresponding positions return reshaped_x[..., self.information_set] # Use apply_blockwise to handle the projection return apply_blockwise(x, self._length, projection_fn)
[docs] def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor: """Encode the input tensor using systematic encoding. For systematic codes, encoding can be done efficiently by placing information bits directly in the information positions and calculating parity bits only. This implementation is optimized compared to the general matrix multiplication used in the parent class. Args: x: The input tensor of shape (..., message_length) or (..., b*message_length) where b is a positive integer. *args: Additional positional arguments (unused). **kwargs: Additional keyword arguments (unused). Returns: Encoded tensor of shape (..., codeword_length) or (..., b*codeword_length) Raises: ValueError: If the last dimension of the input is not a multiple of k. """ # Get the last dimension size last_dim_size = x.shape[-1] # Check if the last dimension is a multiple of k if last_dim_size % self._dimension != 0: raise ValueError(f"Last dimension size {last_dim_size} must be a multiple of " f"the code dimension {self._dimension}") # Define systematic encoding function to apply to blocks def systematic_encode_fn(reshaped_x): # Compute parity bits parity_bits = torch.matmul(reshaped_x, self.parity_submatrix.to(reshaped_x.dtype)) % 2 # Create output tensor of the right shape batch_shape = reshaped_x.shape[:-1] codewords = torch.zeros((*batch_shape, self._length), dtype=reshaped_x.dtype, device=reshaped_x.device) # Place information bits directly codewords[..., self.information_set] = reshaped_x # Place parity bits codewords[..., self.parity_set] = parity_bits return codewords # Use apply_blockwise to handle the encoding return apply_blockwise(x, self._dimension, systematic_encode_fn)
def __repr__(self) -> str: """Return a string representation of the object. Returns: String representation with key parameters """ return f"{self.__class__.__name__}(" f"parity_submatrix=tensor(...), " f"information_set=tensor({self._information_set.tolist()}), " f"dimension={self._dimension}, " f"length={self._length}, " f"redundancy={self._redundancy}" f")"