Plain is headed towards 1.0! Subscribe for development updates →

  1import logging
  2
  3from plain.exceptions import SuspiciousOperation
  4from plain.models import DatabaseError, IntegrityError, router, transaction
  5from plain.sessions.backends.base import CreateError, SessionBase, UpdateError
  6from plain.utils import timezone
  7from plain.utils.functional import cached_property
  8
  9
 10class SessionStore(SessionBase):
 11    """
 12    Implement database session store.
 13    """
 14
 15    def __init__(self, session_key=None):
 16        super().__init__(session_key)
 17
 18    @classmethod
 19    def get_model_class(cls):
 20        # Avoids a circular import and allows importing SessionStore when
 21        # plain.sessions is not in INSTALLED_PACKAGES.
 22        from plain.sessions.models import Session
 23
 24        return Session
 25
 26    @cached_property
 27    def model(self):
 28        return self.get_model_class()
 29
 30    def _get_session_from_db(self):
 31        try:
 32            return self.model.objects.get(
 33                session_key=self.session_key, expire_date__gt=timezone.now()
 34            )
 35        except (self.model.DoesNotExist, SuspiciousOperation) as e:
 36            if isinstance(e, SuspiciousOperation):
 37                logger = logging.getLogger("plain.security.%s" % e.__class__.__name__)
 38                logger.warning(str(e))
 39            self._session_key = None
 40
 41    def load(self):
 42        s = self._get_session_from_db()
 43        return self.decode(s.session_data) if s else {}
 44
 45    def exists(self, session_key):
 46        return self.model.objects.filter(session_key=session_key).exists()
 47
 48    def create(self):
 49        while True:
 50            self._session_key = self._get_new_session_key()
 51            try:
 52                # Save immediately to ensure we have a unique entry in the
 53                # database.
 54                self.save(must_create=True)
 55            except CreateError:
 56                # Key wasn't unique. Try again.
 57                continue
 58            self.modified = True
 59            return
 60
 61    def create_model_instance(self, data):
 62        """
 63        Return a new instance of the session model object, which represents the
 64        current session state. Intended to be used for saving the session data
 65        to the database.
 66        """
 67        return self.model(
 68            session_key=self._get_or_create_session_key(),
 69            session_data=self.encode(data),
 70            expire_date=self.get_expiry_date(),
 71        )
 72
 73    def save(self, must_create=False):
 74        """
 75        Save the current session data to the database. If 'must_create' is
 76        True, raise a database error if the saving operation doesn't create a
 77        new entry (as opposed to possibly updating an existing entry).
 78        """
 79        if self.session_key is None:
 80            return self.create()
 81        data = self._get_session(no_load=must_create)
 82        obj = self.create_model_instance(data)
 83        using = router.db_for_write(self.model, instance=obj)
 84        try:
 85            with transaction.atomic(using=using):
 86                obj.save(
 87                    clean_and_validate=False,
 88                    force_insert=must_create,
 89                    force_update=not must_create,
 90                    using=using,
 91                )
 92        except IntegrityError:
 93            if must_create:
 94                raise CreateError
 95            raise
 96        except DatabaseError:
 97            if not must_create:
 98                raise UpdateError
 99            raise
100
101    def delete(self, session_key=None):
102        if session_key is None:
103            if self.session_key is None:
104                return
105            session_key = self.session_key
106        try:
107            self.model.objects.get(session_key=session_key).delete()
108        except self.model.DoesNotExist:
109            pass
110
111    @classmethod
112    def clear_expired(cls):
113        cls.get_model_class().objects.filter(expire_date__lt=timezone.now()).delete()