Source code for beneath.pipeline.pipeline

# allows us to use Client as a type hint without an import cycle
# see: https://www.stefaanlippens.net/circular-imports-type-hints-python.html
# pylint: disable=wrong-import-position,ungrouped-imports
from __future__ import annotations
from typing import TYPE_CHECKING

if TYPE_CHECKING:
    pass

import asyncio
from collections import defaultdict
from collections.abc import Mapping
from datetime import timedelta
from typing import AsyncIterator, Awaitable, Callable, Dict, Iterable, List, Union

from beneath import config
from beneath.pipeline import BasePipeline, Strategy

AsyncGenerateFn = Callable[[BasePipeline], AsyncIterator[Mapping]]
AsyncApplyFn = Callable[[Mapping], Awaitable[Union[Mapping, Iterable[Mapping]]]]

PIPELINE_IDLE = "<PIPELINE_IDLE>"


class Transform:
    def __init__(self, pipeline: Pipeline):
        self.pipeline = pipeline

    async def process(self, incoming_records: Iterable[Mapping]):
        raise Exception("Must implement process in subclass")


class Generate(Transform):
    def __init__(self, pipeline: Pipeline, fn: AsyncGenerateFn):
        super().__init__(pipeline)
        self.fn = fn

    async def process(self, incoming_records: Iterable[Mapping]):
        assert incoming_records is None
        async for records in self.fn(self.pipeline):
            if records == PIPELINE_IDLE:
                if self.pipeline.strategy != Strategy.continuous:
                    return
                continue
            if not isinstance(records, list):
                records = [records]
            yield records


class ReadTable(Transform):
    def __init__(
        self,
        pipeline: Pipeline,
        table_path: str,
        batch_size: int = config.DEFAULT_READ_BATCH_SIZE,
    ):
        super().__init__(pipeline)
        self.table_path = table_path
        self.batch_size = batch_size
        self.identifier = self.pipeline._add_input_table(table_path)

    async def process(self, incoming_records: Iterable[Mapping]):
        assert incoming_records is None
        consumer = self.pipeline._inputs[self.identifier]
        async for batch in consumer.iterate(
            batches=True,
            replay_only=(self.pipeline.strategy == Strategy.batch),
            stop_when_idle=(self.pipeline.strategy == Strategy.delta),
        ):
            yield batch


class Apply(Transform):
    def __init__(self, pipeline: Pipeline, fn: AsyncApplyFn, max_concurrency: int = None):
        super().__init__(pipeline)
        self.fn = fn
        self.max_concurrency = (
            max_concurrency if max_concurrency else 1
        )  # TODO: use max_concurrency :(

    async def process(self, incoming_records: Iterable[Mapping]):
        for record in incoming_records:
            fut = self.fn(record)
            if hasattr(fut, "__aiter__"):
                async for record in fut:
                    yield [record]
            else:
                batch = await fut
                if batch is None:
                    continue
                if isinstance(batch, list):
                    yield batch
                else:
                    yield [batch]


class WriteTable(Transform):
    def __init__(
        self,
        pipeline: Pipeline,
        table_path: str,
        schema: str = None,
        schema_kind: str = "GraphQL",
        description: str = None,
        retention: timedelta = None,
    ):
        super().__init__(pipeline)
        self.identifier = self.pipeline._add_output_table(
            table_path,
            schema=schema,
            schema_kind=schema_kind,
            description=description,
            retention=retention,
        )

    async def process(self, incoming_records: Iterable[Mapping]):
        instance = self.pipeline._outputs[self.identifier]
        await self.pipeline.client.write(
            instance=instance,
            records=incoming_records,
        )
        # trick to turn into async iterator
        return
        # pylint: disable=unreachable
        yield


[docs]class Pipeline(BasePipeline): """ Pipelines are a construct built on top of the Beneath primitives to manage the reading of input tables, creation of output tables, data generation and derivation logic, and more. This simple implementation supports four combinatorial operations: ``generate``, ``read_table``, ``apply``, and ``write_table``. It's suitable for generating tables, consuming tables, and one-to-N derivation of one table to another. It's not currently suitable for advanced aggregation or multi-machine parallel processing. It supports (light) stateful transformations via a key-value based `checkpointer`, which is useful for tracking generator and consumer progress in between invocations (or crashes). Args: action (Action): The action to run when calling ``main`` or ``execute`` strategy (Strategy): The processing strategy to apply when action="test" or action="run" version (int): The version number for output tables. Incrementing the version number will cause the pipeline to create new output table instances and replay input tables. service_path (str): Path for a service to create for the pipeline when action="stage". The service will be assigned correct permissions for reading input tables and writing to output tables and checkpoints. The service can be used to create a secret for deploying the pipeline to production. service_read_quota (int): Read quota for the service staged for the pipeline (see ``service_path``) service_write_quota (int): Write quota for the service staged for the pipeline (see ``service_path``) service_scan_quota (int): Scan quota for the service staged for the pipeline (see ``service_path``) client (Client): Client to use for the pipeline. If not set, initializes a new client (see its init for details on which secret gets used). write_delay_ms (int): Passed to ``Client`` initializer if ``client`` arg isn't passed disable_checkpoints (bool): If true, will not create a checkpointer, and consumers will not save state. Defaults to false. Example:: # pipeline.py # This pipeline generates a table `ticks` with a record for every minute # since 1st Jan 2021. # To test locally: # python ./pipeline.py test # To prepare for running: # python ./pipeline.py stage USERNAME/PROJECT/ticker # To run (after stage): # python ./pipeline.py run USERNAME/PROJECT/ticker # To teardown: # python ./pipeline.py teardown USERNAME/PROJECT/ticker import asyncio import beneath from datetime import datetime, timedelta, timezone start = datetime(year=2021, month=1, day=1, tzinfo=timezone.utc) async def ticker(p: beneath.Pipeline): last_tick = await p.checkpoints.get("last", default=start) while True: now = datetime.now(tz=last_tick.tzinfo) next_tick = last_tick + timedelta(minutes=1) if next_tick >= now: yield beneath.PIPELINE_IDLE wait = next_tick - now await asyncio.sleep(wait.total_seconds()) yield {"time": next_tick} await p.checkpoints.set("last", next_tick) last_tick = next_tick if __name__ == "__main__": p = beneath.Pipeline(parse_args=True) p.description = "Pipeline that emits a tick for every minute since 1st Jan 2021" ticks = p.generate(ticker) p.write_table( ticks, table_path="ticks", schema=''' type Tick @schema { time: Timestamp! @key } ''', ) p.main() """ _initial: List[Transform] _dag: Dict[Transform, List[Transform]] # INIT OVERRIDE def _init(self): self._initial = [] self._dag = defaultdict(list) # DAG
[docs] def generate(self, fn: AsyncGenerateFn) -> Transform: """ Pipeline step for generating records. Args: fn (async def fn(pipeline)): An async iterator that generates records. The pipeline is passed as an input arg. """ transform = Generate(self, fn) self._initial.append(transform) return transform
[docs] def read_table(self, table_path: str) -> Transform: """ Pipeline step for consuming the primary instance of a table. Args: table_path (str): The table to consume """ transform = ReadTable(self, table_path) self._initial.append(transform) return transform
[docs] def apply( self, prev_transform: Transform, fn: AsyncApplyFn, max_concurrency: int = None, ) -> Transform: """ Pipeline step that transforms records emitted by a previous step. Args: prev_transform (Transform): The pipeline step to apply on. Can be a `generate`, `read_table`, or other `apply` step. fn (async def fn(record) -> (None | record | [record])): Function applied to each incoming record. Can return ``None``, one record or a list of records, which will propagate to the next pipeline step. """ transform = Apply(self, fn, max_concurrency) self._dag[prev_transform].append(transform) return transform
[docs] def write_table( self, prev_transform: Transform, table_path: str, schema: str = None, description: str = None, retention: timedelta = None, schema_kind: str = "GraphQL", ): """ Pipeline step that writes incoming records from the previous step to a table. Args: prev_transform (Transform): The pipeline step to apply on. Can be a `generate`, `read_table`, or other `apply` step. table_path (str): The table to output to. If ``schema`` is provided, the table will be created when running the ``stage`` action. If the path doesn't include a username and project name, it will attempt to find or create the table in the pipeline's service's project (see ``service_path`` in the constructor). schema (str): A GraphQL schema for creating the output table. description (str): A description for the table (only applicable if schema is set). retention (timedelta): The amount of time to retain written data in the table. By default, records are saved forever. schema_kind (str): The parser to use for ``schema``. Currently must be "GraphQL" (default). """ transform = WriteTable( self, table_path, schema, description=description, retention=retention, schema_kind=schema_kind, ) self._dag[prev_transform].append(transform)
# RUNNING async def _run(self): await asyncio.gather(*[self._run_transform(transform, None) for transform in self._initial]) async def _run_transform(self, transform: Transform, incoming_records: List): async for outgoing_records in transform.process(incoming_records): tasks = ( self._run_transform(next_transform, outgoing_records) for next_transform in self._dag[transform] ) await asyncio.gather(*tasks) async def _before_run(self): # make outputs primary if not batch if self.strategy == Strategy.batch: return for identifier, instance in self._outputs.items(): if not instance.is_primary: await instance.update(make_primary=True) self.logger.info( "Made version %s of output table '%s' primary", instance.version, identifier, ) else: self.logger.info( "Using existing primary version %s for output table '%s'", instance.version, identifier, ) async def _after_run(self): # make outputs primary if batch if self.strategy != Strategy.batch: return for identifier, instance in self._outputs.items(): if not instance.is_primary or not instance.is_final: await instance.update(make_primary=True, make_final=True) self.logger.info( "Made version %s for output table '%s' final and primary", instance.version, identifier, )