models

The astrodata.ml.models module provides a unified interface for various machine learning models, abstracting away framework-specific details. It includes a base model class and specialized wrappers for the most common model frameworks.

Abstract Class

BaseMlModel is the abstract base class of any model and it defines the standard interface for all machine learning models within astrodata. Subclasses must implement:

  • fit(X, y, **kwargs): Trains the model.

  • predict(X, **kwargs): Generates predictions.

  • score(X, y, **kwargs): Computes a default score for the model.

  • get_metrics(X, y, metrics, **kwargs): Evaluates a list of specified metrics.

  • get_loss_history_metrics(X, y, metrics, **kwargs): Retrieves metric history during training, if supported.

  • save(filepath, **kwargs): Saves the trained model.

  • load(filepath, **kwargs): Loads a model from a file.

  • get_params(): Returns model hyperparameters.

  • set_params(**kwargs): Sets model hyperparameters.

  • clone(): Creates a shallow copy of the model.

How to Use

Initializing a Model

How a model is initialized depends on the framework of reference. In general, the goal should be to pass an existing model of the chosen framework and then let the class handle the generalization to the astrodata framework.

from sklearn.svm import LinearSVR
from astrodata.ml.models.SklearnModel import SklearnModel

model = SklearnModel(model_class=LinearSVR, random_state=42)

Looking at this example which initializes a scikit-learn model we can see that the model_class is passed along with any extra class specific argument (in this case random_state=42) to initialize the model.

Hint

Model initialization is the only framework-specific part of a model, all other methods are framework-agnostic, examples will be shown for a generic BaseMlModel.

Fitting a model

Once a model is initialized correctly, fitting it requires an X_train and a y_train to be passed through its fit method. The internal logic of the model should handle the rest of the training.

model.fit(X_train, y_train)

This will result in a set of weights to be computed for the model, which will be later used for predicting new values.

Predicting with a fitted model

A model that has been correctly fitted can be used for predictions by invoking its predict method.

y_pred = model.predict(x_test)

The output of the preidct method is an array containing the predicted lables for the given input.

Computing metrics

Given an array of BaseMetric that has been previously created, we can compute the metrics for our model by invoking its get_metrics method, this will output a dictionary with the computed values for each metric.

metrics = model.get_metrics(
        X_test,
        y_test,
        metrics=metrics,
    )

SklearnModel

Can be initialized using any of the models offered by the scikit-learn library, refer to the library documentation for more information on available models and model-specific hyperparameters.

XgboostModel

Can be initialized using any of the models offered by the xgboost library, refer to the library documentation for more information on available models and model-specific hyperparameters.

PytorchModel

Attention

To be implemented in future releases.

TensorflowModel

Attention

To be implemented in future releases.

Examples