ml/18_tensorflow_mlflow_example.pyΒΆ

import keras as K
import tensorflow as tf
from keras.metrics import SparseCategoricalAccuracy
from sklearn.datasets import load_breast_cancer
from sklearn.model_selection import train_test_split

from astrodata.ml.metrics import TensorflowMetric
from astrodata.ml.models import TensorflowModel
from astrodata.tracking.MLFlowTracker import TensorflowMLflowTracker

if __name__ == "__main__":
    X, y = load_breast_cancer(return_X_y=True)
    X_train, X_val, y_train, y_val = train_test_split(
        X, y, test_size=0.2, random_state=42
    )
    X_train, X_test, y_train, y_test = train_test_split(
        X_train, y_train, test_size=0.1, random_state=42
    )

    # Convert to float32 for TensorFlow and ensure proper data types
    X_train = X_train.astype("float32")
    X_val = X_val.astype("float32")
    X_test = X_test.astype("float32")
    y_train = y_train.astype("int32")
    y_val = y_val.astype("int32")
    y_test = y_test.astype("int32")

    dataset = tf.data.Dataset.from_tensor_slices((X_train, y_train))

    dataset_val = tf.data.Dataset.from_tensor_slices((X_val, y_val))

    dataset_test = tf.data.Dataset.from_tensor_slices((X_test, y_test))

    def create_classifier(input_dim, output_dim):
        """Create a simple Keras model for binary classification."""
        model = K.Sequential(
            [
                K.layers.Dense(64, activation="relu", input_shape=(input_dim,)),
                K.layers.BatchNormalization(),
                K.layers.Dense(
                    output_dim, activation="sigmoid" if output_dim == 1 else "softmax"
                ),
            ]
        )
        return model

    model = TensorflowModel(
        model_class=create_classifier,
        model_params={
            "input_dim": X_train.shape[1],
            "output_dim": max(y_train) + 1,
        },
        loss_fn=K.losses.SparseCategoricalCrossentropy,
        optimizer=K.optimizers.Adam,
        optimizer_params={"learning_rate": 1e-3},
        epochs=10,
        batch_size=32,
        device=None,  # TensorFlow handles device automatically
    )

    accuracy = TensorflowMetric(SparseCategoricalAccuracy())

    metrics = [accuracy]

    print(model.get_params())

    tracker = TensorflowMLflowTracker(
        run_name="MlFlowWithVal",
        experiment_name="18_tensorflow_mlflow_example.py",
        extra_tags={"stage": "testing"},
    )

    tracked_model = tracker.wrap_fit(
        model,
        dataset_val=dataset_val,
        dataset_test=dataset_test,
        metrics=metrics,
        log_model=True,
    )

    tracked_model.fit(dataset=dataset)

    y_pred = tracked_model.predict(
        data=X_test,
        batch_size=32,
    )

    print(
        "Test metrics:",
        tracked_model.get_metrics(dataset=dataset_test, metrics=metrics),
    )