model selection¶
The astrodata.ml.model_selection module provides tools for systematically finding the best machine learning model and hyperparameters for a given task using different heuristics or strategies.
Abstract Class¶
BaseMlModelSelector is the abstract base class of any model selection strategy. Subclasses must implement:
fit(X, y, *args, **kwargs): Runs the model selection process on the data.get_best_model(): Returns the bestBaseMlModelinstance found.get_best_params(): Returns the hyperparameters of the best model.get_best_metrics(): Returns the evaluation metrics for the best model.get_params(): Returns the parameters of the selector itself.
How to Use¶
Initializing¶
Initialization depends on the selector that is being used; generally, a model selector is initialized with a model to perform the search on and a grid of parameters to test.
from astrodata.ml.model_selection.GridSearchSelector import GridSearchCVSelector
gss = GridSearchCVSelector(
model=model,
#tracker=tracker,
param_grid={
"C": [0.1, 1, 10],
"max_iter": [1000, 2000],
"tol": [1e-3, 1e-4],
},
scorer=accuracy,
cv=5,
random_state=42,
metrics=None,
)
The scorer parameter of the selector is a BaseMetric and it is used to decide what model is the best by computing said metric and using it as a discriminator. Optionally, a list of metrics can be passed as an argument to compute said metrics at each step and at the end (this is particularly relevant when a tracker is added as those metrics will be saved in MlFlow, check this section for more info).
Attention
Depending on the chosen model selector, the param_grid may change.
After a selector is initialized, the next step is to fit it to a set of data, doing so the selector tries all the required combinations and finally fits the model whose parameters returned the best results.
best_model = gss.fit(X_train, y_train)
print(f"Best parameters found: {gss.get_best_params()}")
print(f"Best metrics: {gss.get_best_metrics()}")
print(f"Best model: {best_model.get_params()}")
GridSearchSelector¶
Implements an exhaustive search over a specified parameter grid. It trains and evaluates models for every combination of hyperparameters, selecting the one that performs best according to a given scorer. It supports both single validation split and cross-validation if using GridSearchCVSelector.
Parameters¶
model : BaseMlModel
The model to optimize.
param_grid : dict
Dictionary with parameters names (str) as keys and lists of parameter settings to try as values.
scorer : BaseMetric, optional
The metric used to select the best model. If None, model’s default score method is used.
val_size (for GridSearchSelector): float, optional (default None)
Fraction of training data to use as validation split.
cv (for GridSearchCVSelector): int or cross-validation splitter (default=5)
Number of folds (int) or an object that yields train/test splits.
random_state : int, optional
Random seed for reproducibility.
metrics : list of BaseMetric, optional
Additional metrics to evaluate on validation set.
tracker : ModelTracker, optional
Optional experiment/model tracker for logging.
log_all_models : bool, optional
If True, logs all models to the tracker, not just the best one.
HyperOptSelector¶
Utilizes the hyperopt library for efficient hyperparameter optimization. Instead of exhaustive search, hyperopt uses Bayesian optimization (Tree-structured Parzen Estimator, TPE) to intelligently explore the parameter space, often finding better results with fewer evaluations compared to traditional grid search. It requires a param_space defined using hyperopt.hp functions.
# Define the hyperopt search space
param_space = {
"model": hp.choice("model", [model]),
"C": hp.choice("C", [0.1, 1, 10]),
"max_iter": hp.choice("max_iter", [1000, 2000]),
"tol": hp.choice("tol", [1e-3, 1e-4]),
}
Parameters¶
param_grid : dict
Dictionary with parameter search spaces as shown here.
scorer : BaseMetric, optional
The metric used to select the best model. If None, model’s default score method is used.
use_cv: bool
Wether to use cross validation or regular validation split.
cv : int or cross-validation splitter (default=5)
Number of folds (int) or an object that yields train/test splits.
max_evals: int
Maximum number of evaluations hyperopt can run.
random_state : int, optional
Random seed for reproducibility.
metrics : list of BaseMetric, optional
Additional metrics to evaluate on validation folds.
tracker : ModelTracker, optional
Optional experiment/model tracker for logging.
log_all_models : bool, optional
If True, logs all models, not just the best one.