ml/3_2_gridsearch_parallel_comparison.pyΒΆ

"""Example 3.2: Compare serial and parallel grid search selectors."""

import argparse
import math
import time
from typing import Dict, Tuple

import numpy as np
import pandas as pd
from sklearn.datasets import load_breast_cancer, make_classification
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split
from sklearn.svm import LinearSVC

from astrodata.ml.metrics import SklearnMetric
from astrodata.ml.model_selection import GridSearchSelector
from astrodata.ml.model_selection.GridSearchSelector_parallel import (
    GridSearchSelectorParallel,
)
from astrodata.ml.models import SklearnModel


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(
        description=(
            "Compare the standard and parallel GridSearch selectors using the same "
            "random seed and an adjustable training set size."
        )
    )
    parser.add_argument(
        "--train-size",
        type=float,
        default=0.75,
        help="Fraction of the dataset reserved for training (0 < train_size < 1).",
    )
    parser.add_argument(
        "--random-state",
        type=int,
        default=42,
        help="Random seed used for reproducible data splits and selectors.",
    )
    parser.add_argument(
        "--n-jobs",
        type=int,
        default=None,
        help=(
            "Number of parallel jobs for the parallel selector. "
            "Use a positive integer, or <= 0 to use (CPU count + n_jobs)."
        ),
    )
    parser.add_argument(
        "--val-size",
        type=float,
        default=0.2,
        help=(
            "Fraction of the training data used as validation inside the selectors. "
            "Must be in (0, 1)."
        ),
    )
    parser.add_argument(
        "--toy-size",
        type=int,
        default=None,
        help=(
            "If set, generate a synthetic classification dataset with the given number "
            "of samples instead of using the breast cancer dataset."
        ),
    )
    parser.add_argument(
        "--toy-features",
        type=int,
        default=30,
        help=(
            "Number of features for the synthetic dataset (only used with --toy-size). "
            "Must be >= 4 to ensure informative and redundant features can be created."
        ),
    )
    parser.add_argument(
        "--toy-classes",
        type=int,
        default=2,
        help="Number of target classes for the synthetic dataset (only used with --toy-size).",
    )

    args = parser.parse_args()
    if not 0 < args.train_size < 1:
        raise ValueError("train_size must be between 0 and 1 (exclusive).")
    if not 0 < args.val_size < 1:
        raise ValueError("val_size must be between 0 and 1 (exclusive).")
    if args.toy_size is not None and args.toy_size <= 0:
        raise ValueError("toy_size must be a positive integer when provided.")
    if args.toy_size is not None and args.toy_features < 4:
        raise ValueError(
            "toy_features must be at least 4 when generating synthetic data."
        )
    if args.toy_size is not None and args.toy_classes < 2:
        raise ValueError(
            "toy_classes must be at least 2 when generating synthetic data."
        )
    return args


def run_selector(selector, X_train, y_train, X_test, y_test) -> Dict[str, float]:
    start = time.perf_counter()
    selector.fit(X_train, y_train, X_test=X_test, y_test=y_test)
    duration = time.perf_counter() - start
    metrics = selector.get_best_metrics() or {}
    return {
        "selector": selector,
        "best_params": selector.get_best_params(),
        "metrics": metrics,
        "duration": duration,
    }


def prepare_dataset(args: argparse.Namespace) -> Tuple[pd.DataFrame, pd.Series, str]:
    if args.toy_size is None:
        dataset = load_breast_cancer()
        X = pd.DataFrame(dataset.data, columns=dataset.feature_names)
        y = pd.Series(dataset.target)
        return X, y, "breast_cancer"

    informative = max(2, int(math.ceil(0.6 * args.toy_features)))
    informative = min(informative, args.toy_features - 2)
    redundant = max(0, args.toy_features - informative)

    X, y = make_classification(
        n_samples=args.toy_size,
        n_features=args.toy_features,
        n_informative=informative,
        n_redundant=redundant,
        n_repeated=0,
        n_classes=args.toy_classes,
        random_state=args.random_state,
    )

    feature_names = [f"feature_{i}" for i in range(args.toy_features)]
    X_df = pd.DataFrame(X, columns=feature_names)
    y_series = pd.Series(y, name="target")
    return X_df, y_series, f"toy(size={args.toy_size}, features={args.toy_features})"


def main() -> None:
    args = parse_args()

    np.random.seed(args.random_state)

    X, y, dataset_label = prepare_dataset(args)

    X_train, X_test, y_train, y_test = train_test_split(
        X,
        y,
        train_size=args.train_size,
        random_state=args.random_state,
        stratify=y,
    )

    model = SklearnModel(model_class=LinearSVC, penalty="l2", loss="squared_hinge")
    accuracy = SklearnMetric(accuracy_score, greater_is_better=True)
    param_grid = {
        "C": [0.1, 1, 10],
        "max_iter": [1000, 2000],
        "tol": [1e-3, 1e-4],
    }

    serial_selector = GridSearchSelector(
        model=model,
        param_grid=param_grid,
        scorer=accuracy,
        val_size=args.val_size,
        random_state=args.random_state,
        metrics=[accuracy],
    )

    parallel_selector = GridSearchSelectorParallel(
        model=model,
        param_grid=param_grid,
        n_jobs=args.n_jobs,
        scorer=accuracy,
        val_size=args.val_size,
        random_state=args.random_state,
        metrics=[accuracy],
    )

    serial_result = run_selector(serial_selector, X_train, y_train, X_test, y_test)
    parallel_result = run_selector(parallel_selector, X_train, y_train, X_test, y_test)

    durations = {
        "serial": serial_result["duration"],
        "parallel": parallel_result["duration"],
    }
    print("\n=== Comparison Summary ===")
    print(f"Dataset: {dataset_label}")
    print(f"Random state: {args.random_state}")
    print(f"Training size fraction: {args.train_size}")
    print(f"Validation size fraction: {args.val_size}")
    print(f"Parallel selector n_jobs: {parallel_selector.n_jobs}")
    print("")

    print("Best parameters (serial):", serial_result["best_params"])
    print("Best parameters (parallel):", parallel_result["best_params"])
    print("")

    for label, metrics in ("serial", serial_result["metrics"]), (
        "parallel",
        parallel_result["metrics"],
    ):
        print(f"Validation metrics ({label}):")
        if not metrics:
            print("  No metrics recorded.")
        for metric_name, value in metrics.items():
            print(f"  {metric_name}: {value:.4f}")
        print("")

    params_match = serial_result["best_params"] == parallel_result["best_params"]
    metrics_match = np.allclose(  # type: ignore[arg-type]
        list(serial_result["metrics"].values()),
        list(parallel_result["metrics"].values()),
        atol=1e-8,
    )

    print(f"Best parameters identical: {params_match}")
    print(f"Validation metrics identical (within tolerance): {metrics_match}")
    print("")
    print("Fit duration (seconds):")
    for label, duration in durations.items():
        print(f"  {label}: {duration:.3f}s")

    if not params_match or not metrics_match:
        raise AssertionError(
            "Serial and parallel grid searches did not converge to the same result."
        )


if __name__ == "__main__":
    main()