Extending astrodata¶
The astrodata
library is designed with extensibility in mind, leveraging Python’s Abstract Base Classes (ABCs) to define clear interfaces for different components. This tutorial will explain what ABCs are, why they are used in astrodata
, and how you can extend the library by implementing your own custom metrics, models, and model selection strategies.
Understanding Abstract Base Classes (ABCs) in Python¶
An Abstract Base Class (ABC) in Python defines a blueprint for other classes. It allows you to specify methods that must be implemented by any concrete (non-abstract) subclass. This enforces a consistent structure and behavior across different implementations of a common concept.
Why Use ABCs?¶
Enforce Interfaces: ABCs ensure that all subclasses adhere to a specific interface. If a method is marked as abstract in the base class, any concrete subclass must provide an implementation for that method, otherwise, it cannot be instantiated.
Consistency and Predictability: By defining a common interface, ABCs make your code more predictable. Users of the library know what methods to expect from any class that implements a specific ABC, regardless of its underlying implementation.
Extensibility: They provide clear extension points. Developers know exactly what they need to implement to add new functionality (e.g., a new metric or a new model type) while remaining compatible with the rest of the framework.
Type Hinting and Static Analysis: ABCs work well with type hinting, allowing for more robust static analysis and clearer code documentation.
Following there will be a series of examples showing how to extend some of the astrodata
components.
Example 1, Extending astrodata.ml.metrics.BaseMetric
¶
The BaseMetric
abstract class defines the interface for all evaluation metrics within astrodata
. If you want to use a custom metric that isn’t covered by SklearnMetric
, you can create your own by inheriting from BaseMetric
.
BaseMetric.py
Abstract Methods:
__init__(self)
: The constructor.__call__(self, y_true: Any, y_pred: Any, **kwargs) -> float
: Computes the metric value.get_name(self) -> str
: Returns the name of the metric.greater_is_better(self) -> bool
(property): Indicates if a higher value of the metric is better.
Example: Creating a Custom Precision Metric
Let’s say you want a simple precision metric for binary classification where greater_is_better
is True
.
from astrodata.ml.metrics.BaseMetric import BaseMetric
from typing import Any
class CustomPrecisionMetric(BaseMetric):
def __init__(self, positive_label: Any = 1):
super().__init__()
self.positive_label = positive_label
def __call__(self, y_true: Any, y_pred: Any, **kwargs) -> float:
true_positives = 0
predicted_positives = 0
for true, pred in zip(y_true, y_pred):
if pred == self.positive_label:
predicted_positives += 1
if true == self.positive_label:
true_positives += 1
if predicted_positives == 0:
return 0.0 # Avoid division by zero
return true_positives / predicted_positives
def get_name(self) -> str:
return f"Precision_for_{self.positive_label}"
@property
def greater_is_better(self) -> bool:
return True
# Example Usage:
y_true = [0, 1, 0, 1, 0, 1]
y_pred = [0, 1, 1, 0, 0, 1]
precision_metric = CustomPrecisionMetric(positive_label=1)
score = precision_metric(y_true, y_pred)
print(f"Custom Precision Score: {score}")
print(f"Metric Name: {precision_metric.get_name()}")
print(f"Greater is better: {precision_metric.greater_is_better}")
Example 2, Extending astrodata.ml.models.BaseMlModel
¶
The BaseMlModel
abstract class defines the fundamental operations for any machine learning model in astrodata
. This allows the astrodata.ml.model_selection
module to work seamlessly with various model types, whether they are scikit-learn models, XGBoost models, or your own custom implementations.
BaseMlModel.py
Abstract Methods:
fit(self, X: Any, y: Any, **kwargs) -> "BaseMlModel"
: Trains the model.predict(self, X: Any, **kwargs) -> Any
: Generates predictions.score(self, X: Any, y: Any, **kwargs) -> float
: Computes a default score.get_metrics(self, X_test: Any, y_test: Any, metrics: List[BaseMetric], **kwargs) -> Dict[str, Any]
: Evaluates multiple metrics.get_loss_history_metrics(self, X_test: Any, y_test: Any, metrics: List[BaseMetric], **kwargs) -> Dict[str, Any]
: Retrieves metric history during training (optional, but part of the interface).save(self, filepath: str, **kwargs)
: Saves the model.load(self, filepath: str, **kwargs) -> "BaseMlModel"
: Loads a model.get_params(self, **kwargs) -> Dict[str, Any]
: Returns model hyperparameters.set_params(self, **kwargs) -> None
: Sets model hyperparameters.clone(self) -> "BaseMlModel"
: Creates a shallow copy.
Example: Creating a Simple Custom Majority Class Classifier Model
This example is simplified for illustration purposes. A real custom model would involve more complex machine learning logic.
from astrodata.ml.models.BaseMlModel import BaseMlModel
from astrodata.ml.metrics.BaseMetric import BaseMetric
from astrodata.ml.metrics.SklearnMetric import SklearnMetric
from sklearn.metrics import accuracy_score
from typing import Any, Dict, List, Optional
import collections
import joblib
class MajorityClassClassifier(BaseMlModel):
def __init__(self, random_state: int = 42):
super().__init__()
self.majority_class = None
self.random_state = random_state # Although not used for randomness in this simple model, it's good practice for consistency
self.model_class = self.__class__ # For clone and load methods
self.model_params = {}
def fit(self, X: Any, y: Any, **kwargs) -> "MajorityClassClassifier":
# Find the most frequent class in y
counts = collections.Counter(y)
self.majority_class = counts.most_common(1)[0][0]
return self
def predict(self, X: Any, **kwargs) -> Any:
if self.majority_class is None:
raise RuntimeError("Model has not been fitted yet.")
# Predict the majority class for all inputs
return [self.majority_class] * len(X)
def score(self, X: Any, y: Any, scorer: Optional[BaseMetric] = None, **kwargs) -> float:
if self.majority_class is None:
raise RuntimeError("Model has not been fitted yet.")
predictions = self.predict(X)
if scorer is None:
# Default scorer
scorer = SklearnMetric(metric=accuracy_score, name="accuracy", greater_is_better=True)
return scorer(y, predictions)
def get_metrics(self, X: Any, y: Any, metrics: List[BaseMetric], **kwargs) -> Dict[str, Any]:
if self.majority_class is None:
raise RuntimeError("Model has not been fitted yet.")
predictions = self.predict(X)
results = {}
for metric in metrics:
results[metric.get_name()] = metric(y, predictions)
return results
def get_loss_history_metrics(self, X: Any, y: Any, metrics: List[BaseMetric], **kwargs) -> Dict[str, Any]:
# This simple model does not have a loss history, so we'll raise an error or return empty
# In a real model, this would track performance over epochs/iterations
raise AttributeError("MajorityClassClassifier does not support loss history.")
@property
def has_loss_history(self) -> bool:
return False
def save(self, filepath: str, **kwargs):
joblib.dump(self.majority_class, filepath)
def load(self, filepath: str, **kwargs) -> "MajorityClassClassifier":
self.majority_class = joblib.load(filepath)
return self
def get_params(self, **kwargs) -> Dict[str, Any]:
return {"random_state": self.random_state}
def set_params(self, **kwargs) -> None:
if "random_state" in kwargs:
self.random_state = kwargs["random_state"]
def clone(self) -> "MajorityClassClassifier":
new_instance = MajorityClassClassifier(random_state=self.random_state)
return new_instance
# Example Usage:
X_train = [[1], [2], [3], [4], [5]]
y_train = [0, 0, 1, 0, 1] # Majority class is 0
X_test = [[6], [7]]
y_test = [1, 0]
model = MajorityClassClassifier()
model.fit(X_train, y_train)
predictions = model.predict(X_test)
print(f"Predictions: {predictions}") # Expected: [0, 0]
accuracy = SklearnMetric(metric=accuracy_score)
score = model.score(X_test, y_test, scorer=accuracy)
print(f"Accuracy: {score}") # Expected: 0.5 (one correct, one incorrect)
model.save("majority_classifier.joblib")
loaded_model = MajorityClassClassifier().load("majority_classifier.joblib")
print(f"Loaded model majority class: {loaded_model.majority_class}")
Example 3, Extending astrodata.ml.model_selection.BaseMlModelSelector
¶
The BaseMlModelSelector
abstract class provides the foundation for any model selection strategy (e.g., Grid Search, Hyperparameter Optimization). To create a new model selection algorithm, you would inherit from this class and implement its abstract methods.
BaseMlModelSelector.py
Abstract Methods:
fit(self, X: Any, y: Any, *args, **kwargs) -> "BaseMlModelSelector"
: Runs the model selection process.get_best_model(self) -> BaseMlModel
: Returns the best model found.get_best_params(self) -> Dict[str, Any]
: Returns the hyperparameters of the best model.get_best_metrics(self) -> Dict[str, Any]
: Returns the evaluation metrics for the best model.get_params(self, **kwargs) -> Dict[str, Any]
: Returns the parameters of the selector itself.
Conceptual Example: Implementing a Simple Random Search Selector
Instead of a full code example (which would be quite extensive), let’s outline the conceptual approach for a RandomSearchSelector
:
from astrodata.ml.model_selection.BaseMlModelSelector import BaseMlModelSelector
from astrodata.ml.models.BaseMlModel import BaseMlModel
from astrodata.ml.metrics.BaseMetric import BaseMetric
from typing import Any, Dict, List, Optional
import random
class RandomSearchSelector(BaseMlModelSelector):
def __init__(
self,
model: BaseMlModel,
param_distributions: dict, # Dictionary with parameter names as keys and distributions/lists as values
scorer: BaseMetric,
n_iter: int = 10,
val_size: float = 0.2,
random_state: int = 42,
metrics: Optional[List[BaseMetric]] = None,
tracker: Any = None, # Placeholder for a tracking object
log_all_models: bool = False,
):
super().__init__()
self.model = model
self.param_distributions = param_distributions
self.scorer = scorer
self.n_iter = n_iter
self.val_size = val_size
self.random_state = random_state
self.metrics = metrics if metrics is not None else []
self.tracker = tracker
self.log_all_models = log_all_models
self._best_model = None
self._best_params = None
self._best_metrics = None
self._best_score = float('-inf') if scorer.greater_is_better else float('inf')
def fit(self, X: Any, y: Any, *args, **kwargs) -> "RandomSearchSelector":
# Split data into training and validation sets
# X_train_split, X_val_split, y_train_split, y_val_split = train_test_split(...)
random.seed(self.random_state)
for i in range(self.n_iter):
# 1. Sample hyperparameters randomly from param_distributions
# sampled_params = self._sample_params(self.param_distributions)
# 2. Clone the base model and set sampled parameters
# current_model = self.model.clone()
# current_model.set_params(**sampled_params)
# 3. Fit the model on training split and evaluate on validation split
# current_model.fit(X_train_split, y_train_split)
# current_score = current_model.score(X_val_split, y_val_split, scorer=self.scorer)
# current_metrics = current_model.get_metrics(X_val_split, y_val_split, metrics=self.metrics)
# 4. Log results using tracker if available (similar to HyperOptSelector example)
# if self.tracker:
# with self.tracker.start_run(nested=True, tags={"hp_iteration": i}):
# self.tracker.log_params(sampled_params)
# self.tracker.log_metrics({self.scorer.get_name(): current_score})
# self.tracker.log_metrics(current_metrics)
# if self.log_all_models:
# self.tracker.log_model(current_model, "model")
# 5. Update best model if current model is better
# if (self.scorer.greater_is_better and current_score > self._best_score) or \
# (not self.scorer.greater_is_better and current_score < self._best_score):
# self._best_score = current_score
# self._best_params = sampled_params
# self._best_metrics = current_metrics
# self._best_model = current_model.clone() # Keep a clone of the best model
# After all iterations, if a best model was found, fit it on the full training data
# if self._best_model:
# self._best_model.fit(X, y)
return self
def get_best_model(self) -> BaseMlModel:
return self._best_model
def get_best_params(self) -> Dict[str, Any]:
return self._best_params
def get_best_metrics(self) -> Dict[str, Any]:
return self._best_metrics
def get_params(self, **kwargs) -> Dict[str, Any]:
return {
"model": self.model,
"param_distributions": self.param_distributions,
"scorer": self.scorer,
"n_iter": self.n_iter,
"val_size": self.val_size,
"random_state": self.random_state,
"metrics": self.metrics,
"tracker": self.tracker,
"log_all_models": self.log_all_models,
}
# Helper method for sampling parameters (would be implemented internally)
def _sample_params(self, param_distributions: dict) -> dict:
sampled = {}
for param, dist in param_distributions.items():
if isinstance(dist, list):
sampled[param] = random.choice(dist)
# Add logic for different distribution types (e.g., uniform, normal)
# For example: if isinstance(dist, tuple) and dist[0] == 'uniform':
# sampled[param] = random.uniform(dist[1], dist[2])
else:
sampled[param] = dist # If it's a fixed value
return sampled
This conceptual example shows how you would structure a new model selector by implementing the abstract methods and incorporating your specific search logic (in this case, random sampling). The general flow involves iterating through different parameter combinations, training and evaluating models, and keeping track of the best-performing one.