Source code for kaira.models.fec.decoders.brute_force_ml

"""Brute Force Maximum Likelihood decoder for forward error correction.

This module implements a brute-force maximum likelihood decoding approach for block codes by
searching through all possible codewords to find the one closest to the received word in terms of
Hamming distance, which is equivalent to maximum likelihood decoding in symmetric channels.

The maximum likelihood decoding principle selects the codeword that has the highest probability
of having been transmitted, given the received word. For binary symmetric channels, this is
equivalent to finding the codeword with the minimum Hamming distance to the received word.
While this approach guarantees optimal decoding performance, its computational complexity
grows exponentially with the code dimension, making it practical only for small codes.

:cite:`lin2004error`
:cite:`moon2005error`
:cite:`proakis2008digital`
"""

from typing import Any, Optional, Tuple, Union

import torch

from kaira.models.fec.encoders.base import BaseBlockCodeEncoder

from .base import BaseBlockDecoder


[docs] class BruteForceMLDecoder(BaseBlockDecoder[BaseBlockCodeEncoder]): """Brute Force Maximum Likelihood decoder for block codes. This decoder implements a brute-force maximum likelihood approach by searching through all possible codewords to find the one that is closest to the received word in terms of Hamming distance. It is optimal in the sense that it performs maximum-likelihood decoding for symmetric channels :cite:`proakis2008digital`, but becomes computationally infeasible for larger codes. The brute force ML decoding principle is based on finding: argmin_{c ∈ C} d(r, c) where d(r, c) is the Hamming distance between the received word r and codeword c, and C is the set of all codewords in the code. The decoder works by: 1. Generating all possible codewords (or using a precomputed codebook) 2. Computing the Hamming distance between the received word and each codeword 3. Selecting the codeword with the minimum distance (maximum likelihood) 4. Extracting the message bits from the selected codeword Attributes: encoder (BaseBlockCodeEncoder): The encoder instance providing encoding functionality _codebook (torch.Tensor): Precomputed tensor of all codewords with shape (2^k, n) _message_map (torch.Tensor): Mapping from codeword indices to message bits with shape (2^k, k) Args: encoder (BaseBlockCodeEncoder): The encoder for the code being decoded precompute_codebook (bool): Whether to precompute the entire codebook during initialization. Default is True, which is more efficient for multiple decoding operations but requires more memory. *args: Variable positional arguments passed to the base class **kwargs: Variable keyword arguments passed to the base class Examples: >>> from kaira.models.fec.encoders import LinearBlockCodeEncoder >>> from kaira.models.fec.decoders import BruteForceMLDecoder >>> import torch >>> >>> # Create a simple (7,4) Hamming code encoder and decoder >>> G = torch.tensor([ ... [1, 0, 0, 0, 1, 1, 0], ... [0, 1, 0, 0, 1, 0, 1], ... [0, 0, 1, 0, 0, 1, 1], ... [0, 0, 0, 1, 1, 1, 1] ... ], dtype=torch.float) >>> encoder = LinearBlockCodeEncoder(generator_matrix=G) >>> decoder = BruteForceMLDecoder(encoder) >>> >>> # Encode a message >>> message = torch.tensor([1., 0., 1., 1.]) >>> codeword = encoder(message) >>> >>> # Introduce two bit errors >>> received = codeword.clone() >>> received[2] = 1 - received[2] >>> received[5] = 1 - received[5] >>> >>> # Decode using ML decoding >>> decoded = decoder(received) >>> print(torch.all(decoded == message)) True """
[docs] def __init__(self, encoder: BaseBlockCodeEncoder, precompute_codebook: bool = True, *args: Any, **kwargs: Any): """Initialize the brute force ML decoder. Sets up the decoder with an encoder instance and optionally precomputes the complete codebook for more efficient decoding operations. Args: encoder: The encoder instance for the code being decoded precompute_codebook: Whether to generate all possible codewords during initialization (True) or on-demand (False) *args: Variable positional arguments passed to the base class **kwargs: Variable keyword arguments passed to the base class Note: Precomputing the codebook requires O(2^k * n) memory, where k is the code dimension and n is the code length. This can be prohibitive for larger codes, so set precompute_codebook=False for such cases. """ super().__init__(encoder, *args, **kwargs) # Initialize attributes as Optional to satisfy mypy self._codebook: Optional[torch.Tensor] = None self._message_map: Optional[torch.Tensor] = None if precompute_codebook: codebook, message_map = self._generate_codebook() self._codebook = codebook self._message_map = message_map
def _generate_codebook(self) -> Tuple[torch.Tensor, torch.Tensor]: """Generate all possible codewords for the code. This method enumerates all 2^k possible messages (where k is the code dimension) and encodes each one to create the complete codebook of the code. It also maintains a mapping from each codeword back to its original message. The codebook is essential for the brute force ML decoding approach, as it allows the decoder to compare the received word with all valid codewords to find the most likely transmitted codeword. Returns: Tuple containing: - Tensor of shape (2^k, n) containing all possible codewords - Tensor of shape (2^k, k) containing all possible messages Note: For codes with dimension k > 20, this function may require excessive memory and computation time. Consider alternative decoding methods for such large codes. """ k = self.code_dimension n = self.code_length # Determine the appropriate dtype based on the encoder encoder_dtype = next(self.encoder.parameters(), torch.zeros(1)).dtype # Generate all possible messages (2^k) num_messages = 2**k messages = torch.zeros((num_messages, k), dtype=encoder_dtype) # Fill in binary representations for i in range(num_messages): for j in range(k): messages[i, k - j - 1] = float((i >> j) & 1) # Cast to float to match encoder dtype # Encode all messages to get codewords codewords = torch.zeros((num_messages, n), dtype=encoder_dtype) for i in range(num_messages): codewords[i] = self.encoder(messages[i].unsqueeze(0)).squeeze(0) return codewords, messages def _hamming_distance(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: """Compute the Hamming distance between two binary vectors. The Hamming distance counts the number of positions at which the corresponding elements of two vectors are different. It is a key metric in coding theory and is used by the ML decoder to determine the closest codeword to a received word. Args: x: First binary vector or batch of vectors y: Second binary vector or batch of vectors of the same shape as x Returns: Tensor containing the Hamming distance(s) along the last dimension Example: >>> x = torch.tensor([1, 0, 1, 0, 1]) >>> y = torch.tensor([1, 1, 0, 0, 1]) >>> self._hamming_distance(x, y) tensor(2) """ # Cast to same type if needed and ensure comparison works correctly if x.dtype != y.dtype: x = x.to(y.dtype) # Count differences elementwise and sum them along the last dimension diff = (x != y).to(x.dtype) return torch.sum(diff, dim=-1)
[docs] def forward(self, received: torch.Tensor, *args: Any, **kwargs: Any) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: """Decode received codewords using maximum likelihood decoding. This method implements the complete maximum likelihood decoding process: 1. Compare the received word with every possible codeword 2. Find the codeword that minimizes the Hamming distance to the received word 3. Return the message bits corresponding to that codeword This provides optimal decoding performance for the binary symmetric channel (BSC) in terms of minimizing the word error probability. Args: received: Received codeword tensor with shape (..., n) or (..., m*n) where n is the code length and m is some multiple *args: Additional positional arguments **kwargs: Additional keyword arguments return_errors: If True, also return the estimated error patterns Returns: Either: - Decoded tensor containing estimated messages with shape (..., k) or (..., m*k) - A tuple of (decoded tensor, error pattern tensor) if return_errors=True Raises: ValueError: If the last dimension of received is not a multiple of the code length Note: This decoder provides optimal (maximum likelihood) decoding for the binary symmetric channel at the cost of exponential complexity in the code dimension k. For larger codes, consider using more efficient decoders that may sacrifice some performance for computational tractability. """ return_errors = kwargs.get("return_errors", False) # Check input dimensions *leading_dims, L = received.shape if L % self.code_length != 0: raise ValueError(f"Last dimension ({L}) must be divisible by code length ({self.code_length})") # Non-batched processing if not leading_dims: # If received is just a single vector received = received.unsqueeze(0) # Add batch dimension result = self._decode_batch(received, return_errors) # Remove batch dimension from result if return_errors: decoded, errors = result # type: ignore[misc] # We know this is a tuple when return_errors=True return decoded.squeeze(0), errors.squeeze(0) else: # We know result is a tensor here, not a tuple return result.squeeze(0) # type: ignore[union-attr] # For batched inputs, reshape to handle multiple blocks if needed if L == self.code_length: # Simple case: each row is a single codeword return self._decode_batch(received, return_errors) else: # Complex case: each row contains multiple codewords # Reshape to (..., L/n, n) blocks_per_row = L // self.code_length reshaped = received.view(*leading_dims, blocks_per_row, self.code_length) # Combine batch dimensions for processing combined_shape = (-1, self.code_length) flattened = reshaped.reshape(*combined_shape) # Decode the combined batch if return_errors: batch_result = self._decode_batch(flattened, return_errors) decoded, errors = batch_result # type: ignore[misc] # We know this is a tuple # Reshape back to original structure decoded = decoded.view(*leading_dims, blocks_per_row, self.code_dimension) errors = errors.view(*leading_dims, blocks_per_row, self.code_length) # Flatten the block dimension return decoded.reshape(*leading_dims, -1), errors.reshape(*leading_dims, -1) else: # We know the result is a tensor here decoded = self._decode_batch(flattened, return_errors) # type: ignore[assignment] # Reshape back to original structure decoded = decoded.view(*leading_dims, blocks_per_row, self.code_dimension) # type: ignore[union-attr] # Flatten the block dimension return decoded.reshape(*leading_dims, -1) # type: ignore[union-attr]
def _decode_batch(self, received_batch: torch.Tensor, return_errors: bool = False) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: """Decode a batch of received words. Args: received_batch: Tensor with shape (batch_size, code_length) return_errors: Whether to return the error pattern Returns: Either: - Decoded tensor with shape (batch_size, code_dimension) - Tuple of (decoded tensor, error pattern tensor) if return_errors=True """ batch_size = received_batch.shape[0] decoded = torch.zeros(batch_size, self.code_dimension, dtype=received_batch.dtype, device=received_batch.device) errors = torch.zeros_like(received_batch) # Generate codebook if needed if self._codebook is None or self._message_map is None: temp_codebook, temp_message_map = self._generate_codebook() codebook = temp_codebook message_map = temp_message_map else: # We know these are not None at this point codebook = self._codebook message_map = self._message_map # Move codebook to same device as received codebook = codebook.to(received_batch.device) message_map = message_map.to(received_batch.device) for i in range(batch_size): # Get the current received word r = received_batch[i] # Compute Hamming distance to all codewords distances = self._hamming_distance(r.unsqueeze(0).expand(codebook.shape[0], -1), codebook) # Find the closest codeword (maximum likelihood) min_idx = torch.argmin(distances) closest_codeword = codebook[min_idx] # Compute the error pattern error_pattern = (r != closest_codeword).to(torch.int) errors[i] = error_pattern # Get the corresponding message decoded[i] = message_map[min_idx] return (decoded, errors) if return_errors else decoded