# allows us to use Client as a type hint without an import cycle
# see: https://www.stefaanlippens.net/circular-imports-type-hints-python.html
# pylint: disable=wrong-import-position,ungrouped-imports
from __future__ import annotations
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from beneath.client import Client
from datetime import timedelta
import json
import sys
from typing import Any, Dict, Tuple
import msgpack
from beneath.config import DEFAULT_CHECKPOINT_COMMIT_DELAY_MS
from beneath.instance import TableInstance
from beneath.utils import AIODelayBuffer, TableIdentifier
SERVICE_CHECKPOINT_LOG_RETENTION = timedelta(hours=6)
SERVICE_CHECKPOINT_SCHEMA = """
type Checkpoint @schema {
key: String! @key
value: Bytes
}
"""
[docs]class Checkpointer:
"""
Checkpointers store (small) key-value records in a meta table (in the specified project).
They are useful for maintaining consumer and pipeline state, such as the cursor for a
subscription or the last time a scraper ran.
Checkpoint keys are strings and values are serialized with msgpack (supports most normal Python
values, but not custom classes).
New checkpointed values are flushed at regular intervals (every 30 seconds by default).
Checkpointers should not be used for storing large amounts of data.
Checkpointers are not currently suitable for synchronizing parallel processes.
"""
instance: TableInstance
""" The meta-table instance that checkpoints are written to """
def __init__(
self,
client: Client,
metatable_identifier: TableIdentifier,
metatable_create: bool,
metatable_description: str,
):
self.instance = None
self._client = client
self._metatable_identifier = metatable_identifier
self._metatable_description = metatable_description
self._create = metatable_create
self._writer = self._Writer(self)
self._cache: Dict[str, Any] = {}
[docs] async def get(self, key: str, default: Any = None) -> Any:
"""
Gets a checkpointed value
"""
if key in self._cache:
return self._cache[key]
if self._client.dry:
return default
filt = json.dumps({"key": key})
cursor = await self.instance.query_index(filter=filt)
value = default
record = await cursor.read_one()
if record is not None:
value = msgpack.unpackb(record["value"], timestamp=3)
# checking self._cache again because of awaits (and we'd rather serve a recent local set)
if key not in self._cache:
self._cache[key] = value
return self._cache[key]
[docs] async def set(self, key: str, value: Any):
""" Sets a checkpoint value. Value will be encoded with msgpack. """
if not self._writer.running:
raise Exception("Cannot call 'set' on checkpointer because the client is stopped")
self._cache[key] = value
await self._writer.write(key, value)
# START/STOP (called by client)
async def _start(self):
if not self.instance:
await self._stage_table()
await self._writer.start()
async def _stop(self):
await self._writer.stop()
# CHECKPOINT TABLE
async def _stage_table(self):
if self._create or self._client.dry:
table = await self._client.create_table(
table_path=str(self._metatable_identifier),
schema=SERVICE_CHECKPOINT_SCHEMA,
description=self._metatable_description,
meta=True,
log_retention=SERVICE_CHECKPOINT_LOG_RETENTION,
use_warehouse=False,
update_if_exists=True,
)
else:
table = await self._client.find_table(table_path=str(self._metatable_identifier))
if not table.primary_instance:
raise Exception("Expected checkpoints table to have a primary instance")
self.instance = table.primary_instance
self._client.logger.info(
"Using '%s' (version %i) for checkpointing",
self._metatable_identifier,
self.instance.version,
)
# CHECKPOINT WRITER
class _Writer(AIODelayBuffer[Tuple[str, Any]]):
checkpoint: Dict[str, Any]
def __init__(self, checkpointer: Checkpointer):
super().__init__(
max_delay_ms=DEFAULT_CHECKPOINT_COMMIT_DELAY_MS,
max_record_size=sys.maxsize,
max_buffer_size=sys.maxsize,
max_buffer_count=sys.maxsize,
)
self.checkpointer = checkpointer
def _reset(self):
self.checkpoint = {}
def _merge(self, value: Tuple[str, Any]):
(key, checkpoint) = value
self.checkpoint[key] = checkpoint
async def _flush(self):
records = (
{
"key": key,
"value": msgpack.packb(checkpoint, datetime=True),
}
for (key, checkpoint) in self.checkpoint.items()
)
if self.checkpointer.instance:
await self.checkpointer._client.write(
instance=self.checkpointer.instance,
records=records,
)
# pylint: disable=arguments-differ
async def write(self, key: str, checkpoint: Any):
await super().write(value=(key, checkpoint), size=0)
class PrefixedCheckpointer:
""" Wraps a Checkpointer and prefixes all keys """
def __init__(self, checkpointer: Checkpointer, prefix: str):
self._checkpointer = checkpointer
self._prefix = prefix
@property
def instance(self):
return self._checkpointer.instance
async def get(self, key: str, default: Any = None) -> Any:
key = self._prefix + key
return await self._checkpointer.get(key, default=default)
async def set(self, key: str, value: Any):
key = self._prefix + key
return await self._checkpointer.set(key, value)