| import datetime |
| import typing as _t |
|
|
| from cachelib.base import BaseCache |
| from cachelib.serializers import DynamoDbSerializer |
|
|
| CREATED_AT_FIELD = "created_at" |
| RESPONSE_FIELD = "response" |
|
|
|
|
| class DynamoDbCache(BaseCache): |
| """ |
| Implementation of cachelib.BaseCache that uses an AWS DynamoDb table |
| as the backend. |
| |
| Your server process will require dynamodb:GetItem and dynamodb:PutItem |
| IAM permissions on the cache table. |
| |
| Limitations: DynamoDB table items are limited to 400 KB in size. Since |
| this class stores cached items in a table, the max size of a cache entry |
| will be slightly less than 400 KB, since the cache key and expiration |
| time fields are also part of the item. |
| |
| :param table_name: The name of the DynamoDB table to use |
| :param default_timeout: Set the timeout in seconds after which cache entries |
| expire |
| :param key_field: The name of the hash_key attribute in the DynamoDb |
| table. This must be a string attribute. |
| :param expiration_time_field: The name of the table attribute to store the |
| expiration time in. This will be an int |
| attribute. The timestamp will be stored as |
| seconds past the epoch. If you configure |
| this as the TTL field, then DynamoDB will |
| automatically delete expired entries. |
| :param key_prefix: A prefix that should be added to all keys. |
| |
| """ |
|
|
| serializer = DynamoDbSerializer() |
|
|
| def __init__( |
| self, |
| table_name: _t.Optional[str] = "python-cache", |
| default_timeout: int = 300, |
| key_field: _t.Optional[str] = "cache_key", |
| expiration_time_field: _t.Optional[str] = "expiration_time", |
| key_prefix: _t.Optional[str] = None, |
| **kwargs: _t.Any |
| ): |
| super().__init__(default_timeout) |
|
|
| try: |
| import boto3 |
| except ImportError as err: |
| raise RuntimeError("no boto3 module found") from err |
|
|
| self._table_name = table_name |
| self._key_field = key_field |
| self._expiration_time_field = expiration_time_field |
| self.key_prefix = key_prefix or "" |
| self._dynamo = boto3.resource("dynamodb", **kwargs) |
| self._attr = boto3.dynamodb.conditions.Attr |
|
|
| try: |
| self._table = self._dynamo.Table(table_name) |
| self._table.load() |
| |
| except Exception: |
| table = self._dynamo.create_table( |
| AttributeDefinitions=[ |
| {"AttributeName": key_field, "AttributeType": "S"} |
| ], |
| TableName=table_name, |
| KeySchema=[ |
| {"AttributeName": key_field, "KeyType": "HASH"}, |
| ], |
| BillingMode="PAY_PER_REQUEST", |
| ) |
| table.wait_until_exists() |
| dynamo = boto3.client("dynamodb", **kwargs) |
| dynamo.update_time_to_live( |
| TableName=table_name, |
| TimeToLiveSpecification={ |
| "Enabled": True, |
| "AttributeName": expiration_time_field, |
| }, |
| ) |
| self._table = self._dynamo.Table(table_name) |
| self._table.load() |
|
|
| def _utcnow(self) -> _t.Any: |
| """Return a tz-aware UTC datetime representing the current time""" |
| return datetime.datetime.utcnow().replace(tzinfo=datetime.timezone.utc) |
|
|
| def _get_item(self, key: str, attributes: _t.Optional[list] = None) -> _t.Any: |
| """ |
| Get an item from the cache table, optionally limiting the returned |
| attributes. |
| |
| :param key: The cache key of the item to fetch |
| |
| :param attributes: An optional list of attributes to fetch. If not |
| given, all attributes are fetched. The |
| expiration_time field will always be added to the |
| list of fetched attributes. |
| :return: The table item for key if it exists and is not expired, else |
| None |
| """ |
| kwargs = {} |
| if attributes: |
| if self._expiration_time_field not in attributes: |
| attributes = list(attributes) + [self._expiration_time_field] |
| kwargs = dict(ProjectionExpression=",".join(attributes)) |
|
|
| response = self._table.get_item(Key={self._key_field: key}, **kwargs) |
| cache_item = response.get("Item") |
|
|
| if cache_item: |
| now = int(self._utcnow().timestamp()) |
| if cache_item.get(self._expiration_time_field, now + 100) > now: |
| return cache_item |
|
|
| return None |
|
|
| def get(self, key: str) -> _t.Any: |
| """ |
| Get a cache item |
| |
| :param key: The cache key of the item to fetch |
| :return: cache value if not expired, else None |
| """ |
| cache_item = self._get_item(self.key_prefix + key) |
| if cache_item: |
| response = cache_item[RESPONSE_FIELD] |
| value = self.serializer.loads(response) |
| return value |
| return None |
|
|
| def delete(self, key: str) -> bool: |
| """ |
| Deletes an item from the cache. This is a no-op if the item doesn't |
| exist |
| |
| :param key: Key of the item to delete. |
| :return: True if the key existed and was deleted |
| """ |
| try: |
|
|
| self._table.delete_item( |
| Key={self._key_field: self.key_prefix + key}, |
| ConditionExpression=self._attr(self._key_field).exists(), |
| ) |
| return True |
| except self._dynamo.meta.client.exceptions.ConditionalCheckFailedException: |
| return False |
|
|
| def _set( |
| self, |
| key: str, |
| value: _t.Any, |
| timeout: _t.Optional[int] = None, |
| overwrite: _t.Optional[bool] = True, |
| ) -> _t.Any: |
| """ |
| Store a cache item, with the option to not overwrite existing items |
| |
| :param key: Cache key to use |
| :param value: a serializable object |
| :param timeout: The timeout in seconds for the cached item, to override |
| the default |
| :param overwrite: If true, overwrite any existing cache item with key. |
| If false, the new value will only be stored if no |
| non-expired cache item exists with key. |
| :return: True if the new item was stored. |
| """ |
| timeout = self._normalize_timeout(timeout) |
| now = self._utcnow() |
|
|
| kwargs = {} |
| if not overwrite: |
| |
| |
|
|
| cond = self._attr(self._key_field).not_exists() | self._attr( |
| self._expiration_time_field |
| ).lte(int(now.timestamp())) |
| kwargs = dict(ConditionExpression=cond) |
|
|
| try: |
| dump = self.serializer.dumps(value) |
| item = { |
| self._key_field: key, |
| CREATED_AT_FIELD: now.isoformat(), |
| RESPONSE_FIELD: dump, |
| } |
| if timeout > 0: |
| expiration_time = now + datetime.timedelta(seconds=timeout) |
| item[self._expiration_time_field] = int(expiration_time.timestamp()) |
| self._table.put_item(Item=item, **kwargs) |
| return True |
| except Exception: |
| return False |
|
|
| def set(self, key: str, value: _t.Any, timeout: _t.Optional[int] = None) -> _t.Any: |
| return self._set(self.key_prefix + key, value, timeout=timeout, overwrite=True) |
|
|
| def add(self, key: str, value: _t.Any, timeout: _t.Optional[int] = None) -> _t.Any: |
| return self._set(self.key_prefix + key, value, timeout=timeout, overwrite=False) |
|
|
| def has(self, key: str) -> bool: |
| return ( |
| self._get_item(self.key_prefix + key, [self._expiration_time_field]) |
| is not None |
| ) |
|
|
| def clear(self) -> bool: |
| paginator = self._dynamo.meta.client.get_paginator("scan") |
| pages = paginator.paginate( |
| TableName=self._table_name, ProjectionExpression=self._key_field |
| ) |
|
|
| with self._table.batch_writer() as batch: |
| for page in pages: |
| for item in page["Items"]: |
| batch.delete_item(Key=item) |
|
|
| return True |
|
|