real_world/tensorflow_galaxyMNIST.pyΒΆ

import keras
import tensorflow as tf
from keras.metrics import Accuracy, SparseCategoricalAccuracy

from astrodata.data.loaders.tensorflow_loader import TensorflowLoader
from astrodata.ml.metrics import TensorflowMetric
from astrodata.ml.models import TensorflowModel
from astrodata.tracking.MLFlowTracker import TensorflowMLflowTracker
from testdata import download_galaxy_mnist

if not tf.executing_eagerly():
    tf.config.run_functions_eagerly(True)

# ============================================================
# DATA LOADING
# ============================================================
data_path = download_galaxy_mnist()
print(f"GalaxyMNIST data downloaded and extracted to: {data_path}")

img_size = (64, 64)

loader = TensorflowLoader()
cifar_data = loader.load(data_path, image_size=img_size)

print(f"GalaxyMNIST class_names: {cifar_data.metadata.get('class_names')}")
print(f"GalaxyMNIST class_to_idx: {cifar_data.metadata.get('class_to_idx')}")

galaxymnist_train = cifar_data.get_dataset("train")
galaxymnist_test = cifar_data.get_dataset("test")
galaxymnist_val = cifar_data.get_dataset("val")

# Get number of classes from the metadata
num_classes = len(cifar_data.metadata.get("class_names", []))
print(f"Number of classes: {num_classes}")

# ============================================================
# MODEL ARCHITECTURE
# ============================================================

# Load pre-trained ResNet50 without top classification layer
base_model = keras.applications.ResNet50(
    weights="imagenet",
    include_top=False,
    input_shape=(img_size[0], img_size[1], 3),  # GalaxyMNIST image size
)

# Data augmentation layers for training robustness
augment = keras.Sequential(
    [
        keras.layers.RandomFlip(mode="horizontal"),
        keras.layers.RandomRotation(0.2),
        keras.layers.RandomContrast(0.2),
    ],
    name="data_augmentation",
)

# Custom classification head for GalaxyMNIST classes
custom_head = keras.Sequential(
    [
        keras.layers.GlobalAveragePooling2D(),
        keras.layers.Dense(128, activation="relu"),
        keras.layers.Dropout(0.2),  # Add dropout for regularization
        keras.layers.Dense(num_classes, activation="softmax", name="galaxy_classifier"),
    ],
    name="galaxy_mnist_head",
)

# Build complete model: input -> augmentation -> resnet50 -> custom head
inputs = keras.Input(shape=(img_size[0], img_size[1], 3), name="input_layer")
x = augment(inputs)
x = keras.applications.resnet50.preprocess_input(x)
x = base_model(x, training=False)
x = custom_head(x)

galaxy_minst_resnet = keras.Model(inputs, x, name="ResNet50_GalaxyMNIST")

print(f"Model created with {num_classes} output classes")
print("Model summary:")
galaxy_minst_resnet.summary()

# ============================================================
# METRICS AND TRACKING SETUP
# ============================================================
accuracy_metric = TensorflowMetric(
    SparseCategoricalAccuracy(), name="accuracy", greater_is_better=True
)
metrics = [accuracy_metric]

tracker = TensorflowMLflowTracker(
    run_name="GalaxyMNIST_ResNet50_FineTuning",
    experiment_name="GalaxyMNIST_ResNet50_FineTuning",
    extra_tags={"stage": "testing"},
)

# ============================================================
# PHASE 1: TRAIN CUSTOM HEAD ONLY
# ============================================================
print("\n" + "=" * 60)
print("PHASE 1: Training custom classification head")
print("Base ResNet50 layers are frozen")
print("=" * 60 + "\n")

model_phase1 = TensorflowModel(
    model_class=galaxy_minst_resnet,
    model_params={},
    loss_fn=keras.losses.SparseCategoricalCrossentropy,
    optimizer=keras.optimizers.Adam,
    optimizer_params={"learning_rate": 5e-3},
    epochs=5,
    batch_size=32,
    with_weight_init=True,
)

model_phase1 = tracker.wrap_fit(
    model=model_phase1,
    dataset_val=galaxymnist_val,
    dataset_test=galaxymnist_test,
    metrics=metrics,
    log_model=False,
    run_name="phase_1_tensorflow",
)

# Freeze base model, train only custom head
model_phase1.unfreeze_layers("all")
model_phase1.freeze_layers(["input_layer", "data_augmentation", "resnet50"])

print("\nPhase 1 layer trainability:")
for layer in model_phase1.model_.layers:
    print(f"Layer: {layer.name}, Trainable: {layer.trainable}")

print("\nStarting Phase 1 training...")
model_phase1.fit(
    dataset=galaxymnist_train,
    dataset_val=galaxymnist_val,
    metrics=metrics,
    fine_tune=True,
)

print("\n" + "=" * 60)
print("PHASE 1 COMPLETED")
print("=" * 60)

# ============================================================
# PHASE 2: FINE-TUNE CONV5 BLOCK OF RESNET50
# ============================================================

print("\n" + "=" * 60)
print("PHASE 2: Fine-tuning ResNet50 conv5 block")
print("Custom head + conv5 block trainable, rest frozen")
print("=" * 60 + "\n")

model_phase2 = TensorflowModel(
    model_class=model_phase1.model_,
    model_params={},
    loss_fn=keras.losses.SparseCategoricalCrossentropy,
    optimizer=keras.optimizers.Adam,
    optimizer_params={"learning_rate": 5e-4},
    epochs=15,
    batch_size=32,
    with_weight_init=True,
)

model_phase2 = tracker.wrap_fit(
    model_phase2,
    dataset_val=galaxymnist_val,
    dataset_test=galaxymnist_test,
    metrics=metrics,
    log_model=False,
    run_name="phase_2_tensorflow",
)

# Unfreeze all layers, then selectively freeze input and augmentation
model_phase2.unfreeze_layers("all")
model_phase2.freeze_layers(["input_layer", "data_augmentation"])

# Freeze all ResNet50 layers except conv5 block
tofreeze = []

for _layer in model_phase2.model_.layers[2].layers:
    if not _layer.name.startswith("conv5"):
        tofreeze.append(_layer.name)

model_phase2.freeze_layers(tofreeze, "resnet50")

print("\nPhase 2 layer trainability (first 10 ResNet50 layers):")
for _layer in model_phase2.model_.layers[2].layers[0:10]:
    print(f"Layer: {_layer.name}, Trainable: {_layer.trainable}")

print("\nPhase 2 top-level layer trainability:")
for layer in model_phase2.model_.layers:
    print(f"Layer: {layer.name}, Trainable: {layer.trainable}")

print("\nStarting Phase 2 training...")
model_phase2.fit(
    dataset=galaxymnist_train,
    dataset_val=galaxymnist_val,
    metrics=metrics,
    fine_tune=True,
)

# ============================================================
# FINAL EVALUATION
# ============================================================

print("\n" + "=" * 60)
print("PHASE 2 COMPLETED - Final Evaluation")
print("=" * 60 + "\n")
phase2_metrics = model_phase2.get_metrics(dataset=galaxymnist_test, metrics=metrics)
print(f"Phase 2 metrics: {phase2_metrics}")