astrodata.ml.model_selection package

Submodules

astrodata.ml.model_selection.BaseMlModelSelector module

class astrodata.ml.model_selection.BaseMlModelSelector.BaseMlModelSelector

Bases: ABC

Abstract base class for model selection strategies.

Subclasses must implement methods for fitting the selector to data, retrieving the best model, and retrieving the best hyperparameters.

abstractmethod fit(X, y, *args, **kwargs)

Fit the model selector to data.

Parameters:
  • X (Any) – Training data features.

  • y (Any) – Training data targets.

  • *args – Additional arguments for fitting.

  • **kwargs – Additional arguments for fitting.

Returns:

Returns self.

Return type:

BaseMlModelSelector

abstractmethod get_best_metrics()

Return the best metrics found during selection.

Returns:

Dictionary of best metrics.

Return type:

dict

abstractmethod get_best_model()

Return the best model found during selection.

Returns:

The best model object.

Return type:

Any

abstractmethod get_best_params()

Return the best hyperparameters found during selection.

Returns:

Dictionary of best parameters.

Return type:

dict

abstractmethod get_params(**kwargs)

Return parameters of the selector. Can be optionally overridden.

Returns:

Dictionary of selector parameters.

Return type:

dict

astrodata.ml.model_selection.GridSearchSelector module

class astrodata.ml.model_selection.GridSearchSelector.GridSearchCVSelector(model, param_grid, scorer=None, cv=5, random_state=2140864672, metrics=None, tracker=None, log_all_models=False)

Bases: BaseMlModelSelector

GridSearchCVSelector performs exhaustive grid search over a parameter grid using cross-validation.

fit(X, y, X_test=None, y_test=None, *args, **kwargs)

Run grid search with cross-validation.

Parameters:
  • X (array-like) – Training data features.

  • y (array-like) – Training data targets.

  • X_test (array-like, optional) – Test data features for tracking/logging (not used in selection).

  • y_test (array-like, optional) – Test data targets for tracking/logging (not used in selection).

Returns:

self – Fitted selector.

Return type:

object

get_best_metrics()

Get the metrics for the best model averaged over cross-validation folds.

Returns:

Averaged metrics, or None if no metrics were specified.

Return type:

dict or None

get_best_model()

Get the best model fitted on all data using the best found parameters.

Returns:

The best fitted model.

Return type:

BaseMlModel

get_best_params()

Get the best parameter combination found during grid search.

Returns:

Best parameters.

Return type:

dict

get_params(**kwargs)

Get parameters of this selector instance.

Returns:

Parameters used to initialize this object.

Return type:

dict

class astrodata.ml.model_selection.GridSearchSelector.GridSearchSelector(model, param_grid, scorer=None, val_size=None, random_state=4282729121, metrics=None, tracker=None, log_all_models=False)

Bases: BaseMlModelSelector

GridSearchSelector performs exhaustive grid search over a parameter grid using a single validation split.

fit(X_train, y_train, X_val=None, y_val=None, X_test=None, y_test=None, *args, **kwargs)

Run grid search using a single train/validation split.

Parameters:
  • X_train (array-like) – Training data features.

  • y_train (array-like) – Training data targets.

  • X_val (array-like, optional) – Validation data features. If None, a random split is performed.

  • y_val (array-like, optional) – Validation data targets. If None, a random split is performed.

  • X_test (array-like, optional) – Test data features for tracking/logging (not used in selection).

  • y_test (array-like, optional) – Test data targets for tracking/logging (not used in selection).

Returns:

self – Fitted selector.

Return type:

object

Raises:

ValueError – If neither validation data nor val_size is provided.

get_best_metrics()

Get the metrics for the best model on validation data.

Returns:

Metrics for the best model, or None if no metrics were specified.

Return type:

dict or None

get_best_model()

Get the best model fitted on all data using the best found parameters.

Returns:

The best fitted model.

Return type:

BaseMlModel

get_best_params()

Get the best parameter combination found during grid search.

Returns:

Best parameters.

Return type:

dict

get_params(**kwargs)

Get parameters of this selector instance.

Returns:

Parameters used to initialize this object.

Return type:

dict

astrodata.ml.model_selection.HyperOptSelector module

class astrodata.ml.model_selection.HyperOptSelector.HyperOptSelector(param_space, scorer=None, use_cv=False, cv=2, val_size=0.2, max_evals=20, random_state=3885649030, metrics=None, tracker=None, log_all_models=False)

Bases: BaseMlModelSelector

HyperOptSelector performs hyperparameter optimization using hyperopt.

fit(X, y, X_val=None, y_val=None, X_test=None, y_test=None, *args, **kwargs)

Fit the model selector to data.

Parameters:
  • X (Any) – Training data features.

  • y (Any) – Training data targets.

  • *args – Additional arguments for fitting.

  • **kwargs – Additional arguments for fitting.

Returns:

Returns self.

Return type:

BaseMlModelSelector

get_best_metrics()

Get the metrics for the best model averaged over cross-validation folds.

Returns:

Averaged metrics, or None if no metrics were specified.

Return type:

dict or None

get_best_model()

Get the best model fitted on all data using the best found parameters.

Returns:

The best fitted model.

Return type:

BaseMlModel

get_best_params()

Get the best parameter combination found during grid search.

Returns:

Best parameters.

Return type:

dict

get_params(**kwargs)

Get parameters of this selector instance.

Returns:

Parameters used to initialize this object.

Return type:

dict

Module contents