Source code for kaira.modulations.base
"""Base classes for modulation and demodulation schemes."""
from abc import ABC, abstractmethod
from typing import Optional, Union
import torch
import torch.nn as nn
[docs]
class BaseModulator(nn.Module, ABC):
"""Abstract base class for all modulators.
A modulator maps bit sequences to complex symbols according to a specific
modulation scheme.
Attributes:
constellation: Complex-valued tensor of constellation points
"""
[docs]
def __init__(self, bits_per_symbol: Optional[int] = None, *args, **kwargs) -> None:
"""Initialize the modulator.
Args:
bits_per_symbol: Number of bits to encode in each symbol
*args: Variable length argument list.
**kwargs: Arbitrary keyword arguments.
"""
super().__init__(*args, **kwargs) # Pass *args and **kwargs to parent
self._bits_per_symbol = bits_per_symbol
@property
def bits_per_symbol(self) -> int:
"""Number of bits per symbol."""
if self._bits_per_symbol is None:
raise NotImplementedError("bits_per_symbol must be defined in subclass")
return self._bits_per_symbol
[docs]
@abstractmethod
def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
"""Modulate bits to symbols.
Args:
x: Input tensor of bits with shape (..., K*N), where K is bits_per_symbol
*args: Variable length argument list.
**kwargs: Arbitrary keyword arguments.
Returns:
Modulated symbols with shape (..., N)
"""
pass
[docs]
def plot_constellation(self, **kwargs):
"""Plot the constellation diagram.
Args:
**kwargs: Additional arguments for plotting
Returns:
Matplotlib figure object
"""
raise NotImplementedError("plot_constellation must be implemented in subclass")
[docs]
def reset_state(self) -> None:
"""Reset any stateful components.
For modulators with memory (like differential schemes).
"""
pass # Default implementation does nothing
[docs]
class BaseDemodulator(nn.Module, ABC):
"""Abstract base class for all demodulators.
A demodulator maps received complex symbols back to bit sequences according to a specific
demodulation scheme, which may include soft or hard decisions.
"""
[docs]
def __init__(self, bits_per_symbol: Optional[int] = None, *args, **kwargs) -> None:
"""Initialize the demodulator.
Args:
bits_per_symbol: Number of bits encoded in each symbol
*args: Variable length argument list.
**kwargs: Arbitrary keyword arguments.
"""
super().__init__(*args, **kwargs) # Pass *args and **kwargs to parent
self._bits_per_symbol = bits_per_symbol
@property
def bits_per_symbol(self) -> int:
"""Number of bits per symbol."""
if self._bits_per_symbol is None:
raise NotImplementedError("bits_per_symbol must be defined in subclass")
return self._bits_per_symbol
[docs]
@abstractmethod
def forward(self, y: torch.Tensor, noise_var: Optional[Union[float, torch.Tensor]] = None, *args, **kwargs) -> torch.Tensor:
"""Demodulate symbols to bits or LLRs.
Args:
y: Received symbols with shape (..., N)
noise_var: Noise variance for soft demodulation (optional)
*args: Variable length argument list.
**kwargs: Arbitrary keyword arguments.
Returns:
If noise_var is provided, returns LLRs; otherwise, returns hard bit decisions
with shape (..., N*bits_per_symbol)
"""
pass
[docs]
def reset_state(self) -> None:
"""Reset any stateful components.
For demodulators with memory.
"""
pass # Default implementation does nothing