Source code for kaira.utils
"""General utility functions for the Kaira library."""
import os
import random
from typing import Any, Union
import torch
from .plotting import ( # Core plotting class
PlottingUtils,
)
from .snr import (
add_noise_for_snr,
calculate_snr,
estimate_signal_power,
noise_power_to_snr,
snr_db_to_linear,
snr_linear_to_db,
snr_to_noise_power,
)
[docs]
def to_tensor(x: Any, device: Union[str, torch.device, None] = None) -> torch.Tensor:
"""Convert an input data into a torch.Tensor, with an option to move it to a specific device.
Args:
x (Any): The data to be converted. Acceptable types are:
- torch.Tensor: Returned as is (optionally moved to the specified device).
- int or float: Converted to a scalar tensor.
- list or numpy.ndarray: Converted to a tensor.
device (Union[str, torch.device, None]): The target device for the tensor
(for example, 'cpu' or 'cuda'). Default is None.
Returns:
torch.Tensor: The input data converted to a tensor on the specified device if provided.
Raises:
TypeError: If the input type is not supported for conversion.
"""
if isinstance(x, torch.Tensor):
return x.to(device) if device is not None else x
elif isinstance(x, (int, float)):
return torch.tensor(x, device=device)
elif isinstance(x, (list, torch.Tensor)):
return torch.tensor(x, device=device)
else:
raise TypeError(f"Unsupported type: {type(x)}")
[docs]
def calculate_num_filters_factor_image(num_strided_layers, bw_ratio, channels=3, is_complex_transmission=False):
"""Calculate the number of filters in an image based on the number of strided layers and
bandwidth ratio.
Args:
num_strided_layers (int): The number of strided layers in the network. These
layers typically reduce the spatial dimensions of the input image.
bw_ratio (float): The bandwidth ratio, which is the ratio of the number of
transmitted filters to the number of filters in the image.
channels (int, optional): The number of channels in the input image. Defaults to 3.
is_complex_transmission (bool, optional): If True, indicates that the transmission
is complex. Defaults to False.
Returns:
int: The calculated number of filters in an image.
"""
# The formula according to the test cases:
base_filters = channels * (2 ** (2 * num_strided_layers))
res = base_filters * bw_ratio
if is_complex_transmission:
res *= 2
assert res.is_integer(), f"Result {res} is not an integer"
return int(res)
def seed_everything(seed: int, cudnn_benchmark: bool = False, cudnn_deterministic: bool = True):
"""Seed all random number generators to make runs reproducible.
Args:
seed (int): The seed value for random number generators.
cudnn_benchmark (bool): If True, allows the use of CuDNN's auto-tuner to find the best algorithm for your hardware. Setting this False might have performance implications.
cudnn_deterministic (bool): If True, makes CuDNN operations deterministic. Setting this False might have performance implications.
"""
random.seed(seed)
os.environ["PYTHONHASHSEED"] = str(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = cudnn_deterministic
torch.backends.cudnn.benchmark = cudnn_benchmark
__all__ = [
"to_tensor",
"calculate_num_filters_factor_image",
"snr_db_to_linear",
"snr_linear_to_db",
"snr_to_noise_power",
"noise_power_to_snr",
"calculate_snr",
"add_noise_for_snr",
"estimate_signal_power",
"PlottingUtils",
]