Extending astrodata

The astrodata library is designed with extensibility in mind, leveraging Python’s Abstract Base Classes (ABCs) to define clear interfaces for different components. This tutorial will explain what ABCs are, why they are used in astrodata, and how you can extend the library by implementing your own custom metrics, models, and model selection strategies.

Understanding Abstract Base Classes (ABCs) in Python

An Abstract Base Class (ABC) in Python defines a blueprint for other classes. It allows you to specify methods that must be implemented by any concrete (non-abstract) subclass. This enforces a consistent structure and behavior across different implementations of a common concept.

Why Use ABCs?

  1. Enforce Interfaces: ABCs ensure that all subclasses adhere to a specific interface. If a method is marked as abstract in the base class, any concrete subclass must provide an implementation for that method, otherwise, it cannot be instantiated.

  2. Consistency and Predictability: By defining a common interface, ABCs make your code more predictable. Users of the library know what methods to expect from any class that implements a specific ABC, regardless of its underlying implementation.

  3. Extensibility: They provide clear extension points. Developers know exactly what they need to implement to add new functionality (e.g., a new metric or a new model type) while remaining compatible with the rest of the framework.

  4. Type Hinting and Static Analysis: ABCs work well with type hinting, allowing for more robust static analysis and clearer code documentation.

Following there will be a series of examples showing how to extend some of the astrodata components.

Example 1, Extending astrodata.ml.metrics.BaseMetric

The BaseMetric abstract class defines the interface for all evaluation metrics within astrodata. If you want to use a custom metric that isn’t covered by SklearnMetric, you can create your own by inheriting from BaseMetric.

BaseMetric.py Abstract Methods:

  • __init__(self): The constructor.

  • __call__(self, y_true: Any, y_pred: Any, **kwargs) -> float: Computes the metric value.

  • get_name(self) -> str: Returns the name of the metric.

  • greater_is_better(self) -> bool (property): Indicates if a higher value of the metric is better.

Example: Creating a Custom Precision Metric

Let’s say you want a simple precision metric for binary classification where greater_is_better is True.

from astrodata.ml.metrics.BaseMetric import BaseMetric
from typing import Any

class CustomPrecisionMetric(BaseMetric):
    def __init__(self, positive_label: Any = 1):
        super().__init__()
        self.positive_label = positive_label

    def __call__(self, y_true: Any, y_pred: Any, **kwargs) -> float:
        true_positives = 0
        predicted_positives = 0
        for true, pred in zip(y_true, y_pred):
            if pred == self.positive_label:
                predicted_positives += 1
                if true == self.positive_label:
                    true_positives += 1
        if predicted_positives == 0:
            return 0.0 # Avoid division by zero
        return true_positives / predicted_positives

    def get_name(self) -> str:
        return f"Precision_for_{self.positive_label}"

    @property
    def greater_is_better(self) -> bool:
        return True

# Example Usage:
y_true = [0, 1, 0, 1, 0, 1]
y_pred = [0, 1, 1, 0, 0, 1]

precision_metric = CustomPrecisionMetric(positive_label=1)
score = precision_metric(y_true, y_pred)
print(f"Custom Precision Score: {score}")
print(f"Metric Name: {precision_metric.get_name()}")
print(f"Greater is better: {precision_metric.greater_is_better}")

Example 2, Extending astrodata.ml.models.BaseMlModel

The BaseMlModel abstract class defines the fundamental operations for any machine learning model in astrodata. This allows the astrodata.ml.model_selection module to work seamlessly with various model types, whether they are scikit-learn models, XGBoost models, or your own custom implementations.

BaseMlModel.py Abstract Methods:

  • fit(self, X: Any, y: Any, **kwargs) -> "BaseMlModel": Trains the model.

  • predict(self, X: Any, **kwargs) -> Any: Generates predictions.

  • score(self, X: Any, y: Any, **kwargs) -> float: Computes a default score.

  • get_metrics(self, X_test: Any, y_test: Any, metrics: List[BaseMetric], **kwargs) -> Dict[str, Any]: Evaluates multiple metrics.

  • get_loss_history_metrics(self, X_test: Any, y_test: Any, metrics: List[BaseMetric], **kwargs) -> Dict[str, Any]: Retrieves metric history during training (optional, but part of the interface).

  • save(self, filepath: str, **kwargs): Saves the model.

  • load(self, filepath: str, **kwargs) -> "BaseMlModel": Loads a model.

  • get_params(self, **kwargs) -> Dict[str, Any]: Returns model hyperparameters.

  • set_params(self, **kwargs) -> None: Sets model hyperparameters.

  • clone(self) -> "BaseMlModel": Creates a shallow copy.

Example: Creating a Simple Custom Majority Class Classifier Model

This example is simplified for illustration purposes. A real custom model would involve more complex machine learning logic.

from astrodata.ml.models.BaseMlModel import BaseMlModel
from astrodata.ml.metrics.BaseMetric import BaseMetric
from astrodata.ml.metrics.SklearnMetric import SklearnMetric
from sklearn.metrics import accuracy_score
from typing import Any, Dict, List, Optional
import collections
import joblib

class MajorityClassClassifier(BaseMlModel):
    def __init__(self, random_state: int = 42):
        super().__init__()
        self.majority_class = None
        self.random_state = random_state # Although not used for randomness in this simple model, it's good practice for consistency
        self.model_class = self.__class__ # For clone and load methods
        self.model_params = {}

    def fit(self, X: Any, y: Any, **kwargs) -> "MajorityClassClassifier":
        # Find the most frequent class in y
        counts = collections.Counter(y)
        self.majority_class = counts.most_common(1)[0][0]
        return self

    def predict(self, X: Any, **kwargs) -> Any:
        if self.majority_class is None:
            raise RuntimeError("Model has not been fitted yet.")
        # Predict the majority class for all inputs
        return [self.majority_class] * len(X)

    def score(self, X: Any, y: Any, scorer: Optional[BaseMetric] = None, **kwargs) -> float:
        if self.majority_class is None:
            raise RuntimeError("Model has not been fitted yet.")
        predictions = self.predict(X)
        if scorer is None:
            # Default scorer
            scorer = SklearnMetric(metric=accuracy_score, name="accuracy", greater_is_better=True)
        return scorer(y, predictions)

    def get_metrics(self, X: Any, y: Any, metrics: List[BaseMetric], **kwargs) -> Dict[str, Any]:
        if self.majority_class is None:
            raise RuntimeError("Model has not been fitted yet.")
        predictions = self.predict(X)
        results = {}
        for metric in metrics:
            results[metric.get_name()] = metric(y, predictions)
        return results

    def get_loss_history_metrics(self, X: Any, y: Any, metrics: List[BaseMetric], **kwargs) -> Dict[str, Any]:
        # This simple model does not have a loss history, so we'll raise an error or return empty
        # In a real model, this would track performance over epochs/iterations
        raise AttributeError("MajorityClassClassifier does not support loss history.")

    @property
    def has_loss_history(self) -> bool:
        return False

    def save(self, filepath: str, **kwargs):
        joblib.dump(self.majority_class, filepath)

    def load(self, filepath: str, **kwargs) -> "MajorityClassClassifier":
        self.majority_class = joblib.load(filepath)
        return self

    def get_params(self, **kwargs) -> Dict[str, Any]:
        return {"random_state": self.random_state}

    def set_params(self, **kwargs) -> None:
        if "random_state" in kwargs:
            self.random_state = kwargs["random_state"]

    def clone(self) -> "MajorityClassClassifier":
        new_instance = MajorityClassClassifier(random_state=self.random_state)
        return new_instance

# Example Usage:
X_train = [[1], [2], [3], [4], [5]]
y_train = [0, 0, 1, 0, 1] # Majority class is 0

X_test = [[6], [7]]
y_test = [1, 0]

model = MajorityClassClassifier()
model.fit(X_train, y_train)

predictions = model.predict(X_test)
print(f"Predictions: {predictions}") # Expected: [0, 0]

accuracy = SklearnMetric(metric=accuracy_score)
score = model.score(X_test, y_test, scorer=accuracy)
print(f"Accuracy: {score}") # Expected: 0.5 (one correct, one incorrect)

model.save("majority_classifier.joblib")
loaded_model = MajorityClassClassifier().load("majority_classifier.joblib")
print(f"Loaded model majority class: {loaded_model.majority_class}")

Example 3, Extending astrodata.ml.model_selection.BaseMlModelSelector

The BaseMlModelSelector abstract class provides the foundation for any model selection strategy (e.g., Grid Search, Hyperparameter Optimization). To create a new model selection algorithm, you would inherit from this class and implement its abstract methods.

BaseMlModelSelector.py Abstract Methods:

  • fit(self, X: Any, y: Any, *args, **kwargs) -> "BaseMlModelSelector": Runs the model selection process.

  • get_best_model(self) -> BaseMlModel: Returns the best model found.

  • get_best_params(self) -> Dict[str, Any]: Returns the hyperparameters of the best model.

  • get_best_metrics(self) -> Dict[str, Any]: Returns the evaluation metrics for the best model.

  • get_params(self, **kwargs) -> Dict[str, Any]: Returns the parameters of the selector itself.

Conceptual Example: Implementing a Simple Random Search Selector

Instead of a full code example (which would be quite extensive), let’s outline the conceptual approach for a RandomSearchSelector:

from astrodata.ml.model_selection.BaseMlModelSelector import BaseMlModelSelector
from astrodata.ml.models.BaseMlModel import BaseMlModel
from astrodata.ml.metrics.BaseMetric import BaseMetric
from typing import Any, Dict, List, Optional
import random

class RandomSearchSelector(BaseMlModelSelector):
    def __init__(
        self,
        model: BaseMlModel,
        param_distributions: dict, # Dictionary with parameter names as keys and distributions/lists as values
        scorer: BaseMetric,
        n_iter: int = 10,
        val_size: float = 0.2,
        random_state: int = 42,
        metrics: Optional[List[BaseMetric]] = None,
        tracker: Any = None, # Placeholder for a tracking object
        log_all_models: bool = False,
    ):
        super().__init__()
        self.model = model
        self.param_distributions = param_distributions
        self.scorer = scorer
        self.n_iter = n_iter
        self.val_size = val_size
        self.random_state = random_state
        self.metrics = metrics if metrics is not None else []
        self.tracker = tracker
        self.log_all_models = log_all_models

        self._best_model = None
        self._best_params = None
        self._best_metrics = None
        self._best_score = float('-inf') if scorer.greater_is_better else float('inf')


    def fit(self, X: Any, y: Any, *args, **kwargs) -> "RandomSearchSelector":
        # Split data into training and validation sets
        # X_train_split, X_val_split, y_train_split, y_val_split = train_test_split(...)

        random.seed(self.random_state)

        for i in range(self.n_iter):
            # 1. Sample hyperparameters randomly from param_distributions
            # sampled_params = self._sample_params(self.param_distributions)

            # 2. Clone the base model and set sampled parameters
            # current_model = self.model.clone()
            # current_model.set_params(**sampled_params)

            # 3. Fit the model on training split and evaluate on validation split
            # current_model.fit(X_train_split, y_train_split)
            # current_score = current_model.score(X_val_split, y_val_split, scorer=self.scorer)
            # current_metrics = current_model.get_metrics(X_val_split, y_val_split, metrics=self.metrics)

            # 4. Log results using tracker if available (similar to HyperOptSelector example)
            # if self.tracker:
            #     with self.tracker.start_run(nested=True, tags={"hp_iteration": i}):
            #         self.tracker.log_params(sampled_params)
            #         self.tracker.log_metrics({self.scorer.get_name(): current_score})
            #         self.tracker.log_metrics(current_metrics)
            #         if self.log_all_models:
            #             self.tracker.log_model(current_model, "model")

            # 5. Update best model if current model is better
            # if (self.scorer.greater_is_better and current_score > self._best_score) or \
            #    (not self.scorer.greater_is_better and current_score < self._best_score):
            #     self._best_score = current_score
            #     self._best_params = sampled_params
            #     self._best_metrics = current_metrics
            #     self._best_model = current_model.clone() # Keep a clone of the best model

        # After all iterations, if a best model was found, fit it on the full training data
        # if self._best_model:
        #     self._best_model.fit(X, y)

        return self

    def get_best_model(self) -> BaseMlModel:
        return self._best_model

    def get_best_params(self) -> Dict[str, Any]:
        return self._best_params

    def get_best_metrics(self) -> Dict[str, Any]:
        return self._best_metrics

    def get_params(self, **kwargs) -> Dict[str, Any]:
        return {
            "model": self.model,
            "param_distributions": self.param_distributions,
            "scorer": self.scorer,
            "n_iter": self.n_iter,
            "val_size": self.val_size,
            "random_state": self.random_state,
            "metrics": self.metrics,
            "tracker": self.tracker,
            "log_all_models": self.log_all_models,
        }

    # Helper method for sampling parameters (would be implemented internally)
    def _sample_params(self, param_distributions: dict) -> dict:
        sampled = {}
        for param, dist in param_distributions.items():
            if isinstance(dist, list):
                sampled[param] = random.choice(dist)
            # Add logic for different distribution types (e.g., uniform, normal)
            # For example: if isinstance(dist, tuple) and dist[0] == 'uniform':
            # sampled[param] = random.uniform(dist[1], dist[2])
            else:
                sampled[param] = dist # If it's a fixed value
        return sampled

This conceptual example shows how you would structure a new model selector by implementing the abstract methods and incorporating your specific search logic (in this case, random sampling). The general flow involves iterating through different parameter combinations, training and evaluating models, and keeping track of the best-performing one.