Plain is headed towards 1.0! Subscribe for development updates →

  1from __future__ import annotations
  2
  3from datetime import datetime, timedelta
  4from functools import cached_property
  5from typing import Any
  6
  7from opentelemetry import trace
  8from opentelemetry.semconv.attributes.db_attributes import (
  9    DB_NAMESPACE,
 10    DB_OPERATION_NAME,
 11    DB_SYSTEM_NAME,
 12)
 13from opentelemetry.trace import SpanKind
 14
 15from plain.models import IntegrityError
 16from plain.utils import timezone
 17
 18tracer = trace.get_tracer("plain.cache")
 19
 20
 21class Cached:
 22    """Store and retrieve cached items."""
 23
 24    def __init__(self, key: str) -> None:
 25        self.key = key
 26
 27        # So we can import Cached in __init__.py
 28        # without getting the packages not ready error...
 29        from .models import CachedItem
 30
 31        self._model_class = CachedItem
 32
 33    @cached_property
 34    def _model_instance(self) -> Any:
 35        try:
 36            return self._model_class.query.get(key=self.key)
 37        except self._model_class.DoesNotExist:
 38            return None
 39
 40    def reload(self) -> None:
 41        if hasattr(self, "_model_instance"):
 42            del self._model_instance
 43
 44    def _is_expired(self) -> bool:
 45        if not self._model_instance:
 46            return True
 47
 48        if not self._model_instance.expires_at:
 49            return False
 50
 51        return self._model_instance.expires_at < timezone.now()
 52
 53    def exists(self) -> bool:
 54        with tracer.start_as_current_span(
 55            "cache.exists",
 56            kind=SpanKind.CLIENT,
 57            attributes={
 58                DB_SYSTEM_NAME: "plain.cache",
 59                DB_OPERATION_NAME: "get",
 60                DB_NAMESPACE: "cache",
 61                "cache.key": self.key,
 62            },
 63        ) as span:
 64            span.set_status(trace.StatusCode.OK)
 65
 66            if self._model_instance is None:
 67                return False
 68
 69            return not self._is_expired()
 70
 71    @property
 72    def value(self) -> Any:
 73        with tracer.start_as_current_span(
 74            "cache.get",
 75            kind=SpanKind.CLIENT,
 76            attributes={
 77                DB_SYSTEM_NAME: "plain.cache",
 78                DB_OPERATION_NAME: "get",
 79                DB_NAMESPACE: "cache",
 80                "cache.key": self.key,
 81            },
 82        ) as span:
 83            if self._model_instance and self._model_instance.expires_at:
 84                span.set_attribute(
 85                    "cache.item.expires_at", self._model_instance.expires_at.isoformat()
 86                )
 87
 88            exists = self.exists()
 89
 90            span.set_attribute("cache.hit", exists)
 91            span.set_status(trace.StatusCode.OK if exists else trace.StatusCode.UNSET)
 92
 93            if not exists:
 94                return None
 95
 96            return self._model_instance.value
 97
 98    def set(
 99        self, value: Any, expiration: datetime | timedelta | int | float | None = None
100    ) -> Any:
101        defaults = {
102            "value": value,
103        }
104
105        if isinstance(expiration, int | float):
106            defaults["expires_at"] = timezone.now() + timedelta(seconds=expiration)  # type: ignore[arg-type]
107        elif isinstance(expiration, timedelta):
108            defaults["expires_at"] = timezone.now() + expiration
109        elif isinstance(expiration, datetime):
110            defaults["expires_at"] = expiration
111        else:
112            # Keep existing expires_at value or None
113            pass
114
115        # Make sure expires_at is timezone aware
116        if (
117            "expires_at" in defaults
118            and defaults["expires_at"]
119            and not timezone.is_aware(defaults["expires_at"])
120        ):
121            defaults["expires_at"] = timezone.make_aware(defaults["expires_at"])
122
123        with tracer.start_as_current_span(
124            "cache.set",
125            kind=SpanKind.CLIENT,
126            attributes={
127                DB_SYSTEM_NAME: "plain.cache",
128                DB_OPERATION_NAME: "set",
129                DB_NAMESPACE: "cache",
130                "cache.key": self.key,
131            },
132        ) as span:
133            if expires_at := defaults.get("expires_at"):
134                span.set_attribute("cache.item.expires_at", expires_at.isoformat())
135
136            try:
137                item, _ = self._model_class.query.update_or_create(
138                    key=self.key, defaults=defaults
139                )
140            except IntegrityError:
141                # Most likely a race condition in creating the item,
142                # so trying again should do an update
143                item, _ = self._model_class.query.update_or_create(
144                    key=self.key, defaults=defaults
145                )
146
147            self.reload()
148            span.set_status(trace.StatusCode.OK)
149            return item.value
150
151    def delete(self) -> bool:
152        with tracer.start_as_current_span(
153            "cache.delete",
154            kind=SpanKind.CLIENT,
155            attributes={
156                DB_SYSTEM_NAME: "plain.cache",
157                DB_OPERATION_NAME: "delete",
158                DB_NAMESPACE: "cache",
159                "cache.key": self.key,
160            },
161        ) as span:
162            span.set_status(trace.StatusCode.OK)
163            if not self._model_instance:
164                # A no-op, but a return value you can use to know whether it did anything
165                return False
166
167            self._model_instance.delete()
168            self.reload()
169            return True