astrodata.tracking package¶
Submodules¶
astrodata.tracking.CodeTracker module¶
- class astrodata.tracking.CodeTracker.CodeTracker(repo_path, ssh_key_path=None, token=None, branch='main')¶
Bases:
object
A class to manage and track code changes in a Git repository.
This class provides methods for common Git operations such as branch management, remote handling, committing, pushing, and synchronizing with remotes.
- Parameters:
repo_path (str) – Path to the local Git repository.
ssh_key_path (str, optional) – Path to the SSH private key for authentication.
token (str, optional) – Personal access token for HTTPS authentication.
branch (str, optional) – Default branch to use. Defaults to “main”.
- add_remote(name, url)¶
Add a remote if it doesn’t already exist, and fetch/pull if the repo is empty.
- Parameters:
name (str) – The name of the remote.
url (str) – The URL of the remote repository.
- Returns:
True if the remote was added or already exists.
- Return type:
bool
- add_to_index(paths)¶
Add paths to the Git index (staging area).
- Parameters:
paths (list) – List of file or directory paths to add.
- Returns:
True if files were added to the index.
- Return type:
bool
- align_with_remote(remote_name='origin')¶
Prune deleted remote branches and delete local branches whose remote is gone.
- Parameters:
remote_name (str, optional) – The name of the remote to align with. Defaults to “origin”.
- Returns:
None
- checkout(branch_name)¶
Checkout to the specified branch.
If the branch does not exist, create it from HEAD.
- Parameters:
branch_name (str) – The name of the branch to checkout.
- Returns:
True if checked out to an existing branch, False if a new branch was created.
- Return type:
bool
- create_commit(message)¶
Create a commit with the given message.
- Parameters:
message (str) – The commit message.
- Returns:
The created commit object, or False if no changes to commit.
- Return type:
Commit
- pull(remote_name, branch)¶
Fetch and pull from remote for empty repository.
- Parameters:
remote_name (str) – The name of the remote.
branch (str) – The branch to pull.
- Returns:
True if pull was successful, None otherwise.
- Return type:
bool
- push(remote_name, branch, remote_url=None)¶
Push the current branch to the specified remote.
- Parameters:
remote_name (str) – The name of the remote.
branch (str) – The branch to push.
remote_url (str, optional) – The URL of the remote, if it needs to be created.
- Returns:
True if push was successful, False otherwise.
- Return type:
bool
- remove_deleted_from_index()¶
Remove deleted files from the Git index.
- Returns:
True if deleted files were removed from the index.
- Return type:
bool
- astrodata.tracking.CodeTracker.git_operation(operation_name)¶
Decorator to handle Git operation errors gracefully.
- Parameters:
operation_name (str) – The name of the Git operation for logging and error messages.
- Returns:
A decorator that wraps the function and handles Git errors.
- Return type:
Callable
astrodata.tracking.DataTracker module¶
- class astrodata.tracking.DataTracker.DataTracker(repo_path, remote)¶
Bases:
object
A class to manage and track data files using DVC (Data Version Control).
This class provides methods to initialize or open a DVC repository, configure remotes, add files to DVC tracking, and synchronize data with a remote storage.
- Parameters:
repo_path (str) – Path to the local DVC repository.
remote (str) – URL of the DVC remote storage.
- add(path)¶
Add a file or directory to DVC tracking.
- Parameters:
path (str) – Relative path to the file or directory to add.
- pull()¶
Pull tracked data from the configured DVC remote storage.
- push()¶
Push tracked data to the configured DVC remote storage.
astrodata.tracking.MLFlowTracker module¶
- class astrodata.tracking.MLFlowTracker.MlflowBaseTracker(run_name=None, experiment_name=None, extra_tags=None, tracking_uri=None, tracking_username=None, tracking_password=None)¶
Bases:
ModelTracker
Base tracker class for MLflow experiment tracking.
Handles MLflow configuration and provides base methods for registering tracked models.
- register_best_model(metric, registered_model_name=None, split_name='train', stage='Production')¶
Register the best model in MLflow Model Registry based on a metric.
- Parameters:
metric (BaseMetric) – Metric used to select the best run.
model_artifact_path (str, optional) – Path to the model artifact in MLflow run.
registered_model_name (str, optional) – Name for the registered model. Defaults to experiment name.
split_name (str, optional) – Which split’s metric to use (‘train’, ‘val’, or ‘test’).
stage (str, optional) – Model stage to assign (e.g., ‘Production’, ‘Staging’).
- Returns:
The result of the registration.
- Return type:
mlflow.entities.model_registry.RegisteredModelVersion
- Raises:
ValueError – If the experiment or suitable run is not found.
- wrap_fit(obj)¶
Placeholder for tracker-specific model wrapping.
To be implemented in subclass.
- Return type:
- class astrodata.tracking.MLFlowTracker.SklearnMLflowTracker(*args, **kwargs)¶
Bases:
MlflowBaseTracker
Tracker for scikit-learn models with MLflow integration.
Provides run lifecycle, parameter logging, metric logging, and optional model logging.
- wrap_fit(model, X_test=None, y_test=None, X_val=None, y_val=None, metrics=None, log_model=False, tags={}, manual_metrics=None)¶
Wrap a BaseMlModel’s fit method to perform MLflow logging.
- Parameters:
model (BaseMlModel) – The model to wrap.
X_test (array-like, optional) – Test data for metric logging.
y_test (array-like, optional) – Test labels for metric logging.
X_val (array-like, optional) – Validation data for metric logging.
y_val (array-like, optional) – Validation labels for metric logging.
metrics (list of BaseMetric, optional) – Metrics to log. If missing, a default loss metric is added.
log_model (bool, optional) – If True, log the fitted model as an MLflow artifact.
tags (Dict[str, Any] default {}) – Any additional tags that should be added to the model. By default the tag “is_final” is set as equal to log_model so that any logged model is considered as a candidate for production (for register_best_model) unless specified otherwise (e.g. in the model selectors for intermediate steps)
- Returns:
A new instance of the model with an MLflow-logging fit method.
- Return type:
astrodata.tracking.ModelTracker module¶
- class astrodata.tracking.ModelTracker.ModelTracker¶
Bases:
ABC
Abstract base class for tracking model fitting processes.
- abstractmethod wrap_fit(obj)¶
Wrap the fit method of an object to add tracking or logging.
- Parameters:
obj (Any) – The object whose fit method will be wrapped.
- Returns:
The wrapped object.
- Return type:
astrodata.tracking.Tracker module¶
- class astrodata.tracking.Tracker.Tracker(config_path)¶
Bases:
object
Orchestrates code and data tracking for a project using Git and DVC.
This class manages both code and data versioning, providing methods to track, commit, and push changes to remote repositories for reproducible research.
- Parameters:
config_path (str) – Path to the configuration file.
- track(commit_message=None)¶
Orchestrate the tracking of data and code, pushing data and committing code.
This method aligns the code repository with the remote, tracks data and code changes, pushes data to the DVC remote, and commits and pushes code changes to the Git remote.
- Parameters:
commit_message (str, optional) – Commit message for the code changes.