From 89055db5d66c0167922a67c09770239cb863835e Mon Sep 17 00:00:00 2001 From: arlobelshee Date: Wed, 1 Feb 2023 11:07:22 -0500 Subject: [PATCH] F - sketch in how I want to handle define error handlers. The builders currently dont' do anything with those yet. --- src/datapipeline/clientapi.py | 26 +++++++++++++- src/datapipeline/pipeline.py | 65 +++++++++++++++++++++++++++++++---- src/examplecode/product.py | 40 ++++++++++++++++++--- 3 files changed, 119 insertions(+), 12 deletions(-) diff --git a/src/datapipeline/clientapi.py b/src/datapipeline/clientapi.py index 463602b..30159e7 100644 --- a/src/datapipeline/clientapi.py +++ b/src/datapipeline/clientapi.py @@ -1,10 +1,12 @@ from __future__ import annotations +import enum from abc import abstractmethod -from typing import runtime_checkable, Protocol, TypeVar +from typing import runtime_checkable, Protocol, TypeVar, Callable TIn = TypeVar('TIn') TOut = TypeVar('TOut') +SomeException = TypeVar('Exc', bound=Exception) @runtime_checkable @@ -62,3 +64,25 @@ class StoreImpl(NamedStep, Protocol[TIn]): @abstractmethod async def __call__(self, data: TIn) -> None: pass + + +AfterError = enum.Enum('AfterError', ['Abort', 'Retry', 'Skip']) + + +@runtime_checkable +class ErrorResponseSync(Protocol[TIn, SomeException]): + @abstractmethod + def __call__(self, segment: '_PipeSegment', original_data: TIn, modified_data: TIn, + exception: SomeException) -> AfterError: + pass + + +@runtime_checkable +class ErrorResponseAsync(Protocol[TIn, SomeException]): + @abstractmethod + async def __call__(self, segment: '_PipeSegment', original_data: TIn, modified_data: TIn, + exception: SomeException) -> AfterError: + pass + + +ErrorResponse = ErrorResponseSync[TIn, SomeException] | ErrorResponseAsync[TIn, SomeException] diff --git a/src/datapipeline/pipeline.py b/src/datapipeline/pipeline.py index 7d064f5..f9ed654 100644 --- a/src/datapipeline/pipeline.py +++ b/src/datapipeline/pipeline.py @@ -1,12 +1,13 @@ from __future__ import annotations import asyncio +import datetime from typing import Callable, TypeVar, Awaitable, Generic, List, Iterable -import assertpy -from assertpy import soft_assertions, assert_that, soft_fail +from assertpy import soft_assertions from datapipeline import DataProcessingSegment, RestructuringSegment, PipeHeadSegment, clientapi +from datapipeline.clientapi import AfterError from datapipeline.segmentimpl import _PipeSegment, SourceSegment, TransformSegment, SinkSegment T = TypeVar("T") @@ -45,21 +46,45 @@ def build(next_segment: _PipeSegment[T, TNext] | None) -> PipeHeadSegment[T]: return IncompletePipeline[T, T]([build]) -def source(load: Callable[[T], Awaitable[TRaw]], parse: Callable[[T, TRaw], None]) -> SegmentBuilder[T]: +SomeException = TypeVar('Exc', bound=Exception) + + +class ErrorHandler(Generic[T, SomeException]): + exception: type[SomeException] + handler: clientapi.ErrorResponse[T, SomeException] + + def __init__(self, exception: type[SomeException], handler: clientapi.ErrorResponse[T, SomeException]) -> None: + self.exception = exception + self.handler = handler + + def __call__(self, next_segment: _PipeSegment[T, TNext] | None) -> SourceSegment[T, TRaw]: + return next_segment + + +def on_error(exception: type[SomeException], + handler: clientapi.ErrorResponse[T, SomeException]) -> ErrorHandler[T, SomeException]: + return ErrorHandler(exception, handler) + + +def source(load: Callable[[T], Awaitable[TRaw]], + parse: Callable[[T, TRaw], None], + *error_handlers: ErrorHandler) -> SegmentBuilder[T]: def build(next_segment: _PipeSegment[T, TNext] | None) -> SourceSegment[T, TRaw]: return SourceSegment(load, parse, next_segment) return build -def transform(process: Callable[[T], None]) -> SegmentBuilder[T]: +def transform(process: Callable[[T], None], *error_handlers: ErrorHandler) -> SegmentBuilder[T]: def build(next_segment: _PipeSegment[T, TNext] | None) -> TransformSegment[T]: return TransformSegment(process, next_segment) return build -def sink(extract: Callable[[TSrc], TDest], store: Callable[[TDest], Awaitable[None]]) -> SegmentBuilder[TSrc]: +def sink(extract: Callable[[TSrc], TDest], + store: Callable[[TDest], Awaitable[None]], + *error_handlers: ErrorHandler) -> SegmentBuilder[TSrc]: def build(next_segment: _PipeSegment[T, TNext] | None) -> SinkSegment[T, TRaw]: return SinkSegment(extract, store, next_segment) @@ -70,7 +95,8 @@ class IncompletePipeline(Generic[TSrc, T]): def __init__(self, prior_steps: List[AnyBuilder]): self._prior_steps = prior_steps - def then(self, *steps: SegmentBuilder[T]) -> PotentiallyCompletePipeline[TSrc, T]: + def then(self, *steps: SegmentBuilder[T] | ErrorHandler) -> PotentiallyCompletePipeline[TSrc, T]: + # Handle error handlers. Require them to only be at the end; they apply to the whole chain. return PotentiallyCompletePipeline(self._prior_steps + list(steps)) @@ -111,10 +137,35 @@ def run(self): asyncio.run(self._first_segment.process(None)) -def pipeline(builder: PotentiallyCompletePipeline[TSrc, TDest]) -> Pipeline: +def pipeline(builder: PotentiallyCompletePipeline[TSrc, TDest], *error_handlers: ErrorHandler) -> Pipeline: return builder.build() +def retry(max_retries: int, delay: datetime.timedelta) -> clientapi.ErrorResponse: + retries_left = max_retries + + async def do_retry( + segment: _PipeSegment[T, TNext], + original_data: T, + modified_data: T, + exception: Exception) -> clientapi.AfterError: + nonlocal retries_left + retries_left -= 1 + if retries_left > 0: + await asyncio.sleep(delay.seconds) + return clientapi.AfterError.Retry + + return do_retry + + +def skip_this_step( + segment: _PipeSegment[T, TNext], + original_data: T, + modified_data: T, + exception: Exception) -> AfterError: + return AfterError.Skip + + def is_valid_pipeline(self): if not isinstance(self.val, Pipeline): raise TypeError('val must be a pipeline.') diff --git a/src/examplecode/product.py b/src/examplecode/product.py index f63a914..0efdca4 100644 --- a/src/examplecode/product.py +++ b/src/examplecode/product.py @@ -1,8 +1,12 @@ from __future__ import annotations +from datetime import timedelta from typing import Any -from datapipeline.pipeline import needs, gives, source, transform, sink, pipeline, start_with +from datapipeline.clientapi import AfterError +from datapipeline.pipeline import needs, gives, source, transform, sink, pipeline, start_with, on_error, retry, \ + skip_this_step +from datapipeline.segmentimpl import _PipeSegment class RawCustomerData: @@ -124,18 +128,44 @@ async def put_projections_into_quickbooks(data: DestStructureTwo) -> None: pass +class CustomError(Exception): + pass + + +def log_assertions_and_continue_with_next_step(segment: _PipeSegment[RawCustomerData, RawCustomerData], original_data: RawCustomerData, modified_data: RawCustomerData, exception: Exception) -> AfterError: + return AfterError.Skip + + +def something_nifty(segment: _PipeSegment[RawCustomerData, RawCustomerData], original_data: RawCustomerData, modified_data: RawCustomerData, exception: Exception) -> AfterError: + # Do some kind of nifty error recovery and stuff. + return AfterError.Retry + + +def laugh_at_math(segment: _PipeSegment[RawCustomerData, RawCustomerData], original_data: RawCustomerData, modified_data: RawCustomerData, exception: Exception) -> AfterError: + return AfterError.Abort + + +def more_general_handler(segment: _PipeSegment[RawCustomerData, RawCustomerData], original_data: RawCustomerData, modified_data: RawCustomerData, exception: Exception) -> AfterError: + # Give helpful messages to people. Then we'll let the pipeline die. + return AfterError.Abort + + def create_pipeline(): return pipeline( start_with(RawCustomerData) .then( source(load_customer_csv, parse_customers_from_csv), - source(load_customer_crm_api, parse_customers_from_json), + source(load_customer_crm_api, parse_customers_from_json, + on_error(ConnectionError, retry(3, timedelta(seconds=10))), + on_error(CustomError, skip_this_step)), transform(remove_invalid_emails), transform(remove_test_customers), source(load_customer_orders, parse_orders_from_json), transform(remove_empty_orders), transform(group_orders_into_customer_cohorts), - transform(compute_cohort_relative_date_per_order)) + transform(compute_cohort_relative_date_per_order), + on_error(AssertionError, log_assertions_and_continue_with_next_step), + on_error(CustomError, something_nifty)) .restructure_to(CustomerGraph, create_customer_object_graph) .then( transform(understand_something), @@ -143,5 +173,7 @@ def create_pipeline(): transform(keep_understanding), sink(extract_cohort_analysis, email_analysis_to_sales_team), sink(extract_revenue_projections, put_projections_into_quickbooks), - ) + ), + on_error(ArithmeticError, laugh_at_math), + on_error(CustomError, more_general_handler) )