1import logging
2
3from plain.exceptions import SuspiciousOperation
4from plain.models import DatabaseError, IntegrityError, router, transaction
5from plain.sessions.backends.base import CreateError, SessionBase, UpdateError
6from plain.utils import timezone
7from plain.utils.functional import cached_property
8
9
10class SessionStore(SessionBase):
11 """
12 Implement database session store.
13 """
14
15 def __init__(self, session_key=None):
16 super().__init__(session_key)
17
18 @classmethod
19 def get_model_class(cls):
20 # Avoids a circular import and allows importing SessionStore when
21 # plain.sessions is not in INSTALLED_PACKAGES.
22 from plain.sessions.models import Session
23
24 return Session
25
26 @cached_property
27 def model(self):
28 return self.get_model_class()
29
30 def _get_session_from_db(self):
31 try:
32 return self.model.objects.get(
33 session_key=self.session_key, expire_date__gt=timezone.now()
34 )
35 except (self.model.DoesNotExist, SuspiciousOperation) as e:
36 if isinstance(e, SuspiciousOperation):
37 logger = logging.getLogger("plain.security.%s" % e.__class__.__name__)
38 logger.warning(str(e))
39 self._session_key = None
40
41 def load(self):
42 s = self._get_session_from_db()
43 return self.decode(s.session_data) if s else {}
44
45 def exists(self, session_key):
46 return self.model.objects.filter(session_key=session_key).exists()
47
48 def create(self):
49 while True:
50 self._session_key = self._get_new_session_key()
51 try:
52 # Save immediately to ensure we have a unique entry in the
53 # database.
54 self.save(must_create=True)
55 except CreateError:
56 # Key wasn't unique. Try again.
57 continue
58 self.modified = True
59 return
60
61 def create_model_instance(self, data):
62 """
63 Return a new instance of the session model object, which represents the
64 current session state. Intended to be used for saving the session data
65 to the database.
66 """
67 return self.model(
68 session_key=self._get_or_create_session_key(),
69 session_data=self.encode(data),
70 expire_date=self.get_expiry_date(),
71 )
72
73 def save(self, must_create=False):
74 """
75 Save the current session data to the database. If 'must_create' is
76 True, raise a database error if the saving operation doesn't create a
77 new entry (as opposed to possibly updating an existing entry).
78 """
79 if self.session_key is None:
80 return self.create()
81 data = self._get_session(no_load=must_create)
82 obj = self.create_model_instance(data)
83 using = router.db_for_write(self.model, instance=obj)
84 try:
85 with transaction.atomic(using=using):
86 obj.save(
87 clean_and_validate=False,
88 force_insert=must_create,
89 force_update=not must_create,
90 using=using,
91 )
92 except IntegrityError:
93 if must_create:
94 raise CreateError
95 raise
96 except DatabaseError:
97 if not must_create:
98 raise UpdateError
99 raise
100
101 def delete(self, session_key=None):
102 if session_key is None:
103 if self.session_key is None:
104 return
105 session_key = self.session_key
106 try:
107 self.model.objects.get(session_key=session_key).delete()
108 except self.model.DoesNotExist:
109 pass
110
111 @classmethod
112 def clear_expired(cls):
113 cls.get_model_class().objects.filter(expire_date__lt=timezone.now()).delete()