Source code for kaira.models.fec.encoders.ldpc_code
"""Low-Density Parity-Check (LDPC) Code module for forward error correction.
This module provides an implementation of Low-Density Parity-Check (LDPC) codes for binary data transmission,
a class of linear block codes widely used in error correction for digital communication. LDPC codes are known for
their sparse parity-check matrices, which enable efficient encoding and decoding using iterative algorithms.
The implementation follows common conventions in coding theory with particular focus
on LDPC codes which are defined by a sparse parity-check matrix H.
References:
:cite:`gallager1962low`, :cite:`gallager1963low`, :cite:`richardson2008modern`
"""
from typing import Any
import torch
from kaira.models.registry import ModelRegistry
from ..encoders.linear_block_code import LinearBlockCodeEncoder
from ..rptu_database import CITATION, EXISTING_CODES, get_code_from_database, parse_alist
from ..utils import row_reduction
[docs]
@ModelRegistry.register_model("ldpc_code_encoder")
class LDPCCodeEncoder(LinearBlockCodeEncoder):
"""Encoder for LDPC code :cite:`gallager1962low`, :cite:`gallager1963low`.
This encoder follows conventional approach of linear block codes and
transforms binary input messages into codewords according to
the calculated generator matrix. It serves as the encoding component of
a linear block code system.
The encoder applies the formula: c = mG, where:
- c is the codeword
- m is the message
- G is the generator matrix
This implementation follows the standard approach to linear block coding described in the
error control coding literature :cite:`lin2004error,moon2005error,sklar2001digital`.
Attributes:
generator_matrix (torch.Tensor): The generator matrix G of the code
check_matrix (torch.Tensor): The parity check matrix H
"""
[docs]
def __init__(self, check_matrix: torch.Tensor = None, rptu_database: bool = False, *args: Any, **kwargs: Any):
"""Initializes the linear block encoder for LDPC codes.
Args:
check_matrix (torch.Tensor, optional): The parity check matrix for encoding.
Should be a binary matrix of shape (code_length - code_dimension, code_length), where
code_dimension is the message length and code_length is the codeword length.
If None and `rptu_database` is True, the matrix
will be loaded from the RPTU database.
rptu_database (bool, optional): If True, loads the check matrix from the RPTU
code database using parameters provided in `kwargs`. Default is False.
*args: Additional positional arguments passed to the base class.
**kwargs: Additional keyword arguments. Expected keys when `rptu_database` is True:
- code_length (int): Codeword length.
- code_dimension (int): Message length.
- rptu_standart (str, optional): Standard name for the LDPC code. If not provided,
the first available standard is used.
- device (str, optional): Device to place the tensors on (e.g., "cpu" or "cuda").
Raises:
ValueError: If the requested (code_length, code_dimension) code or standard is not found in the RPTU database.
"""
# Validate input parameters
if not rptu_database and check_matrix is None:
raise ValueError("Either a valid `check_matrix` must be provided or `rptu_database` must be set to True.")
# Initialize the base class from rptu_database or provided check_matrix
if rptu_database:
print("Loading LDPC code from RPTU database...")
print(CITATION)
print("------------------------------------")
code_length = kwargs.get("code_length", None)
code_dimension = kwargs.get("code_dimension", None)
rptu_standart = kwargs.get("rptu_standart", None)
if code_length is None or code_dimension is None:
raise ValueError("code_length and code_dimension must be provided when using rptu_database.")
code_key = (code_length, code_dimension)
if code_key in EXISTING_CODES.keys():
if rptu_standart is not None:
if rptu_standart not in EXISTING_CODES[code_key].keys():
raise ValueError(f"LDPC code with (code_length={code_length}, code_dimension={code_dimension}) and rptu_standart='{rptu_standart}' not found in rptu_database.")
else:
rptu_standart = list(EXISTING_CODES[code_key].keys())[0] # Default to first available standard
print(f"Using default rptu_standart='{rptu_standart}' for (code_length={code_length}, code_dimension={code_dimension}).")
content = get_code_from_database(EXISTING_CODES[code_key][rptu_standart])
else:
print(f"Available LDPC codes from rptu database: {EXISTING_CODES.keys()}")
raise ValueError(f"LDPC code with (code_length={code_length}, code_dimension={code_dimension}) not found in rptu_database.")
check_matrix = parse_alist(content)
self.device = kwargs.get("device", "cpu")
# Ensure generator matrix is a torch tensor
if not isinstance(check_matrix, torch.Tensor):
check_matrix = torch.tensor(check_matrix).to(self.device)
if check_matrix.device != self.device:
check_matrix = check_matrix.to(self.device)
generator_matrix = self.get_generator_matrix(check_matrix)
# Initialize the base class with dimensions
super().__init__(generator_matrix=generator_matrix, check_matrix=check_matrix)
[docs]
def get_generator_matrix(self, check_matrix_: torch.Tensor) -> torch.Tensor:
"""Derive the generator matrix from a parity check matrix.
This method computes the generator matrix for an LDPC code by:
1. Transposing the parity check matrix
2. Appending an identity matrix to obtain [H | I]
3. Performing Gaussian elimination (row reduction) to obtain [A | B]
4. Extracting the generator matrix from the result
The process ensures that G·Hᵀ = 0, which is the defining property of a valid
generator matrix for the code.
Args:
check_matrix_: The parity check matrix of the LDPC code
Returns:
The generator matrix for the LDPC code
"""
check_matrix = check_matrix_.clone().to(torch.int64).t()
check_matrix_eye = torch.cat((check_matrix, torch.eye(check_matrix.shape[0]).to(bool).to(check_matrix.device)), dim=1)
check_matrix_eye, rank = row_reduction(check_matrix_eye, num_cols=check_matrix.shape[1])
generator_matrix = row_reduction(check_matrix_eye[rank:, check_matrix.shape[1] :])[0]
return generator_matrix