data/3_torch_data.pyΒΆ

from astrodata.data.loaders import TorchDataLoaderWrapper, TorchLoader

if __name__ == "__main__":
    loader = TorchLoader()
    print("TorchLoader initialized.")

    # Load the image data from the specified directory structure.
    # The directory should contain train/val/test folders with class subdirectories.
    raw_data = loader.load(
        "../../testdata/torch/mnist/"
    )  # Points to folder with train/val/test
    print("Image data loaded from directory structure.")
    print(f"train data class_to_idx: {raw_data.get_dataset('train').class_to_idx}")
    print(f"test data class_to_idx: {raw_data.get_dataset('test').class_to_idx}")
    # print(f"val data class_to_idx: {raw_data.get_dataset('val').class_to_idx}")

    # Define the DataLoader wrapper with desired settings for training.
    # This wrapper will create PyTorch DataLoaders
    dataloader_wrapper = TorchDataLoaderWrapper(
        batch_size=32,
        num_workers=0,
        pin_memory=False,
    )
    print("TorchDataLoaderWrapper initialized with training configuration.")

    # Create the actual PyTorch DataLoaders from the raw data.
    dataloaders = dataloader_wrapper.create_dataloaders(raw_data)
    print("PyTorch DataLoaders created successfully.")

    # Extract individual DataLoaders for each data split.
    # These DataLoaders can be directly used in PyTorch training loops.
    train_dataloader = dataloaders.get_dataloader("train")
    # val_dataloader = dataloaders.get_dataloader("val")
    test_dataloader = dataloaders.get_dataloader("test")

    # Show in detail what's inside the train DataLoader
    print("--" * 30)
    print("Train DataLoader details:")
    print(f"Number of batches in train DataLoader: {len(train_dataloader)}")
    print(f"Batch size: {train_dataloader.batch_size}")
    print("--" * 30)

    # Show how a train_dataloader can be accessed in a training loop
    for images, labels in train_dataloader:
        print(f"Batch of images shape: {images.shape}")
        print(f"Batch of labels shape: {labels.shape}")
        break