"""Utility functions for constraints.
This module provides helper functions for creating, testing, validating, and working with
constraints in wireless communication systems. These utilities streamline the process of
configuring common constraint combinations and verifying constraint effectiveness.
"""
from typing import Any, Dict, List, Optional
import torch
from .antenna import PerAntennaPowerConstraint
from .base import BaseConstraint
from .composite import CompositeConstraint
from .power import TotalPowerConstraint
from .signal import PeakAmplitudeConstraint, SpectralMaskConstraint
__all__ = [
"create_ofdm_constraints",
"create_mimo_constraints",
"combine_constraints",
"verify_constraint",
"apply_constraint_chain",
"measure_signal_properties",
]
# Factory functions for common constraint combinations
[docs]
def create_ofdm_constraints(
total_power: float,
max_papr: float = 6.0,
is_complex: bool = True,
peak_amplitude: Optional[float] = None,
) -> CompositeConstraint:
"""Create constraints commonly used in OFDM systems.
Configures a set of constraints suitable for Orthogonal Frequency Division
Multiplexing (OFDM) signals. This includes a power constraint and a Peak-to-Average
Power Ratio (PAPR) constraint to handle the high dynamic range typical of OFDM.
Args:
total_power (float): Total power constraint value in linear units
max_papr (float, optional): Maximum allowed PAPR in linear units (not dB).
Defaults to 6.0 (approximately 7.8 dB).
is_complex (bool, optional): Whether the signal is complex-valued.
Defaults to True.
peak_amplitude (float, optional): If provided, adds a peak amplitude constraint.
Defaults to None.
Returns:
CompositeConstraint: Combined OFDM constraints ready to be applied to signals
Example:
>>> ofdm_constraints = create_ofdm_constraints(total_power=1.0, max_papr=4.0)
>>> constrained_signal = ofdm_constraints(input_signal)
"""
constraints = []
# Add PAPR constraint
if max_papr is not None:
from .power import PAPRConstraint
constraints.append(PAPRConstraint(max_papr=max_papr))
# Only add peak amplitude constraint if explicitly provided
if peak_amplitude is not None:
# Add explicit peak amplitude constraint if provided
constraints.append(PeakAmplitudeConstraint(peak_amplitude))
# Add power constraint
constraints.append(TotalPowerConstraint(total_power))
return CompositeConstraint(constraints)
[docs]
def create_mimo_constraints(
num_antennas: int,
uniform_power: Optional[float] = None,
max_papr: Optional[float] = None,
spectral_mask: Optional[torch.Tensor] = None,
total_power: Optional[float] = None,
) -> CompositeConstraint:
"""Create constraints commonly used in MIMO systems.
Configures constraints appropriate for Multiple-Input Multiple-Output (MIMO)
systems, focusing on either maintaining equal power distribution across antennas
or controlling total power across all antennas, while optionally controlling PAPR
and spectral characteristics.
Args:
num_antennas (int): Number of antennas in the MIMO system
uniform_power (float, optional): Power per antenna in linear units. If None and
total_power is provided, will use a total power constraint instead.
Defaults to None.
max_papr (float, optional): Maximum allowed PAPR in linear units (not dB).
If None, no PAPR constraint is applied. Defaults to None.
spectral_mask (torch.Tensor, optional): If provided, adds a spectral mask constraint.
Defaults to None.
total_power (float, optional): If provided, uses a total power constraint instead of
per-antenna power constraints. This is useful when the total transmit power is
limited, but power can be allocated flexibly across antennas. Defaults to None.
Returns:
CompositeConstraint: Combined MIMO constraints ready to be applied to signals
Raises:
ValueError: If both uniform_power and total_power are None, or if both are provided
Example:
>>> # Example with per-antenna power constraint
>>> mimo_constraints = create_mimo_constraints(
... num_antennas=4, uniform_power=0.25, max_papr=4.0
... )
>>> # Example with total power constraint
>>> mimo_constraints = create_mimo_constraints(
... num_antennas=4, total_power=1.0, max_papr=4.0
... )
>>> constrained_signal = mimo_constraints(input_signal)
"""
constraints = []
# Check if we have valid power constraint settings
if uniform_power is None and total_power is None:
raise ValueError("Either uniform_power or total_power must be provided")
if uniform_power is not None and total_power is not None:
raise ValueError("Cannot specify both uniform_power and total_power; use one or the other")
# Add power constraint first
if uniform_power is not None:
constraints.append(PerAntennaPowerConstraint(uniform_power=uniform_power))
else:
# At this point, total_power must be a float because of the earlier checks
assert total_power is not None, "total_power cannot be None here due to prior validation"
constraints.append(TotalPowerConstraint(total_power=total_power))
# Add PAPR constraint if specified
if max_papr is not None:
from .power import PAPRConstraint
constraints.append(PAPRConstraint(max_papr=max_papr))
# Add spectral mask constraint if specified
if spectral_mask is not None:
constraints.append(SpectralMaskConstraint(spectral_mask))
return CompositeConstraint(constraints)
[docs]
def combine_constraints(constraints: List[BaseConstraint]) -> BaseConstraint:
"""Combine multiple constraints into a single constraint.
Creates a composite constraint that applies multiple constraints in sequence.
This is useful for building custom constraint chains.
Args:
constraints (List[BaseConstraint]): List of constraints to combine
Returns:
BaseConstraint: Combined constraint that applies all input constraints
sequentially
Raises:
ValueError: If the constraints list is empty
Example:
>>> power_constraint = TotalPowerConstraint(1.0)
>>> papr_constraint = PAPRConstraint(4.0)
>>> amp_constraint = PeakAmplitudeConstraint(1.5)
>>> combined = combine_constraints([power_constraint, papr_constraint, amp_constraint])
>>> constrained_signal = combined(input_signal)
"""
if not constraints:
raise ValueError("Cannot combine an empty list of constraints")
if len(constraints) == 1:
return constraints[0]
return CompositeConstraint(constraints)
# Verification and testing utilities
[docs]
def verify_constraint(
constraint: BaseConstraint,
input_tensor: torch.Tensor,
expected_property: str,
expected_value: float,
tolerance: float = 1e-5,
) -> Dict[str, Any]:
"""Verify that a constraint produces the expected property in the output.
Tests whether applying a constraint to a tensor results in the expected
property (such as power or PAPR) within a specified tolerance.
Args:
constraint (BaseConstraint): Constraint to test
input_tensor (torch.Tensor): Input tensor to pass through the constraint
expected_property (str): Name of the property to check.
Valid values: 'power', 'papr', 'amplitude'
expected_value (float): Expected value for the property in linear units
tolerance (float, optional): Tolerance for numerical comparison. Defaults to 1e-5.
Returns:
Dict[str, Any]: Results dictionary containing:
- input_shape: Shape of the input tensor
- output_shape: Shape of the constrained output
- success: Whether the constraint achieved the expected property
- measured_<property>: Actual measured value of the property
- expected_<property>: Expected value of the property
Raises:
ValueError: If expected_property is not one of the supported values
Example:
>>> power_constraint = TotalPowerConstraint(1.0)
>>> input_signal = torch.randn(8, 64)
>>> result = verify_constraint(power_constraint, input_signal, 'power', 1.0)
>>> print(f"Constraint satisfied: {result['success']}")
"""
constrained_output = constraint(input_tensor)
results = {
"input_shape": input_tensor.shape,
"output_shape": constrained_output.shape,
"success": False,
}
# Check property based on the expected type
if expected_property == "power":
# Calculate total power
power = torch.sum(torch.abs(constrained_output) ** 2).item()
results["measured_power"] = power
results["expected_power"] = expected_value
# Use relative tolerance for numerical stability
relative_tolerance = tolerance * max(1.0, abs(expected_value))
results["success"] = abs(power - expected_value) <= relative_tolerance
elif expected_property == "papr":
# Calculate PAPR
mean_power = torch.mean(torch.abs(constrained_output) ** 2).item()
peak_power = torch.max(torch.abs(constrained_output) ** 2).item()
papr = peak_power / mean_power if mean_power > 0 else float("inf")
results["measured_papr"] = papr
results["expected_papr"] = expected_value
# PAPR should be less than or equal to expected value
# Use a more generous tolerance for PAPR as it's an approximation
papr_tolerance = max(tolerance, 1.0) # Allow more tolerance for PAPR constraints
results["success"] = papr <= expected_value + papr_tolerance
elif expected_property == "amplitude":
# Check max amplitude
max_amp = torch.max(torch.abs(constrained_output)).item()
results["measured_max_amplitude"] = max_amp
results["expected_max_amplitude"] = expected_value
results["success"] = max_amp <= expected_value + tolerance
else:
raise ValueError(f"Unsupported property: {expected_property}. Supported values are: power, papr, amplitude")
return results
[docs]
def apply_constraint_chain(constraints: List[BaseConstraint], input_tensor: torch.Tensor) -> torch.Tensor:
"""Apply a list of constraints in sequence and optionally print debug info.
Applies multiple constraints to a tensor sequentially and provides optional
debugging information about power changes at each step.
Args:
constraints (List[BaseConstraint]): List of constraint objects to apply in sequence
input_tensor (torch.Tensor): Input tensor to be constrained
Returns:
torch.Tensor: Output tensor after applying all constraints
Example:
>>> constraints = [
... TotalPowerConstraint(1.0),
... PAPRConstraint(4.0)
... ]
>>> output = apply_constraint_chain(constraints, input_signal)
"""
x = input_tensor
for constraint in constraints:
x = constraint(x)
return x
[docs]
def measure_signal_properties(x: torch.Tensor) -> Dict[str, float]:
"""Measure common signal properties for a given tensor.
Calculates key signal properties like power, PAPR, and peak amplitude
that are commonly constrained in communication systems.
Args:
x (torch.Tensor): Input signal tensor
Returns:
Dict[str, float]: Dictionary of measured signal properties
Example:
>>> signal = torch.randn(1, 64)
>>> props = measure_signal_properties(signal)
>>> print(f"Signal PAPR: {props['papr']:.2f}")
"""
mean_power = torch.mean(torch.abs(x) ** 2).item()
peak_power = torch.max(torch.abs(x) ** 2).item()
peak_amplitude = torch.max(torch.abs(x)).item()
papr = peak_power / mean_power if mean_power > 0 else float("inf")
return {
"mean_power": mean_power,
"peak_power": peak_power,
"peak_amplitude": peak_amplitude,
"papr": papr,
"papr_db": 10 * torch.log10(torch.tensor(papr)).item() if mean_power > 0 else float("inf"),
}