astrodata.ml.model_selection package¶
Submodules¶
astrodata.ml.model_selection.BaseMlModelSelector module¶
- class astrodata.ml.model_selection.BaseMlModelSelector.BaseMlModelSelector¶
Bases:
ABCAbstract base class for model selection strategies.
Subclasses must implement methods for fitting the selector to data, retrieving the best model, and retrieving the best hyperparameters.
- abstractmethod fit(X, y, *args, **kwargs)¶
Fit the model selector to data.
- Parameters:
X (Any) – Training data features.
y (Any) – Training data targets.
*args – Additional arguments for fitting.
**kwargs – Additional arguments for fitting.
- Returns:
Returns self.
- Return type:
- abstractmethod get_best_metrics()¶
Return the best metrics found during selection.
- Returns:
Dictionary of best metrics.
- Return type:
dict
- abstractmethod get_best_model()¶
Return the best model found during selection.
- Returns:
The best model object.
- Return type:
Any
- abstractmethod get_best_params()¶
Return the best hyperparameters found during selection.
- Returns:
Dictionary of best parameters.
- Return type:
dict
- abstractmethod get_params(**kwargs)¶
Return parameters of the selector. Can be optionally overridden.
- Returns:
Dictionary of selector parameters.
- Return type:
dict
astrodata.ml.model_selection.GridSearchSelector module¶
- class astrodata.ml.model_selection.GridSearchSelector.GridSearchCVSelector(model, param_grid, scorer=None, cv=5, random_state=2896307122, metrics=None, tracker=None, log_all_models=False)¶
Bases:
BaseMlModelSelectorGridSearchCVSelector performs exhaustive grid search over a parameter grid using cross-validation.
- fit(X, y, X_test=None, y_test=None, *args, **kwargs)¶
Run grid search with cross-validation.
- Parameters:
X (array-like) – Training data features.
y (array-like) – Training data targets.
X_test (array-like, optional) – Test data features for tracking/logging (not used in selection).
y_test (array-like, optional) – Test data targets for tracking/logging (not used in selection).
- Returns:
self – Fitted selector.
- Return type:
object
- get_best_metrics()¶
Get the metrics for the best model averaged over cross-validation folds.
- Returns:
Averaged metrics, or None if no metrics were specified.
- Return type:
dict or None
- get_best_model()¶
Get the best model fitted on all data using the best found parameters.
- Returns:
The best fitted model.
- Return type:
- get_best_params()¶
Get the best parameter combination found during grid search.
- Returns:
Best parameters.
- Return type:
dict
- get_params(**kwargs)¶
Get parameters of this selector instance.
- Returns:
Parameters used to initialize this object.
- Return type:
dict
- class astrodata.ml.model_selection.GridSearchSelector.GridSearchSelector(model, param_grid, scorer=None, val_size=None, random_state=673715504, metrics=None, tracker=None, log_all_models=False)¶
Bases:
BaseMlModelSelectorGridSearchSelector performs exhaustive grid search over a parameter grid using a single validation split.
- fit(X_train=None, y_train=None, X_val=None, y_val=None, X_test=None, y_test=None, dataset_train=None, dataset_val=None, dataset_test=None, *args, **kwargs)¶
Run grid search using a single train/validation split.
- Parameters:
X_train (array-like, optional) – Training data features. Ignored if dataset_train is provided.
y_train (array-like, optional) – Training data targets. Ignored if dataset_train is provided.
X_val (array-like, optional) – Validation data features. If None and dataset_val is None, a random split is performed.
y_val (array-like, optional) – Validation data targets. If None and dataset_val is None, a random split is performed.
X_test (array-like, optional) – Test data features for tracking/logging (not used in selection).
y_test (array-like, optional) – Test data targets for tracking/logging (not used in selection).
dataset_train (Dataset, optional) – Training dataset. If provided, X_train and y_train are ignored.
dataset_val (Dataset, optional) – Validation dataset. If provided, X_val and y_val are ignored.
dataset_test (Dataset, optional) – Test dataset for tracking/logging (not used in selection).
- Returns:
self – Fitted selector.
- Return type:
object
- Raises:
ValueError – If neither validation data nor val_size is provided, or if neither (X_train, y_train) nor dataset_train is provided.
- get_best_metrics()¶
Get the metrics for the best model on validation data.
- Returns:
Metrics for the best model, or None if no metrics were specified.
- Return type:
dict or None
- get_best_model()¶
Get the best model fitted on all data using the best found parameters.
- Returns:
The best fitted model.
- Return type:
- get_best_params()¶
Get the best parameter combination found during grid search.
- Returns:
Best parameters.
- Return type:
dict
- get_params(**kwargs)¶
Get parameters of this selector instance.
- Returns:
Parameters used to initialize this object.
- Return type:
dict
astrodata.ml.model_selection.HyperOptSelector module¶
- class astrodata.ml.model_selection.HyperOptSelector.HyperOptSelector(param_space, scorer=None, use_cv=False, cv=2, val_size=0.2, max_evals=20, random_state=3424158792, metrics=None, tracker=None, log_all_models=False)¶
Bases:
BaseMlModelSelectorHyperOptSelector performs hyperparameter optimization using hyperopt.
- fit(X=None, y=None, X_val=None, y_val=None, X_test=None, y_test=None, dataset_train=None, dataset_val=None, dataset_test=None, *args, **kwargs)¶
Fit the model selector to data.
- Parameters:
X (Any) – Training data features.
y (Any) – Training data targets.
*args – Additional arguments for fitting.
**kwargs – Additional arguments for fitting.
- Returns:
Returns self.
- Return type:
- get_best_metrics()¶
Get the metrics for the best model averaged over cross-validation folds.
- Returns:
Averaged metrics, or None if no metrics were specified.
- Return type:
dict or None
- get_best_model()¶
Get the best model fitted on all data using the best found parameters.
- Returns:
The best fitted model.
- Return type:
- get_best_params()¶
Get the best parameter combination found during grid search.
- Returns:
Best parameters.
- Return type:
dict
- get_params(**kwargs)¶
Get parameters of this selector instance.
- Returns:
Parameters used to initialize this object.
- Return type:
dict