Source code for kaira.modulations.psk

"""Phase-Shift Keying (PSK) modulation schemes."""

from typing import Literal, Optional, Tuple, Union

import matplotlib.pyplot as plt  # type: ignore
import torch

from .base import BaseDemodulator, BaseModulator
from .registry import ModulationRegistry
from .utils import plot_constellation


[docs] @ModulationRegistry.register_modulator() class BPSKModulator(BaseModulator): """Binary Phase-Shift Keying (BPSK) modulator. Maps binary inputs (0, 1) to constellation points (1, -1). Following standard convention where: - Bit 0 maps to +1 - Bit 1 maps to -1 """ constellation: torch.Tensor # Type annotation for the buffer
[docs] def __init__(self, complex_output: bool = True, *args, **kwargs) -> None: """Initialize the BPSK modulator. Args: complex_output: Whether to output complex values (default: True for consistency with other PSK modulators) *args: Variable length argument list. **kwargs: Arbitrary keyword arguments. """ super().__init__(*args, **kwargs) # Define constellation points self.complex_output = complex_output re_part = torch.tensor([1.0, -1.0]) im_part = torch.tensor([0.0, 0.0]) self.register_buffer("constellation", torch.complex(re_part, im_part)) self._bits_per_symbol = 1 # BPSK has 1 bit per symbol
[docs] def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor: """Modulate binary inputs to BPSK symbols. Args: x: Input tensor of bits with shape (..., N) *args: Variable length argument list. **kwargs: Arbitrary keyword arguments. Returns: Complex tensor of BPSK symbols with shape (..., N) """ # Convert binary 0/1 to 1/-1 if self.complex_output: return torch.complex(1.0 - 2.0 * x.float(), torch.zeros_like(x.float())) else: return 1.0 - 2.0 * x.float()
[docs] def plot_constellation(self, **kwargs) -> Tuple[plt.Figure, plt.Axes]: """Plot the BPSK constellation diagram. Args: **kwargs: Additional arguments passed to plot_constellation Returns: Matplotlib figure object """ return plot_constellation(self.constellation, labels=["0", "1"], title="BPSK Constellation", **kwargs)
[docs] @ModulationRegistry.register_demodulator() class BPSKDemodulator(BaseDemodulator): """Binary Phase-Shift Keying (BPSK) demodulator. Following standard convention where: - Positive values map to bit 0 - Negative values map to bit 1 """
[docs] def __init__(self, *args, **kwargs) -> None: """Initialize the BPSK demodulator. Args: *args: Variable length argument list. **kwargs: Arbitrary keyword arguments. """ super().__init__(*args, **kwargs) self._bits_per_symbol = 1 # BPSK has 1 bit per symbol
[docs] def forward(self, y: torch.Tensor, noise_var: Optional[Union[float, torch.Tensor]] = None, *args, **kwargs) -> torch.Tensor: """Demodulate BPSK symbols. Args: y: Received tensor of BPSK symbols noise_var: Noise variance for soft demodulation (optional) *args: Variable length argument list. **kwargs: Arbitrary keyword arguments. Returns: If noise_var is provided, returns LLRs; otherwise, returns hard bit decisions """ # Extract real part for decision (BPSK uses only real axis) y_real = y.real if noise_var is None: # Hard decision: y < 0 -> 1, y >= 0 -> 0 return (y_real < 0).float() else: # Support both scalar and tensor noise variance if not isinstance(noise_var, torch.Tensor): noise_var = torch.tensor(noise_var, device=y.device) # Soft decision: LLR calculation # Negative LLR means bit 1 is more likely, positive means bit 0 is more likely # LLR = log(P(y|b=0)/P(y|b=1)) = log(exp(-(y-1)²/2σ²)/exp(-(y+1)²/2σ²)) = 2y/σ² return 2.0 * y_real / noise_var
[docs] @ModulationRegistry.register_modulator() class QPSKModulator(BaseModulator): """Quadrature Phase-Shift Keying (QPSK) modulator. Maps pairs of bits to complex constellation points in QPSK modulation. Following standard Gray-coded QPSK convention where: - 00 maps to (1+j)/√2 (first quadrant) - 01 maps to (1-j)/√2 (fourth quadrant) - 10 maps to (-1+j)/√2 (second quadrant) - 11 maps to (-1-j)/√2 (third quadrant) """ constellation: torch.Tensor # Type annotation for the buffer bit_patterns: torch.Tensor # Type annotation for the buffer
[docs] def __init__(self, normalize: bool = True, *args, **kwargs) -> None: """Initialize the QPSK modulator. Args: normalize: If True, normalize constellation to unit energy *args: Variable length argument list. **kwargs: Arbitrary keyword arguments. """ super().__init__(*args, **kwargs) self.normalize = normalize self._normalization = 1 / (2**0.5) if normalize else 1.0 # QPSK mapping table with Gray coding re_part = torch.tensor([1.0, 1.0, -1.0, -1.0], dtype=torch.float) * self._normalization im_part = torch.tensor([1.0, -1.0, 1.0, -1.0], dtype=torch.float) * self._normalization self.register_buffer("constellation", torch.complex(re_part, im_part)) # Bit patterns for each symbol - Gray coded bit_patterns = torch.tensor( [ [0.0, 0.0], # First quadrant [0.0, 1.0], # Fourth quadrant [1.0, 0.0], # Second quadrant [1.0, 1.0], # Third quadrant ], dtype=torch.float, ) self.register_buffer("bit_patterns", bit_patterns) self._bits_per_symbol = 2 # QPSK has 2 bits per symbol
[docs] def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor: """Modulate bit pairs to QPSK symbols. Args: x: Input tensor of bits with shape (..., 2*N) *args: Variable length argument list. **kwargs: Arbitrary keyword arguments. Returns: Complex tensor of QPSK symbols with shape (..., N) """ # Ensure input length is even batch_shape = x.shape[:-1] bit_len = x.shape[-1] if bit_len % 2 != 0: raise ValueError("Input bit length must be even for QPSK modulation") # Reshape to pairs of bits x_reshaped = x.reshape(*batch_shape, -1, 2) # Convert bit pairs to indices using Gray coding pattern indices = x_reshaped[..., 0].to(torch.long) * 2 + x_reshaped[..., 1].to(torch.long) # Handle empty tensor case if indices.numel() == 0: return torch.empty((*batch_shape, 0), dtype=torch.complex64, device=x.device) # Map indices to symbols return self.constellation[indices]
[docs] def plot_constellation(self, **kwargs) -> Tuple[plt.Figure, plt.Axes]: """Plot the QPSK constellation diagram. Args: **kwargs: Additional arguments passed to plot_constellation Returns: Matplotlib figure object """ labels = [] for i in range(4): bit_pattern = self.bit_patterns[i] labels.append(f"{int(bit_pattern[0])}{int(bit_pattern[1])}") return plot_constellation(self.constellation, labels=labels, title="QPSK Constellation", **kwargs)
[docs] @ModulationRegistry.register_demodulator() class QPSKDemodulator(BaseDemodulator): """Quadrature Phase-Shift Keying (QPSK) demodulator. Demodulates QPSK symbols back to bit pairs following Gray coding convention. """
[docs] def __init__(self, normalize: bool = True, *args, **kwargs) -> None: """Initialize the QPSK demodulator. Args: normalize: If True, assume normalized constellation with unit energy *args: Variable length argument list. **kwargs: Arbitrary keyword arguments. """ super().__init__(*args, **kwargs) self.normalize = normalize self._normalization = 1 / (2**0.5) if normalize else 1.0 # Create modulator to access constellation self.modulator = QPSKModulator(normalize) self._bits_per_symbol = 2 # QPSK has 2 bits per symbol
[docs] def forward(self, y: torch.Tensor, noise_var: Optional[Union[float, torch.Tensor]] = None, *args, **kwargs) -> torch.Tensor: """Demodulate QPSK symbols. Args: y: Received tensor of QPSK symbols noise_var: Noise variance for soft demodulation (optional) *args: Variable length argument list. **kwargs: Arbitrary keyword arguments. Returns: If noise_var is provided, returns LLRs; otherwise, returns hard bit decisions """ batch_shape = y.shape[:-1] symbol_shape = y.shape[-1] if noise_var is None: # Hard decision: Find the closest constellation point expanded_y = y.unsqueeze(-1) # (..., N, 1) expanded_const = self.modulator.constellation.expand(*([1] * len(batch_shape)), symbol_shape, 4) # (..., N, 4) # Calculate distances distances = torch.abs(expanded_y - expanded_const) closest_idx = torch.argmin(distances, dim=-1) # (..., N) # Get bit patterns for each closest constellation point bits = torch.zeros((*batch_shape, symbol_shape, 2), dtype=torch.float, device=y.device) for i in range(4): mask = (closest_idx == i).unsqueeze(-1).expand(*batch_shape, symbol_shape, 2) bit_pattern = self.modulator.bit_patterns[i].expand(*batch_shape, symbol_shape, 2) bits = torch.where(mask, bit_pattern, bits) # Reshape to final bit sequence return bits.reshape(*batch_shape, -1) else: # Support both scalar and tensor noise variance if not isinstance(noise_var, torch.Tensor): noise_var_tensor = torch.tensor(noise_var, device=y.device) else: noise_var_tensor = noise_var # Handle broadcasting dimensions for noise_var if noise_var_tensor.dim() == 0: # scalar noise_var_tensor = noise_var_tensor.expand(*batch_shape, symbol_shape) # Calculate LLRs for bit positions llrs = torch.zeros((*batch_shape, symbol_shape, 2), device=y.device) # For each bit position, compute the LLR using max-log approximation for bit_idx in range(2): # Separate constellation points for bit=0 and bit=1 bit_0_mask = self.modulator.bit_patterns[:, bit_idx] == 0 bit_1_mask = ~bit_0_mask # Get corresponding constellation points const_bit_0 = self.modulator.constellation[bit_0_mask] const_bit_1 = self.modulator.constellation[bit_1_mask] # Calculate minimum distances min_dist_0 = self._min_distance_to_points(y, const_bit_0, noise_var) min_dist_1 = self._min_distance_to_points(y, const_bit_1, noise_var) # LLR = log(P(bit=0|y)/P(bit=1|y)) llrs[..., bit_idx] = min_dist_1 - min_dist_0 # Reshape to final sequence return llrs.reshape(*batch_shape, -1)
def _min_distance_to_points(self, y: torch.Tensor, points: torch.Tensor, noise_var: torch.Tensor) -> torch.Tensor: """Calculate minimum (negative) distance to a set of constellation points. Uses max-log approximation for computational efficiency. Args: y: Received symbols with shape (..., N) points: Constellation points to compare against with shape (M,) noise_var: Noise variance with shape (..., N) Returns: Minimum negative distance for each symbol in y """ batch_shape = y.shape[:-1] symbol_shape = y.shape[-1] num_points = points.shape[0] # Reshape inputs for broadcasting y.unsqueeze(-1) # (..., N, 1) # Fix the dimension mismatch by directly calculating distances for each point distances = torch.zeros((*batch_shape, symbol_shape, num_points), device=y.device) for i in range(num_points): point = points[i] # Calculate squared distance between each symbol and this point distances[..., i] = -torch.abs(y - point) ** 2 / noise_var # Return maximum (least negative) value for each symbol max_values, _ = torch.max(distances, dim=-1) return max_values
[docs] @ModulationRegistry.register_modulator() class PSKModulator(BaseModulator): """General M-ary Phase-Shift Keying (PSK) modulator. Maps groups of bits to complex constellation points around the unit circle. Follows standard digital communications convention with Gray coding. """ constellation: torch.Tensor # Type annotation for the buffer bit_patterns: torch.Tensor # Type annotation for the buffer bit_to_symbol_map: torch.Tensor # Type annotation for mapping bit patterns to symbols
[docs] def __init__(self, order: Literal[4, 8, 16, 32, 64] = 4, gray_coding: bool = True, constellation: Optional[torch.Tensor] = None, *args, **kwargs) -> None: """Initialize the PSK modulator. Args: order: Modulation order (must be a power of 2) gray_coding: Whether to use Gray coding for constellation mapping constellation: Optional custom constellation points (overrides order) *args: Variable length argument list. **kwargs: Arbitrary keyword arguments. """ super().__init__(*args, **kwargs) self.gray_coding = gray_coding if constellation is not None: # Use custom constellation self.register_buffer("constellation", constellation) self.order = len(constellation) # Validate order is a power of 2 if not (self.order > 0 and (self.order & (self.order - 1) == 0)): raise ValueError(f"Custom constellation length must be a power of 2, got {self.order}") self._bits_per_symbol: int = int(torch.log2(torch.tensor(self.order, dtype=torch.float)).item()) else: # Validate order is a power of 2 if not (order > 0 and (order & (order - 1) == 0)): raise ValueError(f"PSK order must be a power of 2, got {order}") self.order = order self._bits_per_symbol = int(torch.log2(torch.tensor(order, dtype=torch.float)).item()) # Create standard PSK constellation self._create_constellation()
def _create_constellation(self) -> None: """Create the PSK constellation mapping.""" # Generate points evenly spaced around the unit circle # Standard convention: first point at angle 0 (real axis) angles = torch.arange(0, self.order) * (2 * torch.pi / self.order) re_part = torch.cos(angles) im_part = torch.sin(angles) constellation = torch.complex(re_part, im_part) # Create bit pattern mapping bit_patterns = torch.zeros(self.order, self._bits_per_symbol) if self.gray_coding: # Apply Gray coding - standard digital communications convention # For each index i, calculate corresponding Gray code for i in range(self.order): gray_idx = i ^ (i >> 1) # Binary to Gray conversion bin_str = format(gray_idx, f"0{self._bits_per_symbol}b") for j, bit in enumerate(bin_str): bit_patterns[i, j] = int(bit) else: # Standard binary coding for i in range(self.order): bin_str = format(i, f"0{self._bits_per_symbol}b") for j, bit in enumerate(bin_str): bit_patterns[i, j] = int(bit) # Create mapping from bit patterns to constellation indices bit_to_symbol_map = torch.zeros(self.order, dtype=torch.long) # Map each bit pattern to its index in the constellation for i in range(self.order): # Create binary index from bit pattern idx = 0 for j in range(self._bits_per_symbol): idx = idx * 2 + int(bit_patterns[i, j]) if self.gray_coding: # For Gray coding, we map the bit pattern to the constellation point bit_to_symbol_map[idx] = i else: # For binary coding, the mapping is direct bit_to_symbol_map[i] = i self.register_buffer("constellation", constellation) self.register_buffer("bit_patterns", bit_patterns) self.register_buffer("bit_to_symbol_map", bit_to_symbol_map)
[docs] def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor: """Modulate bit groups to PSK symbols. Args: x: Input tensor of bits with shape (..., M) or direct indices into the constellation *args: Variable length argument list. **kwargs: Arbitrary keyword arguments. Returns: Complex tensor of PSK symbols with shape (..., N) """ # Handle scalar and 0-dim tensor inputs scalar_input = x.dim() == 0 if scalar_input: x = x.unsqueeze(0) # Special case for direct constellation indices as scalar or single-element tensor if x.numel() == 1 and ((x == x.long()).all() and torch.all(x < self.order) and torch.all(x >= 0)): # This is a direct index into the constellation return self.constellation[x.long()].squeeze() # Return scalar output for scalar input # Normal case: Binary bit values grouped into symbols # Ensure input contains binary values (0s and 1s) if torch.any((x != 0) & (x != 1)): # If there are non-binary values, check if they are valid indices if not ((x == x.long()).all() and torch.all(x < self.order) and torch.all(x >= 0)): raise ValueError("Input tensor must contain only binary values (0s and 1s)") # Special case for direct indices in a tensor if x.dim() == 1: # These are valid indices indices = x.long() return self.constellation[indices] # Get batch shape and bit length batch_shape = x.shape[:-1] bit_len = x.shape[-1] # Ensure input length is a multiple of bits_per_symbol if bit_len % self._bits_per_symbol != 0: raise ValueError(f"Input bit length must be a multiple of {self._bits_per_symbol} for {self.order}-PSK modulation") # Reshape to groups of bits x_reshaped = x.reshape(*batch_shape, -1, self._bits_per_symbol) # Convert bit groups to indices indices = torch.zeros((*batch_shape, x_reshaped.shape[-2]), dtype=torch.long, device=x.device) for i in range(self._bits_per_symbol): power = 2 ** (self._bits_per_symbol - 1 - i) indices = indices + (x_reshaped[..., i].long() * power) # Map bit pattern indices to constellation indices symbol_indices = self.bit_to_symbol_map[indices] # Map indices to symbols symbols = self.constellation[symbol_indices] # Handle scalar output if input was scalar if scalar_input and bit_len == self._bits_per_symbol: symbols = symbols.squeeze() return symbols
[docs] def plot_constellation(self, **kwargs) -> Tuple[plt.Figure, plt.Axes]: """Plot the PSK constellation diagram. Args: **kwargs: Additional arguments passed to plot_constellation Returns: Matplotlib figure object """ labels = [] for i in range(self.order): bit_pattern = self.bit_patterns[i] label = "".join(str(int(bit)) for bit in bit_pattern) labels.append(label) return plot_constellation(self.constellation, labels=labels, title=f"{self.order}-PSK Constellation", **kwargs)
[docs] @ModulationRegistry.register_demodulator() class PSKDemodulator(BaseDemodulator): """General M-ary Phase-Shift Keying (PSK) demodulator. Demodulates complex constellation points back to bits. """
[docs] def __init__(self, order: Literal[4, 8, 16, 32, 64] = 4, gray_coding: bool = True, *args, **kwargs) -> None: """Initialize the PSK demodulator. Args: order: Modulation order (must be a power of 2) gray_coding: Whether Gray coding was used for constellation mapping *args: Variable length argument list. **kwargs: Arbitrary keyword arguments. """ super().__init__(*args, **kwargs) self.order = order self.gray_coding = gray_coding self._bits_per_symbol: int = int(torch.log2(torch.tensor(order, dtype=torch.float)).item()) # Create modulator to access constellation self.modulator = PSKModulator(order, gray_coding)
[docs] def forward(self, y: torch.Tensor, noise_var: Optional[Union[float, torch.Tensor]] = None, *args, **kwargs) -> torch.Tensor: """Demodulate PSK symbols. Args: y: Received tensor of PSK symbols noise_var: Noise variance for soft demodulation (optional) *args: Variable length argument list. **kwargs: Arbitrary keyword arguments. Returns: If noise_var is provided, returns LLRs; otherwise, returns hard bit decisions """ # Handle scalar input scalar_input = y.dim() == 0 if scalar_input: y = y.unsqueeze(0) batch_shape = y.shape[:-1] symbol_shape = y.shape[-1] constellation = self.modulator.constellation if noise_var is None: # Hard decision: find closest constellation point expanded_y = y.unsqueeze(-1) # (..., N, 1) expanded_const = constellation.unsqueeze(0).expand(*([1] * len(batch_shape)), symbol_shape, self.order) # Calculate distances to constellation points distances = torch.abs(expanded_y - expanded_const) closest_indices = torch.argmin(distances, dim=-1) # (..., N) # Map to bit patterns bits = torch.zeros((*batch_shape, symbol_shape, self._bits_per_symbol), dtype=torch.float, device=y.device) for i in range(self.order): mask = (closest_indices == i).unsqueeze(-1).expand(*batch_shape, symbol_shape, self._bits_per_symbol) bit_pattern = self.modulator.bit_patterns[i].expand(*batch_shape, symbol_shape, self._bits_per_symbol) bits = torch.where(mask, bit_pattern, bits) # Reshape to final bit sequence result = bits.reshape(*batch_shape, -1).float() # Ensure consistent float output # Handle scalar output if input was scalar if scalar_input: result = result.squeeze(0) return result else: # Soft decision: LLR calculation if not isinstance(noise_var, torch.Tensor): noise_var_tensor = torch.tensor(noise_var, device=y.device) else: noise_var_tensor = noise_var # Handle broadcasting dimensions for noise_var if noise_var_tensor.dim() == 0: # scalar noise_var_tensor = noise_var_tensor.expand(*batch_shape, symbol_shape) # Calculate LLRs for each bit position llrs = torch.zeros((*batch_shape, symbol_shape, self._bits_per_symbol), device=y.device) # For each bit position for bit_idx in range(self._bits_per_symbol): # Get constellation points where bit is 0 or 1 bit_0_mask = self.modulator.bit_patterns[:, bit_idx] == 0 bit_1_mask = ~bit_0_mask # Get corresponding points const_bit_0 = constellation[bit_0_mask] const_bit_1 = constellation[bit_1_mask] # Process each symbol individually for clearer computation import itertools for b_idx in itertools.product(*[range(dim) for dim in batch_shape]): for s_idx in range(symbol_shape): # Get the received symbol if batch_shape: sym = y[b_idx][s_idx] nvar = noise_var_tensor[b_idx][s_idx] else: sym = y[s_idx] nvar = noise_var_tensor[s_idx] # Calculate distances to constellation points dist_0 = torch.min(torch.abs(sym - const_bit_0) ** 2) / nvar dist_1 = torch.min(torch.abs(sym - const_bit_1) ** 2) / nvar # LLR = log(P(bit=0|y)/P(bit=1|y)) = log(exp(-dist_0)/exp(-dist_1)) = -dist_0 + dist_1 if batch_shape: llrs[b_idx][s_idx][bit_idx] = -dist_0 + dist_1 else: llrs[s_idx][bit_idx] = -dist_0 + dist_1 # Reshape to final LLR sequence result = llrs.reshape(*batch_shape, -1) # Handle scalar output if input was scalar if scalar_input: result = result.squeeze(0) return result