Source code for beneath.connection

from typing import AsyncIterator
import uuid
import warnings

import aiohttp
import grpc
from datetime import datetime, timedelta

from beneath import __version__
from beneath import config
from beneath.proto import gateway_pb2
from beneath.proto import gateway_pb2_grpc

# Mirrors server/data/grpc/server.go
MAX_RECV_MSG_SIZE = 1024 * 1024 * 50
MAX_SEND_MSG_SIZE = 1024 * 1024 * 10


[docs]class GraphQLError(Exception): """ Error returned for control-plane (GraphQL) errors """ def __init__(self, message, errors): super().__init__(message) self.errors = errors
[docs]class AuthenticationError(Exception): """ Error returned for failed authentication """
class Connection: """ Encapsulates network connectivity to Beneath """ def __init__(self, secret: str = None): self.secret = secret self.connected = False self.request_metadata = None self.channel: grpc.aio.Channel = None self.stub: gateway_pb2_grpc.GatewayStub = None self.pong: gateway_pb2.PingResponse = None def __getstate__(self): return {"secret": self.secret} def __setstate__(self, obj): self.secret = obj["secret"] # GRPC CONNECTIVITY async def ensure_connected(self, check_authenticated=True): """ Called before each network call (write, query, etc.). May also be called directly. On first run, it sets up grpc, sends a ping, checks library version and secret validity. On subsequent runs, it does nothing. """ if not self.connected: self._create_grpc_connection() pong = await self._ping() self._check_pong_status(pong) self.pong = pong self.connected = True if check_authenticated: if not self.pong.authenticated: raise AuthenticationError( "You must authenticate with 'beneath auth' or by setting BENEATH_SECRET" ) def _create_grpc_connection(self): self.request_metadata = [] if self.secret: self.request_metadata.append(("authorization", "Bearer {}".format(self.secret))) insecure = config.DEV options = [ ("grpc.max_receive_message_length", MAX_RECV_MSG_SIZE), ("grpc.max_send_message_length", MAX_SEND_MSG_SIZE), ] if insecure: self.channel = grpc.aio.insecure_channel( target=config.BENEATH_GATEWAY_HOST_GRPC, compression=grpc.Compression.Gzip, options=options, ) else: self.channel = grpc.aio.secure_channel( target=config.BENEATH_GATEWAY_HOST_GRPC, credentials=grpc.ssl_channel_credentials(), compression=grpc.Compression.Gzip, options=options, ) self.stub = gateway_pb2_grpc.GatewayStub(self.channel) @classmethod def _check_pong_status(cls, pong: gateway_pb2.PingResponse): if config.DEV: return if pong.version_status == "warning": warnings.warn( f"This version ({__version__}) of the Beneath python library will soon be " f"deprecated (recommended: {pong.recommended_version}). " "Update with 'pip install --upgrade beneath'." ) elif pong.version_status == "deprecated": raise Exception( f"This version ({__version__}) of the Beneath python library is out-of-date " f"(recommended: {pong.recommended_version}). " "Update with 'pip install --upgrade beneath' to continue." ) async def _ping(self) -> gateway_pb2.PingResponse: return await self.stub.Ping( gateway_pb2.PingRequest( client_id=config.PYTHON_CLIENT_ID, client_version=__version__, ), metadata=self.request_metadata, ) # CONTROL-PLANE async def query_control(self, query, variables, check_authenticated=True): """ Sends a GraphQL query to the control server """ await self.ensure_connected(check_authenticated=check_authenticated) for k, v in variables.items(): if isinstance(v, uuid.UUID): variables[k] = v.hex url = f"{config.BENEATH_CONTROL_HOST}/graphql" headers = {} if self.secret: headers["Authorization"] = f"Bearer {self.secret}" body = {"query": query, "variables": variables} async with aiohttp.ClientSession() as session: async with session.post(url=url, headers=headers, json=body) as response: # handles malformed queries if 400 <= response.status < 500: raise ValueError(f"{response.status} Client Error: {await response.text()}") response.raise_for_status() obj = await response.json() # handles resolver errors if "errors" in obj: first_err = obj["errors"][0] msg = f"{first_err['message']} (path: {first_err['path']})" raise GraphQLError(msg, obj["errors"]) # successful result return obj["data"] # DATA-PLANE async def write( self, instance_records: gateway_pb2.InstanceRecords, ) -> gateway_pb2.WriteResponse: await self.ensure_connected() return await self.stub.Write( gateway_pb2.WriteRequest(instance_records=instance_records), metadata=self.request_metadata, ) async def query_log(self, instance_id: uuid.UUID, peek: bool) -> gateway_pb2.QueryLogResponse: await self.ensure_connected() return await self.stub.QueryLog( gateway_pb2.QueryLogRequest( instance_id=instance_id.bytes, partitions=1, peek=peek, ), metadata=self.request_metadata, ) # pylint: disable=redefined-builtin async def query_index( self, instance_id: uuid.UUID, filter: str ) -> gateway_pb2.QueryIndexResponse: await self.ensure_connected() return await self.stub.QueryIndex( gateway_pb2.QueryIndexRequest( instance_id=instance_id.bytes, partitions=1, filter=filter, ), metadata=self.request_metadata, ) async def query_warehouse( self, query: str, dry_run: bool, max_bytes_scanned: int, timeout_ms: int ) -> gateway_pb2.QueryWarehouseResponse: await self.ensure_connected() return await self.stub.QueryWarehouse( gateway_pb2.QueryWarehouseRequest( query=query, dry_run=dry_run, max_bytes_scanned=max_bytes_scanned, timeout_ms=timeout_ms, ), metadata=self.request_metadata, ) async def poll_warehouse_job(self, job_id: bytes) -> gateway_pb2.PollWarehouseJobResponse: await self.ensure_connected() return await self.stub.PollWarehouseJob( gateway_pb2.PollWarehouseJobRequest(job_id=job_id), metadata=self.request_metadata, ) async def read(self, cursor: bytes, limit: int) -> gateway_pb2.ReadResponse: await self.ensure_connected() return await self.stub.Read( gateway_pb2.ReadRequest( cursor=cursor, limit=limit, ), metadata=self.request_metadata, ) async def subscribe(self, cursor: bytes) -> AsyncIterator[gateway_pb2.SubscribeResponse]: retry = True while retry: retry = False started = datetime.now() try: await self.ensure_connected() subscription = self.stub.Subscribe( gateway_pb2.SubscribeRequest(cursor=cursor), metadata=self.request_metadata, ) async for msg in subscription: yield msg except grpc.RpcError as e: is_cancel = e.code() in [ grpc.StatusCode.CANCELLED, grpc.StatusCode.UNAVAILABLE, ] is_not_immediate = datetime.now() - started >= timedelta(seconds=15) if is_cancel and is_not_immediate: retry = True else: raise e