Source code for beneath.client

from collections.abc import Mapping
from datetime import timedelta
import json
import logging
import os
import sys
from typing import Dict, Iterable, List, Union

import pandas as pd

from beneath import config
from beneath.admin.client import AdminClient
from beneath.checkpointer import Checkpointer, PrefixedCheckpointer
from beneath.config import DEFAULT_READ_BATCH_SIZE
from beneath.connection import Connection, GraphQLError
from beneath.consumer import Consumer
from beneath.instance import TableInstance
from beneath.job import Job
from beneath.table import Table
from beneath.utils import infer_avro, ProjectIdentifier, TableIdentifier, SubscriptionIdentifier
from beneath.writer import DryWriter, Writer


[docs]class Client: """ The main class for interacting with Beneath. Data-related features (like defining tables and reading/writing data) are implemented directly on `Client`, while control-plane features (like creating projects) are isolated in the `admin` member. Args: secret (str): A beneath secret to use for authentication. If not set, uses the ``BENEATH_SECRET`` environment variable, and if that is not set either, uses the secret authenticated in the CLI (stored in ``~/.beneath``). dry (bool): If true, the client will not perform any mutations or writes, but generally perform reads as usual. It's useful for testing. The exact implication differs for different operations: Some mutations will be mocked, such as creating a table, others will fail with an exception. Write operations log records to the logger instead of transmitting to the server. Reads generally work, but throw an exception when called on mocked resources. write_delay_ms (int): The maximum amount of time to buffer written records before sending a batch write request over the network. Defaults to 1 second (1000 ms). Writing records in batches reduces the number of requests, which leads to lower cost (Beneath charges at least 1kb per request). """ def __init__( self, secret: str = None, dry: bool = False, write_delay_ms: int = config.DEFAULT_WRITE_DELAY_MS, ): self.connection = Connection(secret=self._get_secret(secret=secret)) self.admin = AdminClient(connection=self.connection, dry=dry) self.dry = dry self.logger = self._make_default_logger() if dry: self._writer = DryWriter(client=self, max_delay_ms=write_delay_ms) else: self._writer = Writer(client=self, max_delay_ms=write_delay_ms) self._start_count = 0 self._checkpointers: Dict[ProjectIdentifier, Checkpointer] = {} self._consumers: Dict[SubscriptionIdentifier, Consumer] = {} @classmethod def _get_secret(cls, secret=None): if not secret: secret = os.getenv("BENEATH_SECRET", default=None) if not secret: secret = config.read_secret() if not secret: raise ValueError( "You are not authenticated (either run 'beneath auth' in the CLI, set the " "BENEATH_SECRET environment variable, or pass a secret to the Client constructor)" ) if not isinstance(secret, str): raise TypeError("secret must be a string") return secret.strip() @classmethod def _make_default_logger(cls): h1 = logging.StreamHandler(sys.stdout) h1.setLevel(logging.INFO) h1.addFilter(lambda record: record.levelno <= logging.INFO) h2 = logging.StreamHandler(sys.stderr) h2.setLevel(logging.WARNING) logging.basicConfig(handlers=[h1, h2]) logger = logging.getLogger("beneath") logger.setLevel(logging.INFO) return logger # FINDING AND STAGING TABLES
[docs] async def find_table(self, table_path: str) -> Table: """ Finds an existing table and returns an object that you can use to read and write from/to the table. Args: path (str): The path to the table in the format of "USERNAME/PROJECT/TABLE" """ identifier = TableIdentifier.from_path(table_path) table = await Table._make(client=self, identifier=identifier) return table
[docs] async def create_table( self, table_path: str, schema: str, description: str = None, meta: bool = None, use_index: bool = None, use_warehouse: bool = None, retention: timedelta = None, log_retention: timedelta = None, index_retention: timedelta = None, warehouse_retention: timedelta = None, schema_kind: str = "GraphQL", indexes: str = None, update_if_exists: bool = None, ) -> Table: """ Creates (or optionally updates if ``update_if_exists=True``) a table and returns it. Args: table_path (str): The (desired) path to the table in the format of "USERNAME/PROJECT/TABLE". The project must already exist. If the table doesn't exist yet, it creates it. schema (str): The GraphQL schema for the table. To learn about the schema definition language, see https://about.beneath.dev/docs/reading-writing-data/schema-definition/. description (str): The description shown for the table in the web console. If not set, tries to infer a description from the schema. retention (timedelta): The amount of time to retain records written to the table. If not set, records will be stored forever. schema_kind (str): The parser to use for ``schema``. Currently must be "GraphQL" (default). update_if_exists (bool): If true and the table already exists, the provided info will be used to update the table (only supports non-breaking schema changes) before returning it. """ identifier = TableIdentifier.from_path(table_path) if self.dry: data = await self.admin.tables.compile_schema( schema_kind=schema_kind, schema=schema, ) table = await Table._make_dry( client=self, identifier=identifier, avro_schema=data["canonicalAvroSchema"], ) else: data = await self.admin.tables.create( organization_name=identifier.organization, project_name=identifier.project, table_name=identifier.table, schema_kind=schema_kind, schema=schema, indexes=indexes, description=description, meta=meta, use_index=use_index, use_warehouse=use_warehouse, log_retention_seconds=self._timedelta_to_seconds( log_retention if log_retention else retention ), index_retention_seconds=self._timedelta_to_seconds( index_retention if index_retention else retention ), warehouse_retention_seconds=self._timedelta_to_seconds( warehouse_retention if warehouse_retention else retention ), update_if_exists=update_if_exists, ) table = await Table._make(client=self, identifier=identifier, admin_data=data) return table
@staticmethod def _timedelta_to_seconds(td: timedelta): if not td: return None secs = int(td.total_seconds()) if secs == 0: return None return secs # WAREHOUSE QUERY
[docs] async def query_warehouse( self, query: str, analyze: bool = False, max_bytes_scanned: int = config.DEFAULT_QUERY_WAREHOUSE_MAX_BYTES_SCANNED, timeout_ms: int = config.DEFAULT_QUERY_WAREHOUSE_TIMEOUT_MS, ): """ Starts a warehouse (OLAP) SQL query, and returns a job for tracking its progress Args: query (str): The analytical SQL query to run. To learn about the query language, see https://about.beneath.dev/docs/reading-writing-data/warehouse-queries/. analyze (bool): If true, analyzes the query and returns info about referenced tables and expected bytes scanned, but doesn't actually run the query. max_bytes_scanned (int): Sets a limit on the number of bytes the query can scan. If exceeded, the job will fail with an error. """ resp = await self.connection.query_warehouse( query=query, dry_run=analyze, max_bytes_scanned=max_bytes_scanned, timeout_ms=timeout_ms, ) job_data = resp.job return Job(client=self, job_id=job_data.job_id, job_data=job_data)
# WRITING
[docs] async def start(self): """ Opens the client for writes. Can be called multiple times, but make sure to call ``stop`` correspondingly. """ self._start_count += 1 if self._start_count != 1: return await self.connection.ensure_connected() await self._writer.start() for checkpointer in self._checkpointers.values(): await checkpointer._start()
[docs] async def stop(self): """ Closes the client for writes, ensuring buffered writes are flushed. If ``start`` was called multiple times, only the last corresponding call to ``stop`` triggers a flush. """ if self._start_count == 0: raise Exception("Called stop more times than start") if self._start_count == 1: for checkpointer in self._checkpointers.values(): await checkpointer._stop() await self._writer.stop() self._start_count -= 1
[docs] async def write(self, instance: TableInstance, records: Union[Mapping, Iterable[Mapping]]): """ Writes one or more records to ``instance``. By default, writes are buffered for up to ``write_delay_ms`` milliseconds before being transmitted to the server. See the Client constructor for details. To enabled writes, make sure to call ``start`` on the client (and ``stop`` before terminating). Args: instance (TableInstance): The instance to write to. You can also call ``instance.write`` as a convenience wrapper. records: The records to write. Can be a single record (dict) or a list of records (iterable of dict). """ if self._start_count == 0: raise Exception("Cannot call write because the client is stopped") await self._writer.write(instance, records)
[docs] async def force_flush(self): """Forces the client to flush buffered writes without stopping""" await self._writer.force_flush()
[docs] async def write_full( self, table_path: str, records: Union[Iterable[dict], pd.DataFrame], key: Union[str, List[str]] = None, description: str = None, recreate_on_schema_change=False, ): """ Infers a schema, creates a table, and writes a full dataset to Beneath. Each call will create a new primary version for the table, and delete the old primary version if/when the write completes succesfully. Args: table_path (str): The (desired) path to the table in the format of "USERNAME/PROJECT/TABLE". The project must already exist. records (list(dict) | pandas.DataFrame): The full dataset to write, either as a list of records or as a Pandas DataFrame. This function uses ``beneath.infer_avro`` to infer a schema for the table based on the records. key (str | list(str)): The fields to use as the table's key. If not set, will default to the dataframe index if ``records`` is a Pandas DataFrame, or add a column of incrementing numbers if ``records`` is a list. description (str): A description for the table. recreate_on_schema_change (bool): If true, and there's an existing table at ``table_path`` with a schema that is incompatible with the inferred schema for ``records``, it will delete the existing table and create a new one instead of throwing an error. Defaults to false. """ if self._start_count == 0: raise Exception("Cannot call write_full because the client is stopped") if len(records) == 0: return # hairy defaults for `key` if isinstance(key, str): key = [key] if not key: if isinstance(records, pd.DataFrame): index_named = None for name in records.index.names: if index_named is None: index_named = name is not None elif index_named == (name is None): raise ValueError( "Cannot write DataFrame with mixed null and non-null" f" index names {records.index.names}" ) if index_named: key = records.index.names records = records.reset_index() else: if len(records.index.names) > 1: key = [f"key{idx}" for idx, name in enumerate(records.index.names)] else: key = ["key"] records.index.names = key records = records.reset_index() else: for idx, record in enumerate(records): if "key" not in record: record["key"] = idx key = ["key"] # infer schema and indexes schema = infer_avro(records) indexes = json.dumps([{"key": True, "fields": key}]) # create the table try: table = await self.create_table( table_path=table_path, schema=schema, description=description, schema_kind="Avro", indexes=indexes, update_if_exists=True, ) except GraphQLError as e: if "Schema error:" not in str(e): raise if not recreate_on_schema_change: raise ValueError( "Cannot create table because an existing table *with a different schema* " f"already exists at '{table_path}'. To delete the existing table and all its" " versions and records, pass recreate_on_schema_change=True." ) # recreate table = await self.find_table(table_path) await self.admin.tables.delete(table_id=str(table.table_id)) table = await self.create_table( table_path=table_path, schema=schema, description=description, schema_kind="Avro", indexes=indexes, update_if_exists=True, ) # get instance next_instance = None previous_instance = None if table.primary_instance is None: next_instance = await table.create_instance() elif table.primary_instance.is_final: next_instance = await table.create_instance() previous_instance = table.primary_instance else: next_instance = table.primary_instance # write records if isinstance(records, pd.DataFrame): records = records.to_dict(orient="records") await next_instance.write(records) # make primary and final, and delete previous instance await next_instance.update(make_primary=True, make_final=True) if previous_instance: await previous_instance.delete() return table
# CHECKPOINTERS
[docs] async def checkpointer( self, project_path: str, key_prefix: str = None, metatable_name="checkpoints", metatable_create: bool = True, metatable_description="Stores checkpointed state for consumers, pipelines, and more", ) -> Checkpointer: """ Returns a checkpointer for the given project. Checkpointers store (small) key-value records useful for maintaining consumer and pipeline state. State is stored in a meta-table called "checkpoints" in the given project. Args: project_path (str): Path to the project in which to store the checkpointer's state key_prefix (str): If set, any ``get`` or ``set`` call on the checkpointer will prepend the prefix to the key. metatable_name (str): Name of the meta table in which to save checkpointed data metatable_create (bool): If true, the checkpointer will create the checkpoints meta-table if it does not already exists. If false, the checkpointer will throw an exception if the meta-table does not already exist. Defaults to True. metatable_description (str): An optional description to apply to the checkpoints meta-table. Defaults to a sensible description of checkpointing. """ project_identifier = ProjectIdentifier.from_path(project_path) identifier = TableIdentifier( organization=project_identifier.organization, project=project_identifier.project, table=metatable_name, ) if identifier not in self._checkpointers: checkpointer = Checkpointer( client=self, metatable_identifier=identifier, metatable_create=metatable_create, metatable_description=metatable_description, ) self._checkpointers[identifier] = checkpointer if self._start_count != 0: await checkpointer._start() checkpointer = self._checkpointers[identifier] if key_prefix: checkpointer = PrefixedCheckpointer(checkpointer, key_prefix) return checkpointer
# CONSUMERS
[docs] async def consumer( self, table_path: str, version: int = None, batch_size: int = DEFAULT_READ_BATCH_SIZE, subscription_path: str = None, checkpointer: Checkpointer = None, metatable_create: bool = True, ): """ Creates a consumer for the given table. Consumers make it easy to replay the history of a table and/or subscribe to new changes. Args: table_path (str): Path to the table to subscribe to. The consumer will subscribe to the table's primary version. version (int): The instance version to use for table. If not set, uses the primary instance. batch_size (int): Sets the max number of records to load in each network request. Defaults to 1000. subscription_path (str): Format "ORGANIZATION/PROJECT/NAME". If set, the consumer will use a checkpointer to save cursors. That means processing will not restart from scratch if the process ends or crashes (as long as you use the same subscription name). To reset a subscription, call ``reset`` on the consumer. checkpointer (Checkpointer): Only applies if ``subscription_path`` is set. Provides a specific checkpointer to use for consumer state. If not set, will create one in the subscription's project. metatable_create (bool): Only applies if ``subscription_path`` is set and ``checkpointer`` is not set. Passed through to ``client.checkpointer``. """ table_identifier = TableIdentifier.from_path(table_path) if not subscription_path: consumer = Consumer( client=self, table_identifier=table_identifier, batch_size=batch_size, ) await consumer._init() return consumer sub_identifier = SubscriptionIdentifier.from_path(subscription_path) if sub_identifier not in self._consumers: if checkpointer is None: checkpointer = await self.checkpointer( f"{sub_identifier.organization}/{sub_identifier.project}", metatable_create=metatable_create, ) consumer = Consumer( client=self, table_identifier=table_identifier, version=version, batch_size=batch_size, checkpointer=checkpointer, subscription_name=sub_identifier.subscription, ) await consumer._init() self._consumers[sub_identifier] = consumer return self._consumers[sub_identifier]