"""Reed-Muller decoder using majority-logic decoding.
This module implements the majority-logic decoding algorithm for Reed-Muller codes. The algorithm
efficiently decodes Reed-Muller codes by exploiting their recursive structure and the properties of
their codewords.
Reed-Muller codes form an important family of linear error-correcting codes with a rich mathematical
structure based on finite geometries. The majority-logic decoding algorithm leverages this structure
to provide an efficient decoding method that can correct multiple errors while avoiding the complexity
of brute-force maximum likelihood decoding.
:cite:`reed1954class`
:cite:`muller1954application`
:cite:`macwilliams1977theory`
:cite:`lin2004error`
"""
from typing import Any, List, Literal, Tuple, Union
import torch
from kaira.models.fec.encoders.reed_muller_code import ReedMullerCodeEncoder
from ..utils import apply_blockwise
from .base import BaseBlockDecoder
[docs]
class ReedMullerDecoder(BaseBlockDecoder[ReedMullerCodeEncoder]):
"""Reed-Muller decoder using majority-logic decoding.
This decoder implements the majority-logic decoding algorithm developed by Reed
for Reed-Muller codes :cite:`reed1954class`. It works by recursively decoding the
received word using a series of majority-logic decisions based on special partitions
of the code that correspond to geometrical subspaces in the finite geometry interpretation.
For an RM(r,m) code, the algorithm can correct up to 2^(m-r-1) - 1 errors, which is
optimal for first-order Reed-Muller codes (r=1) :cite:`macwilliams1977theory`.
The decoder supports both hard-decision and soft-decision decoding, with the
soft-decision variant offering better performance in the presence of noise by
taking into account reliability information from the channel.
Attributes:
encoder (ReedMullerCodeEncoder): The Reed-Muller encoder instance providing
code parameters and encoding functionality
input_type (str): The type of input the decoder accepts:
'hard' for binary inputs (0s and 1s)
'soft' for real-valued inputs with reliability information
_reed_partitions (List[List[int]]): Precomputed Reed partitions for efficient decoding,
where each partition corresponds to a specific
information bit
Args:
encoder (ReedMullerCodeEncoder): The encoder for the Reed-Muller code being decoded
input_type (Literal["hard", "soft"]): The type of input the decoder accepts.
Default is "hard".
*args: Variable positional arguments passed to the base class
**kwargs: Variable keyword arguments passed to the base class
Examples:
>>> from kaira.models.fec.encoders import ReedMullerCodeEncoder
>>> from kaira.models.fec.decoders import ReedMullerDecoder
>>> import torch
>>>
>>> # Create a RM(1,3) code encoder and decoder
>>> encoder = ReedMullerCodeEncoder(r=1, m=3)
>>> decoder = ReedMullerDecoder(encoder)
>>>
>>> # Encode a message
>>> message = torch.tensor([1., 0., 1., 0.])
>>> codeword = encoder(message)
>>>
>>> # Introduce an error
>>> received = codeword.clone()
>>> received[2] = 1 - received[2]
>>>
>>> # Decode using majority-logic decoding
>>> decoded = decoder(received)
>>> print(torch.all(decoded == message))
True
"""
[docs]
def __init__(self, encoder: ReedMullerCodeEncoder, input_type: Literal["hard", "soft"] = "hard", *args: Any, **kwargs: Any):
"""Initialize the Reed-Muller decoder.
Sets up the decoder with a Reed-Muller encoder instance and computes the
Reed partitions needed for majority-logic decoding.
Args:
encoder: The Reed-Muller encoder instance for the code being decoded
input_type: The type of decoder input, either "hard" for binary inputs
or "soft" for real-valued inputs with reliability information
*args: Variable positional arguments passed to the base class
**kwargs: Variable keyword arguments passed to the base class
Note:
The Reed partitions are precomputed during initialization to make the
decoding process more efficient. These partitions depend on the specific
parameters of the Reed-Muller code (r,m).
"""
super().__init__(encoder, *args, **kwargs)
self.input_type = input_type
# Compute Reed partitions
self._reed_partitions = self._generate_reed_partitions()
def _generate_reed_partitions(self) -> List[List[torch.Tensor]]:
"""Generate Reed partitions for efficient majority-logic decoding.
Reed partitions are special subsets of positions in the codeword that form
orthogonal check sums for decoding specific information bits in a Reed-Muller
code. These partitions correspond to geometrical subspaces in the finite
geometry interpretation of Reed-Muller codes.
In the context of an RM(r,m) code:
- For r=0 (repetition code), there is a single partition with all positions
- For r=1 (first-order RM code), partitions correspond to hyperplanes
- For higher-order RM codes, partitions are constructed recursively
Returns:
List of Reed partitions, where each partition is a list of position groups
that form check sums for a specific information bit
Note:
This implementation is simplified and would need to be expanded for a full
production implementation to handle all possible Reed-Muller parameters
correctly. The actual construction of these partitions is based on the
recursive structure of Reed-Muller codes and their relation to finite geometries.
"""
# This is a simplified implementation of Reed partitions generation
# In a full implementation, this would depend on the specific parameters
# of the Reed-Muller code (r, m)
# For demonstration purposes, we'll create a basic structure
# A real implementation would compute these based on the code properties
partitions = []
# Example partitioning logic - would need to be replaced with actual Reed-Muller partitioning
m = 0
r = 0
# Try to infer Reed-Muller parameters from code length and dimension
# For an (r,m) Reed-Muller code:
# - Length n = 2^m
# - Dimension k = sum(i=0 to r) of binomial(m,i)
# Infer m from code length
n = self.code_length
temp_m = 0
while 2**temp_m < n:
temp_m += 1
if 2**temp_m == n:
m = temp_m
# Given m, try to infer r from dimension
if m > 0:
k = self.code_dimension
temp_r = 0
temp_k = 0
while temp_k < k and temp_r <= m:
# Add binomial coefficient (m choose temp_r)
from math import comb
temp_k += comb(m, temp_r)
if temp_k == k:
r = temp_r
break
temp_r += 1
# Generate partitions based on Reed-Muller structure
if m > 0 and 0 <= r <= m:
# Generate partitions based on the cosets of the Reed-Muller code
# This is a simplified approach - actual implementation would be more involved
# For each information bit
for i in range(self.code_dimension):
# Create a partition for this bit
partition = []
# In a real implementation, these would be carefully constructed
# based on the algebraic structure of Reed-Muller codes
for j in range(2 ** (m - 1)):
# Create groups of positions that form checks for this bit
positions = []
for offset in range(2**r):
pos = (j * 2**r + offset) % self.code_length
positions.append(pos)
# Convert to tensor
partition.append(torch.tensor(positions, dtype=torch.long))
partitions.append(partition)
return partitions
[docs]
def forward(self, received: torch.Tensor, *args: Any, **kwargs: Any) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
"""Decode received values using the Reed majority-logic algorithm.
This method implements the majority-logic decoding process for Reed-Muller codes.
For each information bit, it computes a set of check sums based on the Reed
partitions and then makes a decision based on the majority value of these sums.
For soft-decision decoding, it also takes into account the reliability information
of each received bit, which can significantly improve performance in AWGN channels.
Args:
received: Received tensor with shape (..., n) or (..., m*n) where n is the code length.
For hard inputs, values should be 0 or 1.
For soft inputs, positive values represent likelihood of 0 bits and
negative values represent likelihood of 1 bits (e.g., LLR values).
*args: Additional positional arguments
**kwargs: Additional keyword arguments
return_errors: If True, also return the estimated error patterns
Returns:
Either:
- Decoded tensor containing estimated messages with shape (..., k) or (..., m*k)
- A tuple of (decoded tensor, error pattern tensor) if return_errors=True
Raises:
ValueError: If the last dimension of received is not a multiple of the code length
Note:
For first-order Reed-Muller codes (r=1), this decoder can correct up to
2^(m-2) errors, which matches the code's error-correcting capability.
For higher-order RM codes, the performance may not be optimal but the
algorithm provides an efficient decoding approach.
"""
return_errors = kwargs.get("return_errors", False)
# Check input dimensions
*leading_dims, L = received.shape
if L % self.code_length != 0:
raise ValueError(f"Last dimension ({L}) must be divisible by code length ({self.code_length})")
# Process blockwise
def decode_block(r_block):
batch_size = r_block.shape[0]
decoded = torch.zeros(batch_size, self.code_dimension, dtype=torch.int, device=received.device)
errors = torch.zeros_like(r_block) if return_errors else None
for i in range(batch_size):
# Get the current received word - ensure it's a 1D tensor
if r_block.dim() == 3: # Handle the case when r_block has shape [batch, 1, code_length]
r = r_block[i, 0, :]
else: # Handle the case when r_block has shape [batch, code_length]
r = r_block[i, :]
"""
# Convert to binary for hard decoding or compute hard decisions for soft decoding
if self.input_type == "hard":
bx = r.clone()
else: # self.input_type == "soft"
bx = (r < 0).to(torch.int)
"""
# Decode using Reed algorithm
u_hat = torch.zeros(self.code_dimension, dtype=torch.int, device=received.device)
# Process each bit position using its corresponding partition
for j, partition in enumerate(self._reed_partitions):
if j >= self.code_dimension:
break
# For hard decision decoding
if self.input_type == "hard":
# Calculate checksums for each group in the partition
checksums = []
for group in partition:
# Ensure the group indices are valid
valid_indices = group[group < r.shape[0]]
if len(valid_indices) == 0:
continue
# Take relevant positions and compute parity
# Use indexing to select elements from the 1D tensor
group_bits = r[valid_indices].to(torch.int)
checksum = torch.sum(group_bits) % 2
checksums.append(checksum.item()) # Use .item() to convert tensor to scalar
# Skip if no valid checksums
if not checksums:
continue
# Convert to tensor
checksums = torch.tensor(checksums, device=received.device)
# Make majority decision
u_hat[j] = (torch.sum(checksums) > len(checksums) // 2).to(torch.int)
# For soft decision decoding
else: # self.input_type == "soft"
# Calculate checksums and minimum reliabilities for each group
checksums = []
min_reliabilities = []
for group in partition:
# Ensure the group indices are valid
valid_indices = group[group < r.shape[0]]
if len(valid_indices) == 0:
continue
# Take relevant positions
group_bits = (r[valid_indices] < 0).to(torch.int)
group_reliabilities = torch.abs(r[valid_indices])
# Compute parity of hard decisions
checksum = torch.sum(group_bits) % 2
checksums.append(checksum.item()) # Use .item() to convert tensor to scalar
# Find minimum reliability in this group
min_reliability = torch.min(group_reliabilities)
min_reliabilities.append(min_reliability.item()) # Use .item() to convert tensor to scalar
# Skip if no valid checksums
if not checksums:
continue
# Convert to tensors
checksums = torch.tensor(checksums, device=received.device)
min_reliabilities = torch.tensor(min_reliabilities, device=received.device)
# Calculate decision variable
decision_var = torch.sum((1 - 2 * checksums) * min_reliabilities)
# Make decision
u_hat[j] = (decision_var < 0).to(torch.int)
# Store the decoded message
decoded[i] = u_hat
# Compute error pattern if needed
if return_errors:
# Re-encode the message to get the correct codeword
correct_codeword = self.encoder(u_hat.float().unsqueeze(0)).squeeze(0)
errors[i] = (r.to(torch.int) != correct_codeword.to(torch.int)).to(torch.int)
return (decoded, errors) if return_errors else decoded
# Apply decoding blockwise
return apply_blockwise(received, self.code_length, decode_block)