Source code for beneath.consumer

# allows us to use Client as a type hint without an import cycle
# see: https://www.stefaanlippens.net/circular-imports-type-hints-python.html
# pylint: disable=wrong-import-position,ungrouped-imports
from __future__ import annotations
from typing import TYPE_CHECKING

if TYPE_CHECKING:
    from beneath.client import Client

from collections.abc import Mapping
import inspect
from typing import Awaitable, Callable, Iterable

from beneath.checkpointer import Checkpointer
from beneath.config import DEFAULT_READ_BATCH_SIZE
from beneath.cursor import Cursor
from beneath.instance import TableInstance
from beneath.utils import TableIdentifier

ConsumerCallback = Callable[[Mapping], Awaitable]


[docs]class Consumer: """ Consumers are used to replay/subscribe to a table. If the consumer is initialized with a project and subscription name, it will checkpoint its progress to avoid reprocessing the same data every time the process starts. """ instance: TableInstance """ The table instance the consumer is subscribed to """ cursor: Cursor """ The cursor used to replay and subscribe the table. You can use it to get the current state of the the underlying replay and changes cursors. """ def __init__( self, client: Client, table_identifier: TableIdentifier, batch_size: int = DEFAULT_READ_BATCH_SIZE, version: int = None, checkpointer: Checkpointer = None, subscription_name: str = None, ): self._client = client self._table_identifier = table_identifier self._version = version self._batch_size = batch_size self._checkpointer = checkpointer self._subscription_name = subscription_name async def _init(self): table = await self._client.find_table(table_path=str(self._table_identifier)) if self._version is not None: self.instance = await table.find_instance(version=self._version) else: self.instance = table.primary_instance if not self.instance: raise ValueError( f"Cannot consume table {self._table_identifier}" " because it doesn't have a primary instance" ) await self._init_cursor()
[docs] async def reset(self): """ Resets the consumer's replay and changes cursor. """ await self._init_cursor(reset=True)
[docs] async def replay(self, cb: ConsumerCallback, max_concurrency: int = 1): """ Calls the callback with every historical record in the table in the order they were written. Returns when all historical records have been processed. Args: cb (async def fn(record)): Async function for processing a record. max_concurrency (int): The maximum number of callbacks to call concurrently. Defaults to 1. """ await self.subscribe(cb=cb, max_concurrency=max_concurrency, replay_only=True)
[docs] async def subscribe( self, cb: ConsumerCallback, max_concurrency: int = 1, replay_only: bool = False, changes_only: bool = False, stop_when_idle: bool = False, ): """ Replays the table and subscribes for new changes (runs forever unless stop_when_idle=True or the instance is finalized). Calls the callback for every record. Args: cb (async def fn(record)): Async function for processing a record. max_concurrency (int): The maximum number of callbacks to call concurrently. Defaults to 1. replay_only (bool): If true, will not read changes, but only replay historical records. Defaults to False. changes_only (bool): If true, will not replay historical records, but only subscribe to new changes. Defaults to False. stop_when_idle (bool): If true, will return when "caught up" and no new changes are available. Defaults to False. """ async for batch in self.iterate( batches=True, replay_only=replay_only, changes_only=changes_only, stop_when_idle=stop_when_idle, ): await self._callback_batch(batch, cb, max_concurrency)
[docs] async def iterate( self, batches: bool = False, replay_only: bool = False, changes_only: bool = False, stop_when_idle: bool = False, ): """ Replays the table and subscribes for new changes (runs forever unless stop_when_idle=True or the instance is finalized). Yields every record (or batch if batches=True). Args: batches (bool): If true, yields batches of records as they're loaded (instead of individual records) replay_only (bool): If true, will not read changes, but only replay historical records. Defaults to False. changes_only (bool): If true, will not replay historical records, but only subscribe to new changes. Defaults to False. stop_when_idle (bool): If true, will return when "caught up" and no new changes are available. Defaults to False. """ if replay_only and changes_only: raise Exception("cannot set replay_only=True and changes_only=True for iterate") if not changes_only: if self.cursor.replay_cursor: self._client.logger.info( "Replaying table '%s' (version %i)", self._table_identifier, self.instance.version, ) async for batch in self._run_replay(): if batches: yield batch else: for record in batch: yield record if replay_only: return if stop_when_idle or self.instance.is_final: self._client.logger.info( "Consuming changes for table '%s' (version %i)", self._table_identifier, self.instance.version, ) it = self._run_delta() else: self._client.logger.info( "Subscribed to changes for table '%s' (version %i)", self._table_identifier, self.instance.version, ) it = self._run_subscribe() async for batch in it: if batches: yield batch else: for record in batch: yield record if self.instance.is_final: self._client.logger.info( "Stopped consuming changes for table '%s' (version %i) because it has been" " finalized", self._table_identifier, self.instance.version, )
# CURSORS / CHECKPOINTS @property def _subscription_cursor_key(self): return self._subscription_name + ":" + str(self.instance.instance_id) + ":cursor" async def _init_cursor(self, reset=False): if not reset and self._checkpointer: state = await self._checkpointer.get(self._subscription_cursor_key) if state: self.cursor = self.instance.table.restore_cursor( replay_cursor=state.get("replay"), changes_cursor=state.get("changes"), ) return self.cursor = await self.instance.query_log() if reset: await self._checkpoint() async def _checkpoint(self): if not self._checkpointer: return state = {} if self.cursor.replay_cursor: state["replay"] = self.cursor.replay_cursor if self.cursor.changes_cursor: state["changes"] = self.cursor.changes_cursor await self._checkpointer.set(self._subscription_cursor_key, state) # RUNNING / CALLBACKS async def _run_replay(self): if not self.cursor.replay_cursor: return while True: batch = await self.cursor.read_next(limit=self._batch_size) if not batch: return yield batch await self._checkpoint() async def _run_delta(self): if not self.cursor.changes_cursor: return while True: batch = await self.cursor.read_next_changes(limit=self._batch_size) if not batch: return batch = list(batch) if len(batch) == 0: return yield batch await self._checkpoint() if len(batch) < self._batch_size: return async def _run_subscribe(self): if not self.cursor.changes_cursor: return async for batch in self.cursor.subscribe_changes(batch_size=self._batch_size): yield batch await self._checkpoint() async def _callback_batch( self, batch: Iterable[Mapping], cb: ConsumerCallback, max_concurrency: int, ): # TODO: respect when max_concurrency != 1 if inspect.iscoroutinefunction(cb): for record in batch: await cb(record) else: for record in batch: cb(record)