Source code for beneath.pipeline.base_pipeline

# allows us to use Client as a type hint without an import cycle
# see: https://www.stefaanlippens.net/circular-imports-type-hints-python.html
# pylint: disable=wrong-import-position,ungrouped-imports
from __future__ import annotations
from typing import TYPE_CHECKING

if TYPE_CHECKING:
    pass

import asyncio
from enum import Enum
import os
import signal
from typing import Dict, Set
import uuid

from beneath import config
from beneath.connection import GraphQLError
from beneath.consumer import Consumer
from beneath.client import Client
from beneath.checkpointer import Checkpointer
from beneath.instance import TableInstance
from beneath.pipeline.parse_args import parse_pipeline_args
from beneath.table import Table
from beneath.utils import ServiceIdentifier, TableIdentifier


[docs]class Action(Enum): """ Actions that a pipeline can execute """ test = "test" """ Does a dry run of the pipeline, where input tables are read, but no output tables are created. Records output by the pipeline will be logged, but not written. """ stage = "stage" """ Creates all the resources the pipeline needs. These include output tables, a checkpoints meta table, and a service with the correct read and write permissions. The service can be used to create a secret for deploying the pipeline to production. """ run = "run" """ Runs the pipeline, reading inputs, applying transformations, and writing to output tables. You must execute the "stage" action before running. """ teardown = "teardown" """ The reverse action of "stage". Deletes all resources created for the pipeline. """
[docs]class Strategy(Enum): """ Strategies for running a pipeline (only apply when ``action="run"``) """ continuous = "continuous" """ Attempts to keep the pipeline running forever. Will replay inputs and stay subscribed for changes, and continously flush outputs. May terminate if a generator returns (see ``Pipeline.generate`` for more). """ delta = "delta" """ Stops the pipeline when there are no more changes to consume. On the first run, it replays inputs, and on subsequent runs, it will consume and process all changes since the last run, and then stop. """ batch = "batch" """ Replays all inputs on every run, and never consumes changes. It will finalize table instances once it is done, so you must increase the ``version`` number on every run. """
class BasePipeline: """ The ``BasePipeline`` implements core functionality for managing tables and checkpointing, while subclasses implement data transformation logic. See the subclass ``Pipeline`` for more. """ def __init__( self, action: Action = None, strategy: Strategy = None, version: int = None, service_path: str = None, service_read_quota: int = None, service_write_quota: int = None, service_scan_quota: int = None, parse_args: bool = False, client: Client = None, write_delay_ms: int = config.DEFAULT_WRITE_DELAY_MS, disable_checkpoints: bool = False, ): if parse_args: parse_pipeline_args() self.action = self._arg_or_env("action", action, "BENEATH_PIPELINE_ACTION", fn=Action) self.strategy = self._arg_or_env( "strategy", strategy, "BENEATH_PIPELINE_STRATEGY", fn=Strategy, ) self.version = self._arg_or_env("version", version, "BENEATH_PIPELINE_VERSION", fn=int) self.service_identifier = self._arg_or_env( "service_path", service_path, "BENEATH_PIPELINE_SERVICE_PATH", fn=ServiceIdentifier.from_path, ) self.service_read_quota = self._arg_or_env( "service_read_quota", service_read_quota, "BENEATH_PIPELINE_SERVICE_READ_QUOTA", fn=int, default=lambda: None, ) self.service_write_quota = self._arg_or_env( "service_write_quota", service_write_quota, "BENEATH_PIPELINE_SERVICE_WRITE_QUOTA", fn=int, default=lambda: None, ) self.service_scan_quota = self._arg_or_env( "service_scan_quota", service_scan_quota, "BENEATH_PIPELINE_SERVICE_SCAN_QUOTA", fn=int, default=lambda: None, ) dry = self.action == Action.test self.client = client if client else Client(dry=dry, write_delay_ms=write_delay_ms) self.logger = self.client.logger self.checkpoints: Checkpointer = None self.service_id: uuid.UUID = None self.description: str = None self.disable_checkpoints = disable_checkpoints self._inputs: Dict[TableIdentifier, Consumer] = {} self._input_identifiers: Set[TableIdentifier] = set() self._outputs: Dict[TableIdentifier, TableInstance] = {} self._output_identifiers: Set[TableIdentifier] = set() self._output_kwargs: Dict[TableIdentifier, dict] = {} self._execute_task: asyncio.Task = None self._init() @staticmethod def _arg_or_env(name, arg, env, fn=None, default=None): if arg is not None: if fn: return fn(arg) return arg arg = os.getenv(env) if arg is not None: if fn: return fn(arg) return arg if default: if callable(default): return default() return default raise ValueError(f"Must provide {name} arg or set {env} environment variable") # OVERRIDES def _init(self): # pylint: disable=no-self-use raise Exception("_init should be implemented in a subclass") async def _before_run(self): raise Exception("_before_run should be implemented in a subclass") async def _after_run(self): raise Exception("_after_run should be implemented in a subclass") async def _run(self): raise Exception("_run should be implemented in a subclass") # RUNNING def main(self): """ Executes the pipeline in an asyncio event loop and gracefully handles SIGINT and SIGTERM """ loop = asyncio.get_event_loop() async def _main(): try: self._execute_task = asyncio.create_task(self.execute()) await self._execute_task except asyncio.CancelledError: self.logger.info("Pipeline gracefully cancelled") async def _kill(sig): self.logger.info("Received %s, attempting graceful shutdown...", sig.name) if self._execute_task: self._execute_task.cancel() signals = (signal.SIGTERM, signal.SIGINT) for sig in signals: loop.add_signal_handler(sig, lambda sig=sig: asyncio.create_task(_kill(sig))) loop.run_until_complete(_main()) async def execute(self): """ Executes the pipeline in accordance with the action and strategy set on init (called by ``main``). """ await self.client.start() try: if self.action == Action.test: await self._execute_test() elif self.action == Action.stage: await self._execute_stage() elif self.action == Action.run: await self._execute_run() elif self.action == Action.teardown: await self._execute_teardown() except asyncio.CancelledError: await self.client.stop() self.logger.info("Pipeline cancelled") raise async def _execute_test(self): await self._stage_checkpointer(create=True) await asyncio.gather( *[self._stage_input_table(identifier) for identifier in self._input_identifiers], *[ self._stage_output_table(identifier=identifier, create=True) for identifier in self._output_identifiers ], ) self.logger.info("Running in TEST mode: Records will be printed, but not saved") await self._before_run() await self._run() await self._after_run() await self.client.stop() self.logger.info("Finished running pipeline") async def _execute_stage(self): async def _stage_input(identifier: TableIdentifier): await self._stage_input_table(identifier) consumer = self._inputs[identifier] await self._update_service_permissions(consumer.instance.table, read=True) async def _stage_output(identifier: TableIdentifier): await self._stage_output_table(identifier=identifier, create=True) instance = self._outputs[identifier] await self._update_service_permissions(instance.table, write=True) await self._stage_service(create=True) await self._stage_checkpointer(create=True) if not self.disable_checkpoints: await self._update_service_permissions( self.checkpoints.instance.table, read=True, write=True ) await asyncio.gather(*[_stage_input(identifier) for identifier in self._input_identifiers]) await asyncio.gather(*[_stage_output(identifier) for identifier in self._output_identifiers]) await self.client.stop() self.logger.info("Finished staging pipeline for service '%s'", self.service_identifier) async def _execute_run(self): await self._stage_checkpointer(create=False) await asyncio.gather( *[self._stage_input_table(identifier) for identifier in self._input_identifiers], *[ self._stage_output_table(identifier=identifier, create=False) for identifier in self._output_identifiers ], ) self.logger.info("Running in PRODUCTION mode: Records will be saved to Beneath") await self._before_run() await self._run() await self.client.stop() await self._after_run() self.logger.info("Finished running pipeline") async def _execute_teardown(self): await asyncio.gather( *[self._teardown_output_table(identifier) for identifier in self._output_identifiers] ) await self._teardown_checkpointer() await self._teardown_service() await self.client.stop() self.logger.info("Finished tearing down pipeline for service '%s'", self.service_identifier) # STAGING async def _stage_service(self, create: bool): if create: description = ( self.description if self.description else f"Service for running '{self.service_identifier.service}' pipeline" ) admin_data = await self.client.admin.services.create( organization_name=self.service_identifier.organization, project_name=self.service_identifier.project, service_name=self.service_identifier.service, description=description, read_quota_bytes=self.service_read_quota, write_quota_bytes=self.service_write_quota, scan_quota_bytes=self.service_scan_quota, update_if_exists=True, ) self.logger.info("Staged service '%s'", self.service_identifier) else: admin_data = await self.client.admin.services.find_by_organization_project_and_name( organization_name=self.service_identifier.organization, project_name=self.service_identifier.project, service_name=self.service_identifier.service, ) self.logger.info("Found service '%s'", self.service_identifier) self.service_id = uuid.UUID(hex=admin_data["serviceID"]) async def _teardown_service(self): try: await self._stage_service(create=False) await self.client.admin.services.delete(self.service_id) self.logger.info("Deleted service '%s'", self.service_identifier) except GraphQLError as e: if " not found" in str(e): return raise async def _stage_checkpointer(self, create: bool): if self.disable_checkpoints: return metatable_name = self.service_identifier.service + "-checkpoints" self.checkpoints = await self.client.checkpointer( project_path=f"{self.service_identifier.organization}/{self.service_identifier.project}", metatable_name=metatable_name, metatable_create=create, metatable_description=( f"Stores checkpointed state for the '{self.service_identifier.service}' pipeline" ), key_prefix=f"{self.version}:", ) if create: self.logger.info( "Staged checkpointer '%s'", self.checkpoints.instance.table._identifier ) else: self.logger.info("Found checkpointer '%s'", self.checkpoints.instance.table._identifier) async def _teardown_checkpointer(self): if self.disable_checkpoints: return try: await self._stage_checkpointer(create=False) await self.checkpoints.instance.table.delete() self.logger.info( "Deleted checkpointer '%s'", self.checkpoints.instance.table._identifier ) except GraphQLError as e: if " not found" in str(e): return raise def _add_input_table(self, table_path: str): identifier = TableIdentifier.from_path(table_path) self._input_identifiers.add(identifier) return identifier async def _stage_input_table(self, identifier: TableIdentifier): if identifier in self._output_identifiers: raise ValueError(f"Cannot use table '{identifier}' as both input and output") subscription_path = None if not self.disable_checkpoints: subscription_path = f"{self.service_identifier.organization}/" subscription_path += f"{self.service_identifier.project}/" subscription_path += self.service_identifier.service consumer = await self.client.consumer( table_path=str(identifier), subscription_path=subscription_path, checkpointer=self.checkpoints, ) self._inputs[identifier] = consumer self.logger.info( "Using input table '%s' (version: %d)", identifier, consumer.instance.version ) def _add_output_table( self, table_path: str, **kwargs, ) -> TableIdentifier: # use service organization and project if "table_path" is just a name if "/" not in table_path: table_path = ( f"{self.service_identifier.organization}/" + f"{self.service_identifier.project}/" + table_path ) identifier = TableIdentifier.from_path(table_path) self._output_identifiers.add(identifier) self._output_kwargs[identifier] = kwargs return identifier async def _stage_output_table(self, identifier: TableIdentifier, create: bool): kwargs = self._output_kwargs[identifier] create = create and "schema" in kwargs if create: table = await self.client.create_table( table_path=str(identifier), update_if_exists=True, **kwargs, ) else: table = await self.client.find_table(table_path=str(identifier)) instance = await table.create_instance(version=self.version, update_if_exists=True) self._outputs[identifier] = instance if create: self.logger.info( "Staged output table '%s' (using version %s)", identifier, self.version ) else: self.logger.info("Found output table '%s' (using version %s)", identifier, self.version) async def _teardown_output_table(self, identifier: TableIdentifier): try: kwargs = self._output_kwargs[identifier] if "schema" not in kwargs: return await self._stage_output_table(identifier=identifier, create=False) instance = self._outputs[identifier] await instance.table.delete() self.logger.info("Deleted output '%s'", identifier) except GraphQLError as e: if " not found" in str(e): return raise async def _update_service_permissions( self, table: Table, read: bool = None, write: bool = None, ): await self.client.admin.services.update_permissions_for_table( service_id=str(self.service_id), table_id=str(table.table_id), read=read, write=write, )