astrodata.ml.model_selection package¶
Submodules¶
astrodata.ml.model_selection.BaseMlModelSelector module¶
- class astrodata.ml.model_selection.BaseMlModelSelector.BaseMlModelSelector¶
Bases:
ABC
Abstract base class for model selection strategies.
Subclasses must implement methods for fitting the selector to data, retrieving the best model, and retrieving the best hyperparameters.
- abstractmethod fit(X, y, *args, **kwargs)¶
Fit the model selector to data.
- Parameters:
X (Any) – Training data features.
y (Any) – Training data targets.
*args – Additional arguments for fitting.
**kwargs – Additional arguments for fitting.
- Returns:
Returns self.
- Return type:
- abstractmethod get_best_metrics()¶
Return the best metrics found during selection.
- Returns:
Dictionary of best metrics.
- Return type:
dict
- abstractmethod get_best_model()¶
Return the best model found during selection.
- Returns:
The best model object.
- Return type:
Any
- abstractmethod get_best_params()¶
Return the best hyperparameters found during selection.
- Returns:
Dictionary of best parameters.
- Return type:
dict
- abstractmethod get_params(**kwargs)¶
Return parameters of the selector. Can be optionally overridden.
- Returns:
Dictionary of selector parameters.
- Return type:
dict
astrodata.ml.model_selection.GridSearchSelector module¶
- class astrodata.ml.model_selection.GridSearchSelector.GridSearchCVSelector(model, param_grid, scorer=None, cv=5, random_state=2140864672, metrics=None, tracker=None, log_all_models=False)¶
Bases:
BaseMlModelSelector
GridSearchCVSelector performs exhaustive grid search over a parameter grid using cross-validation.
- fit(X, y, X_test=None, y_test=None, *args, **kwargs)¶
Run grid search with cross-validation.
- Parameters:
X (array-like) – Training data features.
y (array-like) – Training data targets.
X_test (array-like, optional) – Test data features for tracking/logging (not used in selection).
y_test (array-like, optional) – Test data targets for tracking/logging (not used in selection).
- Returns:
self – Fitted selector.
- Return type:
object
- get_best_metrics()¶
Get the metrics for the best model averaged over cross-validation folds.
- Returns:
Averaged metrics, or None if no metrics were specified.
- Return type:
dict or None
- get_best_model()¶
Get the best model fitted on all data using the best found parameters.
- Returns:
The best fitted model.
- Return type:
- get_best_params()¶
Get the best parameter combination found during grid search.
- Returns:
Best parameters.
- Return type:
dict
- get_params(**kwargs)¶
Get parameters of this selector instance.
- Returns:
Parameters used to initialize this object.
- Return type:
dict
- class astrodata.ml.model_selection.GridSearchSelector.GridSearchSelector(model, param_grid, scorer=None, val_size=None, random_state=4282729121, metrics=None, tracker=None, log_all_models=False)¶
Bases:
BaseMlModelSelector
GridSearchSelector performs exhaustive grid search over a parameter grid using a single validation split.
- fit(X_train, y_train, X_val=None, y_val=None, X_test=None, y_test=None, *args, **kwargs)¶
Run grid search using a single train/validation split.
- Parameters:
X_train (array-like) – Training data features.
y_train (array-like) – Training data targets.
X_val (array-like, optional) – Validation data features. If None, a random split is performed.
y_val (array-like, optional) – Validation data targets. If None, a random split is performed.
X_test (array-like, optional) – Test data features for tracking/logging (not used in selection).
y_test (array-like, optional) – Test data targets for tracking/logging (not used in selection).
- Returns:
self – Fitted selector.
- Return type:
object
- Raises:
ValueError – If neither validation data nor val_size is provided.
- get_best_metrics()¶
Get the metrics for the best model on validation data.
- Returns:
Metrics for the best model, or None if no metrics were specified.
- Return type:
dict or None
- get_best_model()¶
Get the best model fitted on all data using the best found parameters.
- Returns:
The best fitted model.
- Return type:
- get_best_params()¶
Get the best parameter combination found during grid search.
- Returns:
Best parameters.
- Return type:
dict
- get_params(**kwargs)¶
Get parameters of this selector instance.
- Returns:
Parameters used to initialize this object.
- Return type:
dict
astrodata.ml.model_selection.HyperOptSelector module¶
- class astrodata.ml.model_selection.HyperOptSelector.HyperOptSelector(param_space, scorer=None, use_cv=False, cv=2, val_size=0.2, max_evals=20, random_state=3885649030, metrics=None, tracker=None, log_all_models=False)¶
Bases:
BaseMlModelSelector
HyperOptSelector performs hyperparameter optimization using hyperopt.
- fit(X, y, X_val=None, y_val=None, X_test=None, y_test=None, *args, **kwargs)¶
Fit the model selector to data.
- Parameters:
X (Any) – Training data features.
y (Any) – Training data targets.
*args – Additional arguments for fitting.
**kwargs – Additional arguments for fitting.
- Returns:
Returns self.
- Return type:
- get_best_metrics()¶
Get the metrics for the best model averaged over cross-validation folds.
- Returns:
Averaged metrics, or None if no metrics were specified.
- Return type:
dict or None
- get_best_model()¶
Get the best model fitted on all data using the best found parameters.
- Returns:
The best fitted model.
- Return type:
- get_best_params()¶
Get the best parameter combination found during grid search.
- Returns:
Best parameters.
- Return type:
dict
- get_params(**kwargs)¶
Get parameters of this selector instance.
- Returns:
Parameters used to initialize this object.
- Return type:
dict