Source code for beneath.cursor

# Allows us to use classes as type hints without an import cycle
from __future__ import annotations
from typing import TYPE_CHECKING

if TYPE_CHECKING:
    from beneath.connection import Connection
    from beneath.schema import Schema

import asyncio
from collections.abc import Mapping
import sys
from typing import AsyncIterator, Awaitable, Callable, Iterable, List
import warnings

import pandas as pd

from beneath import config
from beneath.utils import AIOTicker


[docs]class Cursor: """ A cursor allows you to page through the results from of a query in Beneath. """ def __init__( self, connection: Connection, schema: Schema, replay_cursor: bytes, changes_cursor: bytes, ): self.connection = connection self.schema = schema self.replay_cursor = replay_cursor """ The replay cursor, which pages through the initial query results """ self.changes_cursor = changes_cursor """ The change cursor, which can return updates since the query was started """ @property def _top_level_columns(self): return [field["name"] for field in self.schema.parsed_avro["fields"]].append( "@meta.timestamp" )
[docs] async def read_one(self): """ Returns the first record or None if the cursor is empty """ batch = await self.read_next(limit=1) if batch is not None: for record in batch: return record return None
[docs] async def read_next( self, limit: int = config.DEFAULT_READ_BATCH_SIZE, to_dataframe=False ) -> Iterable[Mapping]: """ Returns a new page of results and advances the replay cursor """ batch = await self._read_next_replay(limit=limit) if batch is None: return None records = (self.schema.pb_to_record(pb, to_dataframe) for pb in batch) if to_dataframe: return pd.DataFrame(records, columns=self._top_level_columns) return records
[docs] async def read_next_changes( self, limit: int = config.DEFAULT_READ_BATCH_SIZE, to_dataframe=False, ) -> Iterable[Mapping]: """ Returns a new page of changes and advances the change cursor """ batch = await self._read_next_changes(limit=limit) if batch is None: return None records = (self.schema.pb_to_record(pb, to_dataframe) for pb in batch) if to_dataframe: return pd.DataFrame(records, columns=self._top_level_columns) return records
[docs] async def read_all( self, max_records=None, max_bytes=config.DEFAULT_READ_ALL_MAX_BYTES, batch_size=config.DEFAULT_READ_BATCH_SIZE, warn_max=True, to_dataframe=False, ) -> Iterable[Mapping]: """ Returns all records in the cursor (up to the limits of max_records and max_bytes) """ # compute limits max_records = max_records if max_records else sys.maxsize max_bytes = max_bytes if max_bytes else sys.maxsize # loop state records = [] complete = False bytes_loaded = 0 # loop until all records fetched or limits reached while len(records) < max_records and bytes_loaded < max_bytes: limit = min(max_records - len(records), batch_size) assert limit >= 0 batch = await self._read_next_replay(limit=limit) if batch is None: complete = True break batch_len = 0 for pb in batch: record = self.schema.pb_to_record(pb=pb, to_dataframe=False) records.append(record) batch_len += 1 bytes_loaded += len(pb.avro_data) if not complete and warn_max: # Jupyter doesn't always display warnings, so also print if len(records) >= max_records: err = f"Stopped loading because result exceeded max_records={max_records}" print(err) warnings.warn(err) elif bytes_loaded >= max_bytes: err = f"Stopped loading because download size exceeded max_bytes={max_bytes}" print(err) warnings.warn(err) if to_dataframe: return pd.DataFrame(records, columns=self._top_level_columns) return records
async def _read_next_replay(self, limit: int): if not self.replay_cursor: return None resp = await self.connection.read( cursor=self.replay_cursor, limit=limit, ) self.replay_cursor = resp.next_cursor return resp.records async def _read_next_changes(self, limit: int): if not self.changes_cursor: return None resp = await self.connection.read( cursor=self.changes_cursor, limit=limit, ) self.changes_cursor = resp.next_cursor return resp.records
[docs] async def subscribe_changes( self, batch_size=config.DEFAULT_READ_BATCH_SIZE, poll_at_most_every_ms=config.DEFAULT_SUBSCRIBE_POLL_AT_MOST_EVERY_MS, ) -> AsyncIterator[List[Mapping]]: """ Similar to subscribe_changes_with_callback, but as an async iterator. Note that yielded values are batches, not individual records. """ queue = asyncio.Queue() done = Exception("DONE") async def callback(records, _): await queue.put(records) def done_callback(task: asyncio.Task): if task.exception(): queue.put_nowait(task.exception()) else: queue.put_nowait(done) coro = self.subscribe_changes_with_callback( callback, batch_size=batch_size, poll_at_most_every_ms=poll_at_most_every_ms, ) task = asyncio.create_task(coro) task.add_done_callback(done_callback) while True: item = await queue.get() if isinstance(item, Exception): if item == done: return raise item yield item queue.task_done()
[docs] async def subscribe_changes_with_callback( self, callback: Callable[[List[Mapping], Cursor], Awaitable[None]], batch_size=config.DEFAULT_READ_BATCH_SIZE, poll_at_most_every_ms=config.DEFAULT_SUBSCRIBE_POLL_AT_MOST_EVERY_MS, ): """ Subscribes to new changes and calls ``callback`` with new batches of records. """ ticker = AIOTicker( at_least_every_ms=config.DEFAULT_SUBSCRIBE_POLL_AT_LEAST_EVERY_MS, at_most_every_ms=poll_at_most_every_ms, ) async def _poll(): while True: batch = await self.read_next_changes(limit=batch_size) batch = list(batch) if len(batch) != 0: await callback(batch, self) if len(batch) < batch_size: break async def _tick(): # poll on startup await _poll() async for _ in ticker: await _poll() async def _subscribe(): subscription = self.connection.subscribe( cursor=self.changes_cursor, ) async for _ in subscription: ticker.trigger() await asyncio.gather(_subscribe(), _tick())