Source code for kaira.data.sample_data
"""Utilities for loading sample data, such as standard test images."""
import os
from typing import Literal, Optional, Tuple
import torch
import torchvision
import torchvision.transforms as transforms
[docs]
def load_sample_images(dataset: Literal["cifar10", "cifar100", "mnist"] = "cifar10", num_samples: int = 4, seed: Optional[int] = None, normalize: bool = False) -> Tuple[torch.Tensor, torch.Tensor]:
"""Load sample images from popular datasets for demonstrations.
This function provides easy access to sample images from standard datasets like
CIFAR-10, CIFAR-100, and MNIST for demonstration purposes.
Args:
dataset: Name of the dataset to sample from ('cifar10', 'cifar100', 'mnist')
num_samples: Number of sample images to return
seed: Random seed for reproducibility
normalize: Whether to normalize the images to [0,1] range
Returns:
Tuple containing:
- Tensor of images with shape (num_samples, C, H, W)
- Tensor of labels with shape (num_samples,)
"""
# Set random seed if provided
if seed is not None:
torch.manual_seed(seed)
# Define transforms
if normalize:
transform = transforms.Compose([transforms.ToTensor()])
else:
transform = transforms.Compose([transforms.ToTensor()])
# Load the appropriate dataset
# Get the root library directory
current_dir = os.path.dirname(os.path.abspath(__file__))
# Navigate to the root library directory (two levels up)
root_library_dir = os.path.abspath(os.path.join(current_dir, os.pardir, os.pardir))
root_path = os.path.join(root_library_dir, ".cache", "data")
os.makedirs(root_path, exist_ok=True)
if dataset.lower() == "cifar10":
data = torchvision.datasets.CIFAR10(root=root_path, train=True, download=True, transform=transform)
elif dataset.lower() == "cifar100":
data = torchvision.datasets.CIFAR100(root=root_path, train=True, download=True, transform=transform)
elif dataset.lower() == "mnist":
data = torchvision.datasets.MNIST(root=root_path, train=True, download=True, transform=transform)
else:
raise ValueError(f"Unsupported dataset: {dataset}. Choose from 'cifar10', 'cifar100', or 'mnist'")
# Create a subset of the data
indices = torch.randperm(len(data))[:num_samples]
images = []
labels = []
for idx in indices:
img, label = data[idx]
images.append(img)
labels.append(label)
# Stack into batches
images = torch.stack(images)
labels = torch.tensor(labels)
return images, labels