astrotaxi/step1_data_import.pyΒΆ
from astrodata.data import AbstractProcessor, DataPipeline, ParquetLoader, RawData
def run_data_import_example(config, tracker):
# This step demonstrates how to use the DataPipeline with a ParquetLoader and a custom processor.
# The DataPipeline class orchestrates the loading and processing of data through a series of defined processors.
# It should be used when you want to apply a sequence of transformations to your data, aimed to prepare it for machine learning tasks.
# We will load a Parquet file, process it with a custom processor, and track the resulting data.
# Define the loader that will be used to load the data, returning a RawData object.
loader = ParquetLoader()
# Define a custom processor to create a target variable and filter the data.
# The processor needs to inherit from AbstractProcessor and implement the process method.
class TargetCreator(AbstractProcessor):
def process(self, raw: RawData) -> RawData:
raw.data["duration"] = (
raw.data.lpep_dropoff_datetime - raw.data.lpep_pickup_datetime
)
raw.data["duration"] = raw.data["duration"].apply(
lambda x: x.total_seconds() / 60
)
raw.data = raw.data[
(raw.data["duration"] >= 1) & (raw.data["duration"] <= 60)
].reset_index(drop=True)
raw.data = raw.data[raw.data["trip_distance"] < 50].reset_index(drop=True)
return raw
# Define the list of processors to be used in the pipeline.
data_processors = [TargetCreator()]
# Define the data pipeline with the config file, loader and processors.
data_pipeline = DataPipeline(
config_path=config, loader=loader, processors=data_processors
)
# Path to the input Parquet file
data_path = "./testdata/green_tripdata_2024-01.parquet"
# Run the data pipeline with the path to the Parquet file.
processed = data_pipeline.run(data_path)
# The tracker is used to version code and data with Git and DVC.
# The track method will version everything that is included in the config file, alongside with astrodat-produced files.
# Astrodata creates a folder named "astrodata_files" in which it stores generated data and artifacts.
tracker.track("Data pipeline run, processed data versioned")
print("Data Pipeline ran successfully!")
print(f"Processed data shape:{processed.data.shape}")
return processed