Source code for kaira.modulations.dpsk

"""Differential Phase-Shift Keying (DPSK) modulation schemes."""

from typing import Any, Literal, Optional, Union

import matplotlib.pyplot as plt  # type: ignore
import torch

from .base import BaseDemodulator, BaseModulator
from .registry import ModulationRegistry
from .utils import plot_constellation


[docs] @ModulationRegistry.register_modulator() class DPSKModulator(BaseModulator): """Differential Phase-Shift Keying (DPSK) modulator. Encodes information in the phase differences between consecutive symbols rather than absolute phases, making it robust to phase ambiguities. """ constellation: torch.Tensor # Type annotation for the buffer bit_patterns: torch.Tensor # Type annotation for the buffer _phase_memory: torch.Tensor # Type annotation for the buffer
[docs] def __init__(self, order: Optional[Literal[2, 4, 8, 16]] = None, gray_coding: bool = True, bits_per_symbol: Optional[int] = None, gray_coded: Optional[bool] = None, *args, **kwargs) -> None: """Initialize the DPSK modulator. Args: order: Modulation order (must be a power of 2) gray_coding: Whether to use Gray coding for phase mapping bits_per_symbol: Alternative way to specify order (2^bits_per_symbol) gray_coded: Alternative name for gray_coding *args: Variable length argument list. **kwargs: Arbitrary keyword arguments. """ # Pass *args and **kwargs to the base class initializer super().__init__(*args, **kwargs) # Support both initialization styles (order or bits_per_symbol) if bits_per_symbol is not None: self._bits_per_symbol = bits_per_symbol self.order = 2**bits_per_symbol elif order is not None: # Validate order is a power of 2 if not (order > 0 and (order & (order - 1) == 0)): raise ValueError(f"DPSK order must be a power of 2, got {order}") self.order = order self._bits_per_symbol: int = int(torch.log2(torch.tensor(order, dtype=torch.float)).item()) else: raise ValueError("Either order or bits_per_symbol must be provided") # Support both naming conventions self.gray_coding = gray_coded if gray_coded is not None else gray_coding # Create constellation self._create_constellation() # Initialize phase memory for differential encoding self.register_buffer("_phase_memory", torch.tensor(1.0 + 0.0j))
def _create_constellation(self) -> None: """Create the DPSK constellation mapping.""" # Generate differential phase shifts angles = torch.arange(0, self.order) * (2 * torch.pi / self.order) # For non-gray-coded, rotate constellation to make it different if not self.gray_coding: # Add a small rotation to make non-gray constellation visibly different angles = angles + torch.pi / self.order re_part = torch.cos(angles) im_part = torch.sin(angles) constellation = torch.complex(re_part, im_part) # Create bit pattern mapping bit_patterns = torch.zeros(self.order, self._bits_per_symbol) if self.gray_coding: # Apply Gray coding for i in range(self.order): gray_idx = i ^ (i >> 1) # Binary to Gray conversion bin_str = format(gray_idx, f"0{self._bits_per_symbol}b") for j, bit in enumerate(bin_str): bit_patterns[i, j] = int(bit) else: # Standard binary coding for i in range(self.order): bin_str = format(i, f"0{self._bits_per_symbol}b") for j, bit in enumerate(bin_str): bit_patterns[i, j] = int(bit) self.register_buffer("constellation", constellation) self.register_buffer("bit_patterns", bit_patterns)
[docs] def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor: """Modulate bit groups to DPSK symbols. Args: x: Input tensor of bits with shape (..., K*N), where K is bits_per_symbol, or direct symbol indices with shape (..., N) where each value is < order *args: Variable length argument list. **kwargs: Arbitrary keyword arguments. Returns: Complex tensor of DPSK symbols with shape (..., N) """ batch_shape = x.shape[:-1] bit_len = x.shape[-1] # Determine if input contains bit patterns or indices is_binary_input = torch.all((x == 0) | (x == 1)) # If it's a binary input, check length divisibility ONLY if not handling a single element tensor # Single element tensors should be treated as indices even if their values are 0 or 1 if is_binary_input and bit_len > 1 and bit_len % self._bits_per_symbol != 0: raise ValueError(f"Input bit length must be divisible by {self._bits_per_symbol}") # For single element tensors containing 0s or 1s, treat them as indices not bits if is_binary_input and bit_len == 1: is_binary_input = False if is_binary_input: # Calculate number of symbols symbol_len = bit_len // self._bits_per_symbol # Reshape to groups of bits_per_symbol for processing x_reshaped = x.reshape(*batch_shape, symbol_len, self._bits_per_symbol) # Convert bit groups to indices indices = torch.zeros((*batch_shape, symbol_len), dtype=torch.long, device=x.device) for i in range(self._bits_per_symbol): indices = indices | (x_reshaped[..., i].long() << (self._bits_per_symbol - i - 1)) else: # Process as direct indices indices = x.long() # Validate indices are within range if torch.any(indices >= self.order): raise ValueError(f"Symbol indices must be less than order ({self.order})") symbol_len = x.shape[-1] # Map indices to differential phase shifts phase_shifts = self.constellation[indices] # Apply differential encoding ref_phase = self._phase_memory.clone().detach() # Expand reference phase to match batch dimensions if needed if batch_shape: # Expand to match batch dimensions for _ in range(len(batch_shape)): ref_phase = ref_phase.unsqueeze(0) ref_phase = ref_phase.expand(*batch_shape, 1) else: ref_phase = ref_phase.unsqueeze(0) # Create output tensor with the right shape output = torch.zeros(*batch_shape, symbol_len, dtype=torch.complex64, device=x.device) # Apply differential modulation to all symbols if symbol_len > 0: # First symbol is modulated using the phase memory output[..., 0] = ref_phase.squeeze(-1) * phase_shifts[..., 0] # Apply differential encoding to subsequent symbols for i in range(1, symbol_len): output[..., i] = output[..., i - 1] * phase_shifts[..., i] # Update phase memory with the last output symbol if self.training: if batch_shape: self._phase_memory = output[..., -1].detach().mean().view(1) else: self._phase_memory = output[..., -1].detach().view(1) return output
[docs] def reset_state(self) -> None: """Reset the internal phase memory to the default state.""" self._phase_memory = torch.tensor(1.0 + 0.0j)
[docs] def plot_constellation(self, **kwargs) -> plt.Figure: """Plot the DPSK constellation diagram. Args: **kwargs: Additional arguments passed to plot_constellation Returns: Matplotlib figure object """ labels = [] for i in range(self.order): bit_pattern = self.bit_patterns[i] label = "".join(str(int(bit)) for bit in bit_pattern) labels.append(label) fig, _ = plot_constellation(self.constellation, labels=labels, title=f"{self.order}-DPSK Constellation", **kwargs) return fig
[docs] @ModulationRegistry.register_demodulator() class DPSKDemodulator(BaseDemodulator): """Differential Phase-Shift Keying (DPSK) demodulator."""
[docs] def __init__(self, order: Optional[Literal[2, 4, 8, 16]] = None, gray_coding: bool = True, bits_per_symbol: Optional[int] = None, gray_coded: Optional[bool] = None, *args, **kwargs) -> None: """Initialize the DPSK demodulator. Args: order: Modulation order (must be a power of 2) gray_coding: Whether Gray coding was used for phase mapping bits_per_symbol: Alternative way to specify order (2^bits_per_symbol) gray_coded: Alternative name for gray_coding *args: Variable length argument list. **kwargs: Arbitrary keyword arguments. """ # Pass *args and **kwargs to the base class initializer super().__init__(*args, **kwargs) # Support both initialization styles (order or bits_per_symbol) if bits_per_symbol is not None: self._bits_per_symbol = bits_per_symbol self.order = 2**bits_per_symbol elif order is not None: self.order = order self._bits_per_symbol: int = int(torch.log2(torch.tensor(order, dtype=torch.float)).item()) else: raise ValueError("Either order or bits_per_symbol must be provided") # Support both naming conventions self.gray_coding = gray_coded if gray_coded is not None else gray_coding # Create reference modulator to access constellation self.modulator = DPSKModulator(self.order, self.gray_coding, *args, **kwargs)
[docs] def forward(self, y: torch.Tensor, noise_var: Optional[Union[float, torch.Tensor]] = None, *args, **kwargs) -> torch.Tensor: """Demodulate DPSK symbols. Args: y: Received tensor of DPSK symbols with shape (..., N) noise_var: Noise variance for soft demodulation (optional) *args: Variable length argument list. **kwargs: Arbitrary keyword arguments. Returns: If noise_var is provided, returns LLRs; otherwise, returns hard bit decisions with shape (..., (N-1)*bits_per_symbol) because first symbol is reference """ batch_shape = y.shape[:-1] symbol_len = y.shape[-1] constellation = self.modulator.constellation # Need at least two symbols for differential demodulation if symbol_len < 2: raise ValueError("Need at least two symbols for differential demodulation") # Calculate phase differences between consecutive symbols y_prev = y[..., :-1] y_current = y[..., 1:] # z contains the differential phases (normalized by previous symbol) z = y_current * torch.conj(y_prev) z = z / (torch.abs(z) + 1e-9) # Normalize to unit magnitude if noise_var is None: # Hard decision: find closest constellation point z_angle = torch.angle(z) const_angles = torch.angle(constellation) # Find closest angle (considering circular distance) expanded_z_angle = z_angle.unsqueeze(-1) # (..., N-1, 1) expanded_const_angle = const_angles.expand(*([1] * len(batch_shape)), symbol_len - 1, self.order) # (..., N-1, order) # Calculate circular distance angle_diff = torch.abs((expanded_z_angle - expanded_const_angle + torch.pi) % (2 * torch.pi) - torch.pi) closest_indices = torch.argmin(angle_diff, dim=-1) # (..., N-1) # Map to bit patterns using the modulator's bit patterns bits = torch.zeros( (*batch_shape, symbol_len - 1, self._bits_per_symbol), dtype=torch.float, device=y.device, ) for i in range(self.order): mask = (closest_indices == i).unsqueeze(-1) bit_pattern = self.modulator.bit_patterns[i].expand(*batch_shape, symbol_len - 1, self._bits_per_symbol) bits = torch.where(mask, bit_pattern, bits) return bits.reshape(*batch_shape, -1) else: # Soft decision # For differential demodulation with noise, the effective noise variance is doubled # because noise affects both current and previous symbols # Convert noise_var to appropriate tensor form and apply 2x factor for differential detection if not isinstance(noise_var, torch.Tensor): effective_noise_var = torch.tensor(2.0 * noise_var, device=y.device) else: effective_noise_var = 2.0 * noise_var.to(device=y.device) # Handle scalar noise variance if effective_noise_var.dim() == 0: # scalar effective_noise_var = effective_noise_var.expand(*batch_shape, symbol_len - 1) # Calculate LLRs for each bit position llrs = torch.zeros((*batch_shape, symbol_len - 1, self._bits_per_symbol), device=y.device) for bit_idx in range(self._bits_per_symbol): # Create masks for symbols where bit is 0 or 1 bit_0_mask = self.modulator.bit_patterns[:, bit_idx] == 0 bit_1_mask = ~bit_0_mask # Get constellation points for each bit value const_bit_0 = constellation[bit_0_mask] const_bit_1 = constellation[bit_1_mask] # Calculate minimum distance for each bit value min_dist_0 = self._min_distance_to_points(z, const_bit_0, effective_noise_var) min_dist_1 = self._min_distance_to_points(z, const_bit_1, effective_noise_var) # Calculate LLR: log(P(bit=0)/P(bit=1)) llrs[..., bit_idx] = min_dist_1 - min_dist_0 return llrs.reshape(*batch_shape, -1)
def _min_distance_to_points(self, y: torch.Tensor, points: torch.Tensor, noise_var: torch.Tensor) -> torch.Tensor: """Calculate minimum (negative) distance to constellation points for DPSK. Uses max-log approximation for computational efficiency. Args: y: Received symbols with shape (..., N) points: Constellation points to compare against with shape (M,) noise_var: Noise variance with shape (..., N) Returns: Minimum negative distance for each symbol in y """ batch_shape = y.shape[:-1] symbol_shape = y.shape[-1] num_points = points.shape[0] # Fix: Ensure points_expanded has the right dimensions for all tensor shapes # Reshape inputs for broadcasting y_expanded = y.unsqueeze(-1) if batch_shape: # For multi-dimensional tensors, use proper expand y_expanded = y_expanded.expand(*batch_shape, symbol_shape, num_points) # Create points_expanded to match dimensions points_reshaped = points.reshape(*([1] * len(batch_shape)), 1, -1) points_expanded = points_reshaped.expand(*batch_shape, symbol_shape, num_points) # Expand noise variance similarly noise_var_expanded = noise_var.unsqueeze(-1).expand(*batch_shape, symbol_shape, num_points) else: # For 1D tensors, simpler expansion y_expanded = y_expanded.expand(symbol_shape, num_points) points_expanded = points.reshape(1, -1).expand(symbol_shape, num_points) noise_var_expanded = noise_var.unsqueeze(-1).expand(symbol_shape, num_points) # Calculate distances (using phase difference for DPSK) distances = -torch.abs(y_expanded - points_expanded) ** 2 / noise_var_expanded # Return maximum (least negative) value return torch.max(distances, dim=-1)[0]
[docs] @ModulationRegistry.register_modulator("dbpsk") class DBPSKModulator(DPSKModulator): """Differential Binary Phase-Shift Keying (DBPSK) modulator."""
[docs] def __init__(self, *args: Any, **kwargs: Any): """Initialize DBPSK Modulator.""" # Filter out conflicting keys to avoid duplicate argument errors filtered_kwargs = {k: v for k, v in kwargs.items() if k not in ("order", "gray_coding")} # Pass remaining args and filtered kwargs super().__init__(2, False, *args, **filtered_kwargs)
[docs] @ModulationRegistry.register_demodulator("dbpsk") class DBPSKDemodulator(DPSKDemodulator): """Differential Binary Phase-Shift Keying (DBPSK) demodulator."""
[docs] def __init__(self, *args: Any, **kwargs: Any): """Initialize DBPSK Demodulator.""" # Filter out conflicting keys to avoid duplicate argument errors filtered_kwargs = {k: v for k, v in kwargs.items() if k not in ("order", "gray_coding")} # Pass remaining args and filtered kwargs super().__init__(2, False, *args, **filtered_kwargs)
[docs] @ModulationRegistry.register_modulator("dqpsk") class DQPSKModulator(DPSKModulator): """Differential Quadrature Phase-Shift Keying (DQPSK) modulator."""
[docs] def __init__(self, *args: Any, **kwargs: Any): """Initialize DQPSK Modulator.""" # Filter out conflicting keys to avoid duplicate argument errors filtered_kwargs = {k: v for k, v in kwargs.items() if k not in ("order", "gray_coding")} # Pass remaining args and filtered kwargs super().__init__(4, True, *args, **filtered_kwargs)
[docs] @ModulationRegistry.register_demodulator("dqpsk") class DQPSKDemodulator(DPSKDemodulator): """Differential Quadrature Phase-Shift Keying (DQPSK) demodulator."""
[docs] def __init__(self, *args: Any, **kwargs: Any): """Initialize DQPSK Demodulator.""" # Filter out conflicting keys to avoid duplicate argument errors filtered_kwargs = {k: v for k, v in kwargs.items() if k not in ("order", "gray_coding")} # Pass remaining args and filtered kwargs super().__init__(4, True, *args, **filtered_kwargs)