data/3_torch_data.pyΒΆ

from astrodata.data.loaders import TorchDataLoaderWrapper, TorchLoader
from utils import setup_datasets


if __name__ == "__main__":

    # ==============================================================
    # DATA SETUP (executed once for both classic image + FITS examples)
    # This will:
    # 1. Download CIFAR10.
    # 2. Materialize its train/test splits into an ImageFolder-like directory tree.
    # 3. Download a single example FITS file and duplicate it into a dummy
    #    train/test, multi-class directory tree for demonstration.
    # ==============================================================
    print("Setting up datasets...")
    cifar_dir = "../../testdata/torch/cifar10"
    fits_dir = "../../testdata/torch/fits"
    setup_datasets(cifar_dir, fits_dir)

    # Initialize the TorchLoader, which can handle both classic images and FITS
    loader = TorchLoader()
    print("TorchImageLoader initialized.")

    # ==============================================================
    # SECTION 1: CLASSIC (RGB) IMAGE DATA EXAMPLE (CIFAR10)
    # --------------------------------------------------------------
    # Goal: Show how a standard natural-image style dataset (already in a
    #       folder layout with class subdirectories) is loaded and wrapped
    #       into PyTorch DataLoaders.
    # ==============================================================

    # Load directory-structured CIFAR10 dataset (train/test folders)
    cifar_data = loader.load(cifar_dir)
    print("Loaded CIFAR10 structured dataset object.")

    # Display class mappings for both splits
    print(f"CIFAR10 train class_to_idx: {cifar_data.get_dataset('train').class_to_idx}")
    print(f"CIFAR10 test  class_to_idx: {cifar_data.get_dataset('test').class_to_idx}")

    # Create DataLoader wrapper (shared config, could be tuned per section)
    # This step is optional, as the PytorchModule already implements dataloader
    # creation internally. However, this shows how to use the wrapper.
    dataloader_wrapper = TorchDataLoaderWrapper(
        batch_size=32,
        num_workers=0,
        pin_memory=False,
    )
    print("Initialized TorchDataLoaderWrapper.")

    # Build actual PyTorch DataLoaders
    cifar_dataloaders = dataloader_wrapper.create_dataloaders(cifar_data)
    print("Created CIFAR10 DataLoaders.")

    # Extract split-specific loaders
    cifar_train_loader = cifar_dataloaders.get_dataloader("train")
    cifar_test_loader = cifar_dataloaders.get_dataloader("test")

    # Introspect the train DataLoader
    print("--" * 30)
    print("CIFAR10 Train DataLoader details:")
    print(f"Number of batches: {len(cifar_train_loader)}")
    print(f"Batch size: {cifar_train_loader.batch_size}")
    print("--" * 30)

    # Demonstrate single training batch access
    for images, labels in cifar_train_loader:
        print(f"[CIFAR10] Example batch images tensor shape: {images.shape}")
        print(f"[CIFAR10] Example batch labels tensor shape: {labels.shape}")
        break

    # ==============================================================
    # SECTION 2: FITS (ASTRONOMICAL) IMAGE DATA EXAMPLE
    # --------------------------------------------------------------
    # Goal: Show how the same loader abstraction can handle FITS images
    #       arranged in the same folder pattern:
    #         fits_dir/
    #             train/first/*.fits
    #             train/second/*.fits
    #             test/first/*.fits
    #             test/second/*.fits
    #
    # The FITS setup duplicated one sample file to keep this extremely small.
    # We still wrap it in DataLoaders to mimic a real workflow.
    # ==============================================================

    fits_data = loader.load(fits_dir)
    print("Loaded FITS structured dataset object.")
    print(f"FITS train class_to_idx: {fits_data.get_dataset('train').class_to_idx}")
    print(f"FITS test  class_to_idx: {fits_data.get_dataset('test').class_to_idx}")

    # Reuse the same wrapper (could alternatively instantiate with different params)
    fits_dataloaders = dataloader_wrapper.create_dataloaders(fits_data)
    print("Created FITS DataLoaders.")

    fits_train_loader = fits_dataloaders.get_dataloader("train")
    fits_test_loader = fits_dataloaders.get_dataloader("test")

    print("--" * 30)
    print("FITS Train DataLoader details:")
    print(f"Number of batches: {len(fits_train_loader)}")
    print(f"Batch size: {fits_train_loader.batch_size}")
    print("--" * 30)

    # Show a single batch (each element a single-channel tensor)
    for images, labels in fits_train_loader:
        print(f"[FITS] Example batch images tensor shape: {images.shape}")
        print(f"[FITS] Example batch labels tensor shape: {labels.shape}")
        break