data/4_tensorflow_data.pyΒΆ

import tensorflow as tf
from utils import setup_datasets

from astrodata.data.loaders.tensorflow_loader import TensorflowLoader


def describe_dataset(name, ds):
    card = tf.data.experimental.cardinality(ds).numpy()
    print("--" * 30)
    print(f"{name} tf.data.Dataset details:")
    print(f"Cardinality (batches): {card if card >= 0 else 'unknown'}")

    # Peek one batch
    for images, labels in ds.take(1):
        print(f"[{name}] Example batch images tensor shape: {images.shape}")
        print(f"[{name}] Example batch labels tensor shape: {labels.shape}")
        break
    print("--" * 30)


if __name__ == "__main__":
    cifar_dir = "../../testdata/torch/cifar10"
    fits_dir = "../../testdata/torch/fits"
    setup_datasets(cifar_dir, fits_dir)

    # Initialize loader
    loader = TensorflowLoader()
    print("KerasLoader initialized.")

    print("Loading CIFAR10 directory-structured dataset with Keras...")
    cifar_data = loader.load(cifar_dir, image_size=(32, 32))
    print("Loaded CIFAR10 KerasData.")

    # Class mappings (from train split)
    print(f"CIFAR10 class_names: {cifar_data.metadata.get('class_names')}")
    print(f"CIFAR10 class_to_idx: {cifar_data.metadata.get('class_to_idx')}")

    cifar_train = cifar_data.get_dataset("train")
    cifar_test = cifar_data.get_dataset("test")

    describe_dataset("CIFAR10 Train", cifar_train)
    describe_dataset("CIFAR10 Test", cifar_test)

    # Demonstrate iterating a single batch from the train dataset
    for images, labels in cifar_train.take(1):
        print(f"[CIFAR10] Batch example images shape: {images.shape}")
        print(f"[CIFAR10] Batch example labels shape: {labels.shape}")
        break

    print("Loading FITS directory-structured dataset with Keras...")
    fits_data = loader.load(fits_dir, batch_size=1)

    print("Loaded FITS KerasData.")

    # class mappings (from train split)
    print(f"FITS class_names: {fits_data.metadata.get('class_names')}")
    print(f"FITS class_to_idx: {fits_data.metadata.get('class_to_idx')}")

    fits_train = fits_data.get_dataset("train")
    fits_test = fits_data.get_dataset("test")

    describe_dataset("FITS Train", fits_train)
    describe_dataset("FITS Test", fits_test)

    for images, labels in fits_train.take(1):
        print(f"[FITS] Batch example images shape: {images.shape}")
        print(f"[FITS] Batch example labels shape: {labels.shape}")
        break