astrodata.tracking package

Submodules

astrodata.tracking.CodeTracker module

class astrodata.tracking.CodeTracker.CodeTracker(repo_path, ssh_key_path=None, token=None, branch='main')

Bases: object

A class to manage and track code changes in a Git repository.

This class provides methods for common Git operations such as branch management, remote handling, committing, pushing, and synchronizing with remotes.

Parameters:
  • repo_path (str) – Path to the local Git repository.

  • ssh_key_path (str, optional) – Path to the SSH private key for authentication.

  • token (str, optional) – Personal access token for HTTPS authentication.

  • branch (str, optional) – Default branch to use. Defaults to “main”.

add_remote(name, url)

Add a remote if it doesn’t already exist, and fetch/pull if the repo is empty.

Parameters:
  • name (str) – The name of the remote.

  • url (str) – The URL of the remote repository.

Returns:

True if the remote was added or already exists.

Return type:

bool

add_to_index(paths)

Add paths to the Git index (staging area).

Parameters:

paths (list) – List of file or directory paths to add.

Returns:

True if files were added to the index.

Return type:

bool

align_with_remote(remote_name='origin')

Prune deleted remote branches and delete local branches whose remote is gone.

Parameters:

remote_name (str, optional) – The name of the remote to align with. Defaults to “origin”.

Returns:

None

checkout(branch_name)

Checkout to the specified branch.

If the branch does not exist, create it from HEAD.

Parameters:

branch_name (str) – The name of the branch to checkout.

Returns:

True if checked out to an existing branch, False if a new branch was created.

Return type:

bool

create_commit(message)

Create a commit with the given message.

Parameters:

message (str) – The commit message.

Returns:

The created commit object, or False if no changes to commit.

Return type:

Commit

pull(remote_name, branch)

Fetch and pull from remote for empty repository.

Parameters:
  • remote_name (str) – The name of the remote.

  • branch (str) – The branch to pull.

Returns:

True if pull was successful, None otherwise.

Return type:

bool

push(remote_name, branch, remote_url=None)

Push the current branch to the specified remote.

Parameters:
  • remote_name (str) – The name of the remote.

  • branch (str) – The branch to push.

  • remote_url (str, optional) – The URL of the remote, if it needs to be created.

Returns:

True if push was successful, False otherwise.

Return type:

bool

remove_deleted_from_index()

Remove deleted files from the Git index.

Returns:

True if deleted files were removed from the index.

Return type:

bool

astrodata.tracking.CodeTracker.git_operation(operation_name)

Decorator to handle Git operation errors gracefully.

Parameters:

operation_name (str) – The name of the Git operation for logging and error messages.

Returns:

A decorator that wraps the function and handles Git errors.

Return type:

Callable

astrodata.tracking.DataTracker module

class astrodata.tracking.DataTracker.DataTracker(repo_path, remote)

Bases: object

A class to manage and track data files using DVC (Data Version Control).

This class provides methods to initialize or open a DVC repository, configure remotes, add files to DVC tracking, and synchronize data with a remote storage.

Parameters:
  • repo_path (str) – Path to the local DVC repository.

  • remote (str) – URL of the DVC remote storage.

add(path)

Add a file or directory to DVC tracking.

Parameters:

path (str) – Relative path to the file or directory to add.

pull()

Pull tracked data from the configured DVC remote storage.

push()

Push tracked data to the configured DVC remote storage.

astrodata.tracking.MLFlowTracker module

class astrodata.tracking.MLFlowTracker.MlflowBaseTracker(run_name=None, experiment_name=None, extra_tags=None, tracking_uri=None, tracking_username=None, tracking_password=None)

Bases: ModelTracker

Base tracker class for MLflow experiment tracking.

Handles MLflow configuration and provides base methods for registering tracked models.

register_best_model(metric, registered_model_name=None, split_name='train', stage='Production')

Register the best model in MLflow Model Registry based on a metric.

Parameters:
  • metric (BaseMetric) – Metric used to select the best run.

  • model_artifact_path (str, optional) – Path to the model artifact in MLflow run.

  • registered_model_name (str, optional) – Name for the registered model. Defaults to experiment name.

  • split_name (str, optional) – Which split’s metric to use (‘train’, ‘val’, or ‘test’).

  • stage (str, optional) – Model stage to assign (e.g., ‘Production’, ‘Staging’).

Returns:

The result of the registration.

Return type:

mlflow.entities.model_registry.RegisteredModelVersion

Raises:

ValueError – If the experiment or suitable run is not found.

wrap_fit(obj)

Placeholder for tracker-specific model wrapping.

To be implemented in subclass.

Return type:

BaseMlModel

class astrodata.tracking.MLFlowTracker.SklearnMLflowTracker(*args, **kwargs)

Bases: MlflowBaseTracker

Tracker for scikit-learn models with MLflow integration.

Provides run lifecycle, parameter logging, metric logging, and optional model logging.

wrap_fit(model, X_test=None, y_test=None, X_val=None, y_val=None, metrics=None, log_model=False, tags={}, manual_metrics=None)

Wrap a BaseMlModel’s fit method to perform MLflow logging.

Parameters:
  • model (BaseMlModel) – The model to wrap.

  • X_test (array-like, optional) – Test data for metric logging.

  • y_test (array-like, optional) – Test labels for metric logging.

  • X_val (array-like, optional) – Validation data for metric logging.

  • y_val (array-like, optional) – Validation labels for metric logging.

  • metrics (list of BaseMetric, optional) – Metrics to log. If missing, a default loss metric is added.

  • log_model (bool, optional) – If True, log the fitted model as an MLflow artifact.

  • tags (Dict[str, Any] default {}) – Any additional tags that should be added to the model. By default the tag “is_final” is set as equal to log_model so that any logged model is considered as a candidate for production (for register_best_model) unless specified otherwise (e.g. in the model selectors for intermediate steps)

Returns:

A new instance of the model with an MLflow-logging fit method.

Return type:

BaseMlModel

astrodata.tracking.ModelTracker module

class astrodata.tracking.ModelTracker.ModelTracker

Bases: ABC

Abstract base class for tracking model fitting processes.

abstractmethod wrap_fit(obj)

Wrap the fit method of an object to add tracking or logging.

Parameters:

obj (Any) – The object whose fit method will be wrapped.

Returns:

The wrapped object.

Return type:

BaseMlModel

astrodata.tracking.Tracker module

class astrodata.tracking.Tracker.Tracker(config_path)

Bases: object

Orchestrates code and data tracking for a project using Git and DVC.

This class manages both code and data versioning, providing methods to track, commit, and push changes to remote repositories for reproducible research.

Parameters:

config_path (str) – Path to the configuration file.

track(commit_message=None)

Orchestrate the tracking of data and code, pushing data and committing code.

This method aligns the code repository with the remote, tracks data and code changes, pushes data to the DVC remote, and commits and pushes code changes to the Git remote.

Parameters:

commit_message (str, optional) – Commit message for the code changes.

Module contents