metrics¶
The astrodata.ml.metrics
module provides a standardized interface for defining and using evaluation metrics for machine learning models. It features an abstract base class BaseMetric
and an adapter for scikit-learn metrics, SklearnMetric
.
Abstract Class¶
BaseMetric
is the abstract base class that all custom metrics should inherit from. A metric must implement:
__init__()
: Initializes the metric.__call__(y_true, y_pred, **kwargs)
: Computes the metric value given true and predicted labels.get_name()
: Returns the name of the metric.greater_is_better
: A property indicating whether a higher value of the metric is desirable.
How to Use¶
Calling a metric¶
A metric that has been created following the defined abstract class can be always called directly:
from astrodata.ml.metrics.SklearnMetric import SklearnMetric
from sklearn.metrics import accuracy_score
accuracy_metric = SklearnMetric(accuracy_score)
accuracy_computed = accuracy_metric(y_true, y_pred)
The result of this operation is the computed metric on the two provided arrays.
Attention
y_true
and y_pred
should always have the same length by definition.
Tip
Some sk_learn metrics, especially for classification tasks, require the probability of a given class rather than the predicted label. Be sure to read the related documentation before using them!
Creating a Custom Metric (inheriting from BaseMetric
)¶
A custom metric would typically look like this:
from astrodata.ml.metrics.BaseMetric import BaseMetric
from typing import Any
class MyCustomAccuracy(BaseMetric):
def __init__(self):
pass
def __call__(self, y_true: Any, y_pred: Any, **kwargs) -> float:
# Implement your custom accuracy calculation here
correct_predictions = sum(1 for true, pred in zip(y_true, y_pred) if true == pred)
return correct_predictions / len(y_true)
def get_name(self) -> str:
return "MyCustomAccuracy"
@property
def greater_is_better(self) -> bool:
return True
SklearnMetric
¶
This class allows you to easily wrap existing scikit-learn metric functions:
from sklearn.metrics import accuracy_score, f1_score
from astrodata.ml.metrics.SklearnMetric import SklearnMetric
# Wrap accuracy_score
accuracy = SklearnMetric(metric=accuracy_score, name="accuracy", greater_is_better=True)
# Wrap f1_score with specific parameters
f1_macro = SklearnMetric(metric=f1_score, name="f1_macro", greater_is_better=True, average='macro')
# Example usage
y_true = [0, 1, 0, 1]
y_pred = [0, 0, 0, 1]
print(f"Accuracy: {accuracy(y_true, y_pred)}")
print(f"F1 Macro: {f1_macro(y_true, y_pred)}")