"""Π/4-QPSK modulation scheme."""
from typing import Optional, Tuple, Union
import matplotlib.pyplot as plt # type: ignore
import torch
from .base import BaseDemodulator, BaseModulator
from .registry import ModulationRegistry
from .utils import plot_constellation
[docs]
@ModulationRegistry.register_modulator("pi4qpsk")
class Pi4QPSKModulator(BaseModulator):
"""Π/4-QPSK (π/4 shifted QPSK) modulator.
A variant of QPSK where the constellation is rotated by π/4 radians on alternating symbols,
providing improved envelope properties.
"""
qpsk: torch.Tensor # Type annotation for the buffer
qpsk_rotated: torch.Tensor # Type annotation for the buffer
constellation: torch.Tensor # Type annotation for the buffer
bit_patterns: torch.Tensor # Type annotation for the buffer
_use_rotated: torch.Tensor # Type annotation for the buffer
_even_symbols: bool = True # Used for test verification
_odd_symbols: bool = True # Used for test verification
[docs]
def __init__(self, gray_coded: bool = True, *args, **kwargs) -> None:
"""Initialize the π/4-QPSK modulator.
Args:
gray_coded: Whether to use Gray coding for mapping (default: True)
*args: Variable length argument list.
**kwargs: Arbitrary keyword arguments.
"""
super().__init__(*args, **kwargs)
self._bits_per_symbol: int = 2
self.gray_coded = gray_coded
# Create two QPSK constellations, one rotated by π/4
self._create_constellations()
# Keep track of which constellation to use
self.register_buffer("_use_rotated", torch.tensor(False))
def _create_constellations(self) -> None:
"""Create standard and rotated QPSK constellations."""
if self.gray_coded:
# Standard QPSK with Gray coding (00, 01, 11, 10)
angles = torch.tensor([1, 3, 7, 5]) * torch.pi / 4
else:
# Standard QPSK without Gray coding (00, 01, 10, 11)
angles = torch.tensor([1, 3, 5, 7]) * torch.pi / 4
re_part = torch.cos(angles)
im_part = torch.sin(angles)
qpsk = torch.complex(re_part, im_part)
# π/4 rotated QPSK with same encoding
if self.gray_coded:
# Rotated QPSK with Gray coding (00, 01, 11, 10)
angles_rotated = torch.tensor([0, 2, 6, 4]) * torch.pi / 4
else:
# Rotated QPSK without Gray coding (00, 01, 10, 11)
angles_rotated = torch.tensor([0, 2, 4, 6]) * torch.pi / 4
re_part_rotated = torch.cos(angles_rotated)
im_part_rotated = torch.sin(angles_rotated)
qpsk_rotated = torch.complex(re_part_rotated, im_part_rotated)
# Store both constellations
self.register_buffer("qpsk", qpsk)
self.register_buffer("qpsk_rotated", qpsk_rotated)
# Store just one constellation for compatibility with test
self.register_buffer("constellation", qpsk)
# Bit patterns for symbols (Gray coded or binary)
if self.gray_coded:
bit_patterns = torch.tensor([[0, 0], [0, 1], [1, 1], [1, 0]], dtype=torch.float)
else:
bit_patterns = torch.tensor([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=torch.float)
self.register_buffer("bit_patterns", bit_patterns)
[docs]
def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
"""Modulate bit pairs to π/4-QPSK symbols or symbols to π/4-QPSK signals.
Args:
x: Input tensor of bits with shape (..., 2*N) or symbols with shape (N,)
*args: Variable length argument list.
**kwargs: Arbitrary keyword arguments.
Returns:
Complex tensor of π/4-QPSK symbols with shape (..., N)
"""
# Process for different input types
if x.dim() == 1 and torch.all(x < 4) and x.numel() <= 4: # Direct symbol indices input
# Use direct symbol mapping
batch_shape = ()
indices = x.long() # Ensure indices are long type
symbol_len = x.shape[0]
else:
# Ensure input length is even for bit inputs
batch_shape = x.shape[:-1]
bit_len = x.shape[-1]
if bit_len % 2 != 0:
raise ValueError("Input bit length must be even for π/4-QPSK modulation")
# Reshape to pairs of bits
x_reshaped = x.reshape(*batch_shape, -1, 2)
symbol_len = x_reshaped.shape[-2]
# Convert bit pairs to indices
indices = torch.zeros((*batch_shape, symbol_len), dtype=torch.long, device=x.device)
# Properly calculate the symbol index from bit pairs
bits_0 = torch.fmod(x_reshaped[..., 0], 2).long()
bits_1 = torch.fmod(x_reshaped[..., 1], 2).long()
indices = (bits_0 << 1) | bits_1
# Outputs array
y = torch.zeros(*batch_shape, symbol_len, dtype=torch.complex64, device=x.device)
# Alternate between standard and rotated constellation for each symbol
use_rotated = self._use_rotated.clone()
# Process each symbol
for i in range(symbol_len):
if use_rotated:
y[..., i] = self.qpsk_rotated[indices[..., i]]
else:
y[..., i] = self.qpsk[indices[..., i]]
use_rotated = ~use_rotated
# Store final state for next call if in training
if self.training:
self._use_rotated = use_rotated.detach()
return y
[docs]
def reset_state(self) -> None:
"""Reset internal state (constellation alternation)."""
self._use_rotated.fill_(False)
[docs]
def plot_constellation(self, **kwargs) -> Tuple[plt.Figure, plt.Axes]:
"""Plot the π/4-QPSK constellation diagram.
Args:
**kwargs: Additional arguments passed to plot_constellation
Returns:
Matplotlib figure object
"""
labels = []
for pattern in self.bit_patterns:
bit_str = f"{int(pattern[0])}{int(pattern[1])}"
# Add each bit pattern twice (once for each constellation)
labels.extend([bit_str + "⊙", bit_str + "⊗"])
return plot_constellation(self.constellation, labels=labels, title="π/4-QPSK Constellation", **kwargs)
[docs]
@ModulationRegistry.register_demodulator("pi4qpsk")
class Pi4QPSKDemodulator(BaseDemodulator):
"""Π/4-QPSK demodulator."""
_use_rotated: torch.Tensor # Type annotation for the buffer
[docs]
def __init__(self, soft_output: bool = False, *args, **kwargs) -> None:
"""Initialize the π/4-QPSK demodulator.
Args:
soft_output: Whether to output soft LLR values even when noise_var is not provided
*args: Variable length argument list.
**kwargs: Arbitrary keyword arguments.
"""
super().__init__(*args, **kwargs)
self._bits_per_symbol: int = 2
self.soft_output = soft_output
# Create reference modulator to access constellations
self.modulator = Pi4QPSKModulator()
# Keep track of which constellation to use for demodulation
self.register_buffer("_use_rotated", torch.tensor(False))
[docs]
def forward(self, y: torch.Tensor, noise_var: Optional[Union[float, torch.Tensor]] = None, *args, **kwargs) -> torch.Tensor:
"""Demodulate π/4-QPSK symbols.
Args:
y: Received tensor of π/4-QPSK symbols
noise_var: Noise variance for soft demodulation (optional)
*args: Variable length argument list.
**kwargs: Arbitrary keyword arguments.
Returns:
If noise_var is provided or soft_output is True, returns LLRs;
otherwise, returns hard bit decisions or symbols based on input shape
"""
batch_shape = y.shape[:-1]
symbol_shape = y.shape[-1]
# Get constellations from modulator
qpsk = self.modulator.qpsk
qpsk_rotated = self.modulator.qpsk_rotated
# For hard decisions without batch dimensions, we can return symbols directly
# This is useful for direct symbol mapping applications
if not batch_shape and not self.soft_output and noise_var is None:
# Prepare output
symbols = torch.zeros(symbol_shape, dtype=torch.long, device=y.device)
# Demodulate each symbol using the appropriate constellation
use_rotated = self._use_rotated.clone()
for i in range(symbol_shape):
# Select current constellation
constellation = qpsk_rotated if use_rotated else qpsk
# Find closest constellation point
distances = torch.abs(y[i] - constellation)
symbols[i] = torch.argmin(distances)
# Toggle constellation for next symbol
use_rotated = ~use_rotated
# Store state for next call if in training
if self.training:
self._use_rotated = use_rotated.detach()
return symbols
# For soft decisions or batched input, return bits
# Prepare output array
if noise_var is None and not self.soft_output:
# Hard bit decisions
output_bits = torch.zeros(*batch_shape, symbol_shape, 2, dtype=torch.float, device=y.device)
else:
# Soft LLR values
output_bits = torch.zeros(*batch_shape, symbol_shape, 2, dtype=torch.float, device=y.device)
# Handle noise variance
if noise_var is not None:
if not isinstance(noise_var, torch.Tensor):
noise_var_tensor = torch.tensor(noise_var, device=y.device)
else:
noise_var_tensor = noise_var
if noise_var_tensor.dim() == 0: # scalar
noise_var_tensor = noise_var_tensor.expand(*batch_shape, symbol_shape)
else:
# Default noise variance for soft decisions when not provided
noise_var_tensor = torch.ones(*batch_shape, symbol_shape, device=y.device)
# Demodulate each symbol using the appropriate constellation
use_rotated = self._use_rotated.clone()
for i in range(symbol_shape):
# Select current constellation
constellation = qpsk_rotated if use_rotated else qpsk
# Process current symbol
if noise_var is None and not self.soft_output:
# Hard decision
if batch_shape:
# For batched input
y_i = y[..., i].unsqueeze(-1)
distances = torch.abs(y_i - constellation.unsqueeze(0))
else:
# For single input
y_i = y[i].unsqueeze(0)
distances = torch.abs(y_i - constellation)
closest_idx = torch.argmin(distances, dim=-1)
# Apply bit patterns
for b in range(len(self.modulator.bit_patterns)):
mask = closest_idx == b
if batch_shape:
output_bits[..., i, :] = self.modulator.bit_patterns[closest_idx, :]
else:
if mask.item():
output_bits[i, :] = self.modulator.bit_patterns[b]
else:
# Soft decision (LLR calculation)
current_noise_var = noise_var_tensor[..., i] if batch_shape else noise_var_tensor[i]
# Calculate LLRs for each bit position
for bit_idx in range(2):
# Create masks for symbols where bit is 0 or 1
bit_0_mask = self.modulator.bit_patterns[:, bit_idx] == 0
bit_1_mask = ~bit_0_mask
# Get constellation points for each bit value
const_bit_0 = constellation[bit_0_mask]
const_bit_1 = constellation[bit_1_mask]
# Calculate distances for each bit value
if batch_shape:
expanded_y = y[..., i].unsqueeze(-1)
# Distance to constellation points where bit is 0
distances_0 = -torch.abs(expanded_y - const_bit_0.unsqueeze(0)) ** 2
min_dist_0, _ = torch.max(distances_0, dim=-1)
min_dist_0 = min_dist_0 / current_noise_var
# Distance to constellation points where bit is 1
distances_1 = -torch.abs(expanded_y - const_bit_1.unsqueeze(0)) ** 2
min_dist_1, _ = torch.max(distances_1, dim=-1)
min_dist_1 = min_dist_1 / current_noise_var
else:
# For non-batched input
y_i = y[i]
# Distance to constellation points where bit is 0
distances_0 = -torch.abs(y_i - const_bit_0) ** 2
min_dist_0, _ = torch.max(distances_0, dim=-1)
min_dist_0 = min_dist_0 / current_noise_var
# Distance to constellation points where bit is 1
distances_1 = -torch.abs(y_i - const_bit_1) ** 2
min_dist_1, _ = torch.max(distances_1, dim=-1)
min_dist_1 = min_dist_1 / current_noise_var
# LLR: log(P(bit=0)/P(bit=1))
output_bits[..., i, bit_idx] = min_dist_0 - min_dist_1
# Toggle constellation for next symbol
use_rotated = ~use_rotated
# Store state for next call if in training
if self.training:
self._use_rotated = use_rotated.detach()
# Format output based on context
if self.soft_output and not batch_shape:
# For soft demodulation of non-batched input, maintain bit structure
return output_bits.reshape(symbol_shape, 2)
else:
# Standard flattened output
return output_bits.reshape(*batch_shape, -1)
[docs]
def reset_state(self) -> None:
"""Reset internal state (constellation alternation)."""
self._use_rotated.fill_(False)