Source code for kaira.models.fec.decoders.base

"""Base decoders module for forward error correction.

This module implements base decoder classes for various forward error correction techniques.
Decoders are responsible for recovering the original message from received codewords that
may contain errors introduced during transmission over noisy channels.

The module provides a type-generic architecture that ensures correct pairing between encoders
and their corresponding decoders, while maintaining flexibility for different decoding
algorithms and implementation strategies.

:cite:`lin2004error`
:cite:`moon2005error`
:cite:`richardson2008modern`
"""

from abc import ABC, abstractmethod
from typing import Any, Generic, Tuple, TypeVar, Union

import torch

from kaira.models.base import BaseModel

from ..encoders.base import BaseBlockCodeEncoder

T = TypeVar("T", bound=BaseBlockCodeEncoder)


[docs] class BaseBlockDecoder(BaseModel, Generic[T], ABC): """Base class for block code decoders. This abstract class provides a common interface and functionality for all types of block code decoders. It serves as a foundation for specific implementations like syndrome decoders, maximum likelihood decoders, algebraic decoders, and soft-decision decoders. The class uses a generic type parameter T to ensure type safety when pairing encoders with their corresponding decoders. This allows the compiler to catch type mismatches at development time rather than during runtime. Attributes: encoder (T): The encoder instance associated with this decoder, providing access to code parameters and encoding/syndrome calculation methods Args: encoder (T): The encoder instance for the code being decoded *args: Variable positional arguments passed to the base class **kwargs: Variable keyword arguments passed to the base class Note: All concrete implementations must override the forward method to implement specific decoding algorithms. Decoders may operate on hard-decision (binary) or soft-decision (real-valued reliability information) inputs depending on their implementation. """
[docs] def __init__(self, encoder: T, *args: Any, **kwargs: Any): """Initialize the block code decoder with an encoder instance. The encoder provides essential information about the code parameters and may be used by the decoder to perform syndrome calculations or other encoding-related operations during the decoding process. """ super().__init__(*args, **kwargs) self.encoder = encoder
@property def code_length(self) -> int: """Get the code length (n). The code length is the total number of bits in each codeword, including both information bits and redundancy bits. Returns: The length of the code (number of bits in a codeword) """ return self.encoder.code_length @property def code_dimension(self) -> int: """Get the code dimension (k). The code dimension is the number of information bits in each codeword, representing the actual data being transmitted. Returns: The dimension of the code (number of information bits) """ return self.encoder.code_dimension @property def redundancy(self) -> int: """Get the code redundancy (r = n - k). The redundancy represents the number of parity or check bits added to the information bits to enable error detection and correction. Returns: The redundancy of the code (number of parity bits) """ return self.encoder.redundancy @property def code_rate(self) -> float: """Get the code rate (k/n). The code rate is the ratio of information bits to the total bits, indicating the coding efficiency. Higher rates mean more efficient use of the channel but typically lower error correction capability. Returns: The rate of the code (ratio of information bits to total bits) """ return self.encoder.code_rate
[docs] @abstractmethod def forward(self, received: torch.Tensor, *args: Any, **kwargs: Any) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: """Decode received codewords to recover the original messages. This method implements the decoding algorithm to estimate the original message from a potentially corrupted received codeword. Different decoder implementations will use different algorithms based on the code structure and desired performance characteristics. Args: received: Received codeword tensor with shape (..., n) or (..., m*n) where n is the code length and m is some multiple. *args: Additional positional arguments for specific decoder implementations. **kwargs: Additional keyword arguments for specific decoder implementations. Returns: Either: - Decoded tensor containing estimated messages with shape (..., k) or (..., m*k) - A tuple of (decoded tensor, additional decoding information such as syndromes, reliability metrics, or error patterns) Raises: ValueError: If the last dimension of received is not a multiple of n. Note: The decoding may not perfectly recover the original message if the number of errors exceeds the error-correcting capability of the code. """ raise NotImplementedError("Subclasses must implement forward method")