1from __future__ import annotations
2
3from collections.abc import Callable, Generator
4from contextlib import ContextDecorator, contextmanager
5from typing import Any, TypeVar
6
7from plain.models.db import DatabaseError, Error, ProgrammingError, db_connection
8
9F = TypeVar("F", bound=Callable[..., Any])
10
11
12class TransactionManagementError(ProgrammingError):
13 """Transaction management is used improperly."""
14
15 pass
16
17
18@contextmanager
19def mark_for_rollback_on_error() -> Generator[None, None, None]:
20 """
21 Internal low-level utility to mark a transaction as "needs rollback" when
22 an exception is raised while not enforcing the enclosed block to be in a
23 transaction. This is needed by Model.save() and friends to avoid starting a
24 transaction when in autocommit mode and a single query is executed.
25
26 It's equivalent to:
27
28 if db_connection.get_autocommit():
29 yield
30 else:
31 with transaction.atomic(savepoint=False):
32 yield
33
34 but it uses low-level utilities to avoid performance overhead.
35 """
36 try:
37 yield
38 except Exception as exc:
39 if db_connection.in_atomic_block:
40 db_connection.needs_rollback = True
41 db_connection.rollback_exc = exc
42 raise
43
44
45def on_commit(func: Callable[[], Any], robust: bool = False) -> None:
46 """
47 Register `func` to be called when the current transaction is committed.
48 If the current transaction is rolled back, `func` will not be called.
49 """
50 db_connection.on_commit(func, robust)
51
52
53#################################
54# Decorators / context managers #
55#################################
56
57
58class Atomic(ContextDecorator):
59 """
60 Guarantee the atomic execution of a given block.
61
62 An instance can be used either as a decorator or as a context manager.
63
64 When it's used as a decorator, __call__ wraps the execution of the
65 decorated function in the instance itself, used as a context manager.
66
67 When it's used as a context manager, __enter__ creates a transaction or a
68 savepoint, depending on whether a transaction is already in progress, and
69 __exit__ commits the transaction or releases the savepoint on normal exit,
70 and rolls back the transaction or to the savepoint on exceptions.
71
72 It's possible to disable the creation of savepoints if the goal is to
73 ensure that some code runs within a transaction without creating overhead.
74
75 A stack of savepoints identifiers is maintained as an attribute of the
76 db_connection. None denotes the absence of a savepoint.
77
78 This allows reentrancy even if the same AtomicWrapper is reused. For
79 example, it's possible to define `oa = atomic('other')` and use `@oa` or
80 `with oa:` multiple times.
81
82 Since database connections are thread-local, this is thread-safe.
83
84 An atomic block can be tagged as durable. In this case, raise a
85 RuntimeError if it's nested within another atomic block. This guarantees
86 that database changes in a durable block are committed to the database when
87 the block exists without error.
88
89 This is a private API.
90 """
91
92 def __init__(self, savepoint: bool, durable: bool) -> None:
93 self.savepoint = savepoint
94 self.durable = durable
95 self._from_testcase = False
96
97 def __enter__(self) -> None:
98 if (
99 self.durable
100 and db_connection.atomic_blocks
101 and not db_connection.atomic_blocks[-1]._from_testcase
102 ):
103 raise RuntimeError(
104 "A durable atomic block cannot be nested within another atomic block."
105 )
106 if not db_connection.in_atomic_block:
107 # Reset state when entering an outermost atomic block.
108 db_connection.commit_on_exit = True
109 db_connection.needs_rollback = False
110 if not db_connection.get_autocommit():
111 # Pretend we're already in an atomic block to bypass the code
112 # that disables autocommit to enter a transaction, and make a
113 # note to deal with this case in __exit__.
114 db_connection.in_atomic_block = True
115 db_connection.commit_on_exit = False
116
117 if db_connection.in_atomic_block:
118 # We're already in a transaction; create a savepoint, unless we
119 # were told not to or we're already waiting for a rollback. The
120 # second condition avoids creating useless savepoints and prevents
121 # overwriting needs_rollback until the rollback is performed.
122 if self.savepoint and not db_connection.needs_rollback:
123 sid = db_connection.savepoint()
124 db_connection.savepoint_ids.append(sid)
125 else:
126 db_connection.savepoint_ids.append(None)
127 else:
128 db_connection.set_autocommit(
129 False, force_begin_transaction_with_broken_autocommit=True
130 )
131 db_connection.in_atomic_block = True
132
133 if db_connection.in_atomic_block:
134 db_connection.atomic_blocks.append(self)
135
136 def __exit__(
137 self,
138 exc_type: type[BaseException] | None,
139 exc_value: BaseException | None,
140 traceback: Any,
141 ) -> None:
142 if db_connection.in_atomic_block:
143 db_connection.atomic_blocks.pop()
144
145 if db_connection.savepoint_ids:
146 sid = db_connection.savepoint_ids.pop()
147 else:
148 # Prematurely unset this flag to allow using commit or rollback.
149 db_connection.in_atomic_block = False
150
151 try:
152 if db_connection.closed_in_transaction:
153 # The database will perform a rollback by itself.
154 # Wait until we exit the outermost block.
155 pass
156
157 elif exc_type is None and not db_connection.needs_rollback:
158 if db_connection.in_atomic_block:
159 # Release savepoint if there is one
160 if sid is not None:
161 try:
162 db_connection.savepoint_commit(sid)
163 except DatabaseError:
164 try:
165 db_connection.savepoint_rollback(sid)
166 # The savepoint won't be reused. Release it to
167 # minimize overhead for the database server.
168 db_connection.savepoint_commit(sid)
169 except Error:
170 # If rolling back to a savepoint fails, mark for
171 # rollback at a higher level and avoid shadowing
172 # the original exception.
173 db_connection.needs_rollback = True
174 raise
175 else:
176 # Commit transaction
177 try:
178 db_connection.commit()
179 except DatabaseError:
180 try:
181 db_connection.rollback()
182 except Error:
183 # An error during rollback means that something
184 # went wrong with the db_connection. Drop it.
185 db_connection.close()
186 raise
187 else:
188 # This flag will be set to True again if there isn't a savepoint
189 # allowing to perform the rollback at this level.
190 db_connection.needs_rollback = False
191 if db_connection.in_atomic_block:
192 # Roll back to savepoint if there is one, mark for rollback
193 # otherwise.
194 if sid is None:
195 db_connection.needs_rollback = True
196 else:
197 try:
198 db_connection.savepoint_rollback(sid)
199 # The savepoint won't be reused. Release it to
200 # minimize overhead for the database server.
201 db_connection.savepoint_commit(sid)
202 except Error:
203 # If rolling back to a savepoint fails, mark for
204 # rollback at a higher level and avoid shadowing
205 # the original exception.
206 db_connection.needs_rollback = True
207 else:
208 # Roll back transaction
209 try:
210 db_connection.rollback()
211 except Error:
212 # An error during rollback means that something
213 # went wrong with the db_connection. Drop it.
214 db_connection.close()
215
216 finally:
217 # Outermost block exit when autocommit was enabled.
218 if not db_connection.in_atomic_block:
219 if db_connection.closed_in_transaction:
220 db_connection.connection = None
221 else:
222 db_connection.set_autocommit(True)
223 # Outermost block exit when autocommit was disabled.
224 elif not db_connection.savepoint_ids and not db_connection.commit_on_exit:
225 if db_connection.closed_in_transaction:
226 db_connection.connection = None
227 else:
228 db_connection.in_atomic_block = False
229
230
231def atomic(
232 func: F | None = None, *, savepoint: bool = True, durable: bool = False
233) -> F | Atomic:
234 """Create an atomic transaction context or decorator."""
235 if callable(func):
236 return Atomic(savepoint, durable)(func) # type: ignore[return-value]
237 return Atomic(savepoint, durable)