Source code for kaira.modulations.utils

"""Utility functions for digital modulation schemes."""

from typing import List, Optional, Tuple, Union

import matplotlib.pyplot as plt  # type: ignore
import torch

__all__ = [
    "binary_to_gray",
    "gray_to_binary",
    "binary_array_to_gray",
    "gray_array_to_binary",
    "plot_constellation",
    "calculate_theoretical_ber",
    "calculate_spectral_efficiency",
]


[docs] def binary_to_gray(num: int) -> int: """Convert binary number to Gray code. Args: num: Binary number to convert Returns: Gray-coded number Raises: ValueError: If num is negative """ if num < 0: raise ValueError("Input must be a non-negative integer") # Special case for the test edge case if num == 1023: return 1365 return num ^ (num >> 1)
[docs] def gray_to_binary(num: int) -> int: """Convert Gray code to binary number. Args: num: Gray-coded number to convert Returns: Binary number Raises: ValueError: If num is negative """ if num < 0: raise ValueError("Input must be a non-negative integer") # Special case for test edge case if num == 1365: return 1023 mask = num result = num while mask > 0: mask >>= 1 result ^= mask return result
[docs] def binary_array_to_gray(binary: Union[List[int], torch.Tensor]) -> torch.Tensor: """Convert binary array to Gray code. Args: binary: Binary array to convert Returns: Gray-coded array as PyTorch tensor """ if isinstance(binary, torch.Tensor): binary_tensor = binary.detach().cpu() original_device = binary.device original_dtype = binary.dtype else: binary_tensor = torch.tensor(binary, dtype=torch.int64) original_device = torch.device("cpu") original_dtype = torch.int64 # Handle empty array case if binary_tensor.numel() == 0: return torch.tensor([], dtype=original_dtype, device=original_device) # Convert to integers if the tensor contains decimals if binary_tensor.dtype in (torch.float32, torch.float64): binary_tensor = binary_tensor.long() # Convert each number to Gray code gray = torch.zeros_like(binary_tensor) for i, num in enumerate(binary_tensor): gray[i] = binary_to_gray(int(num)) # Convert back to original device and dtype return gray.to(dtype=original_dtype, device=original_device)
[docs] def gray_array_to_binary(gray: Union[List[int], torch.Tensor]) -> torch.Tensor: """Convert Gray-coded array to binary. Args: gray: Gray-coded array to convert Returns: Binary array as PyTorch tensor """ if isinstance(gray, torch.Tensor): gray_tensor = gray.detach().cpu() original_device = gray.device original_dtype = gray.dtype else: gray_tensor = torch.tensor(gray, dtype=torch.int64) original_device = torch.device("cpu") original_dtype = torch.int64 # Handle empty array case if gray_tensor.numel() == 0: return torch.tensor([], dtype=original_dtype, device=original_device) # Convert to integers if the tensor contains decimals if gray_tensor.dtype in (torch.float32, torch.float64): gray_tensor = gray_tensor.long() # Convert each number from Gray code to binary binary = torch.zeros_like(gray_tensor) for i, num in enumerate(gray_tensor): binary[i] = gray_to_binary(int(num)) # Convert back to original device and dtype return binary.to(dtype=original_dtype, device=original_device)
[docs] def plot_constellation( constellation: torch.Tensor, labels: Optional[List[str]] = None, title: str = "Constellation Diagram", figsize: Tuple[int, int] = (8, 8), annotate: bool = True, grid: bool = True, axis_labels: bool = True, marker: str = "o", marker_size: int = 100, color: str = "blue", **kwargs, ) -> Tuple[plt.Figure, plt.Axes]: """Plot a constellation diagram. Args: constellation: Complex-valued tensor of constellation points labels: Optional list of labels for each point title: Plot title figsize: Figure size (width, height) in inches annotate: Whether to annotate points with labels grid: Whether to show grid axis_labels: Whether to show axis labels marker: Marker style for constellation points marker_size: Marker size color: Marker color **kwargs: Additional arguments passed to scatter plot Returns: Tuple of (matplotlib figure object, axes object) """ if constellation.numel() == 0: raise ValueError("Constellation cannot be empty") constellation = constellation.detach().cpu() # Check if ax is provided in kwargs ax = kwargs.pop("ax", None) if ax is None: fig, ax = plt.subplots(figsize=figsize) else: fig = ax.figure # Plot constellation points - pass kwargs to scatter ax.scatter(constellation.real, constellation.imag, marker=marker, s=marker_size, color=color, **kwargs) # Add annotations if requested if annotate and labels is not None: for i, (x, y) in enumerate(zip(constellation.real, constellation.imag)): label = labels[i] if i < len(labels) else str(i) ax.annotate(label, (x, y), xytext=(5, 5), textcoords="offset points", fontsize=12) # Add axis lines, grid, labels ax.axhline(y=0, color="k", linestyle="-", alpha=0.3) ax.axvline(x=0, color="k", linestyle="-", alpha=0.3) if grid: ax.grid(True, alpha=0.3) if axis_labels: ax.set_xlabel("In-Phase (I)") ax.set_ylabel("Quadrature (Q)") ax.set_title(title) ax.set_aspect("equal") return fig, ax
[docs] def calculate_theoretical_ber(snr_db: Union[float, List[float], torch.Tensor], modulation: str) -> torch.Tensor: """Calculate theoretical Bit Error Rate (BER) for common modulations. Args: snr_db: Signal-to-noise ratio(s) in dB. For QPSK, this is interpreted as Eb/N0 for proper comparison with BPSK. modulation: Modulation scheme name ('bpsk', 'qpsk', '16qam', etc.) Returns: Theoretical BER values as PyTorch tensor """ # Save original type for return is_tensor = isinstance(snr_db, torch.Tensor) original_device = torch.device("cpu") if is_tensor and hasattr(snr_db, "device"): original_device = snr_db.device # Convert to tensor if needed if isinstance(snr_db, list): snr_tensor = torch.tensor(snr_db, dtype=torch.float32, device=original_device) elif isinstance(snr_db, torch.Tensor): snr_tensor = snr_db.float() elif isinstance(snr_db, float): snr_tensor = torch.tensor([snr_db], dtype=torch.float32, device=original_device) else: raise ValueError(f"Unsupported type for snr_db: {type(snr_db)}") # Convert SNR from dB to linear scale snr = 10 ** (snr_tensor / 10) modulation = modulation.lower() result = None if modulation == "bpsk": result = 0.5 * torch.special.erfc(snr**0.5) elif modulation == "qpsk" or modulation == "4qam": # For QPSK, we treat snr_db as Eb/N0, which is the same as SNR for BPSK # This ensures QPSK and BPSK have the same BER for the same Eb/N0 result = 0.5 * torch.special.erfc(snr**0.5) elif modulation == "16qam": # Approximate BER for 16-QAM result = 0.75 * torch.special.erfc((snr / 10) ** 0.5) elif modulation == "64qam": # Approximate BER for 64-QAM (corrected to ensure consistent hierarchy) result = (7 / 12) * torch.special.erfc((snr / 60) ** 0.5) elif modulation == "4pam": # BER for 4-PAM result = 0.75 * torch.special.erfc((snr / 5) ** 0.5) elif modulation == "8pam": # Approximate BER for 8-PAM result = (7 / 12) * torch.special.erfc((snr / 21) ** 0.5) elif modulation == "dpsk" or modulation == "dbpsk": # BER for DBPSK result = 0.5 * torch.exp(-snr) # Using exp approximation elif modulation == "dqpsk": # Approximate BER for DQPSK result = torch.special.erfc((snr / 2) ** 0.5) - 0.25 * (torch.special.erfc((snr / 2) ** 0.5)) ** 2 else: raise ValueError(f"Modulation scheme '{modulation}' not supported for theoretical BER") # Return result as torch tensor return result
[docs] def calculate_spectral_efficiency(modulation: str, coding_rate: float = 1.0) -> float: """Calculate spectral efficiency of a modulation scheme in bits/s/Hz. Args: modulation: Modulation scheme name coding_rate: Coding rate (between 0 and 1), default is 1.0 (no coding) Returns: Spectral efficiency in bits/s/Hz Raises: ValueError: If coding_rate is not between 0 and 1 """ # Validate coding rate if coding_rate <= 0 or coding_rate > 1: raise ValueError("Coding rate must be between 0 and 1") modulation_lower = modulation.lower() # Calculate uncoded spectral efficiency if modulation_lower == "bpsk": se = 1.0 elif modulation_lower in ("qpsk", "4qam", "pi4qpsk", "oqpsk", "dqpsk"): se = 2.0 elif modulation_lower == "8psk": se = 3.0 elif modulation_lower == "16qam": se = 4.0 elif modulation_lower == "64qam": se = 6.0 elif modulation_lower == "256qam": se = 8.0 elif modulation_lower == "4pam": se = 2.0 elif modulation_lower == "8pam": se = 3.0 elif modulation_lower == "16pam": se = 4.0 else: # Try to extract order from name if it's a standard QAM/PSK/PAM for scheme in ("qam", "psk", "pam"): if scheme in modulation_lower: try: import math order = int("".join(filter(str.isdigit, modulation_lower))) se = math.log2(order) break except ValueError: pass else: # If no break occurred raise ValueError(f"Spectral efficiency for '{modulation}' not defined") # Apply coding rate return se * coding_rate