1from __future__ import annotations
2
3from collections.abc import Callable, Generator
4from contextlib import ContextDecorator, contextmanager
5from typing import Any
6
7from plain.postgres.db import DatabaseError, Error, ProgrammingError, get_connection
8
9
10class TransactionManagementError(ProgrammingError):
11 """Transaction management is used improperly."""
12
13 pass
14
15
16@contextmanager
17def mark_for_rollback_on_error() -> Generator[None]:
18 """
19 Internal low-level utility to mark a transaction as "needs rollback" when
20 an exception is raised while not enforcing the enclosed block to be in a
21 transaction. This is needed by Model.save() and friends to avoid starting a
22 transaction when in autocommit mode and a single query is executed.
23
24 It's equivalent to:
25
26 if get_connection().get_autocommit():
27 yield
28 else:
29 with transaction.atomic(savepoint=False):
30 yield
31
32 but it uses low-level utilities to avoid performance overhead.
33 """
34 try:
35 yield
36 except Exception as exc:
37 conn = get_connection()
38 if conn.in_atomic_block:
39 conn.needs_rollback = True
40 conn.rollback_exc = exc
41 raise
42
43
44def on_commit(func: Callable[[], Any], robust: bool = False) -> None:
45 """
46 Register `func` to be called when the current transaction is committed.
47 If the current transaction is rolled back, `func` will not be called.
48 """
49 get_connection().on_commit(func, robust)
50
51
52#################################
53# Decorators / context managers #
54#################################
55
56
57class Atomic(ContextDecorator):
58 """
59 Guarantee the atomic execution of a given block.
60
61 An instance can be used either as a decorator or as a context manager.
62
63 When it's used as a decorator, __call__ wraps the execution of the
64 decorated function in the instance itself, used as a context manager.
65
66 When it's used as a context manager, __enter__ creates a transaction or a
67 savepoint, depending on whether a transaction is already in progress, and
68 __exit__ commits the transaction or releases the savepoint on normal exit,
69 and rolls back the transaction or to the savepoint on exceptions.
70
71 It's possible to disable the creation of savepoints if the goal is to
72 ensure that some code runs within a transaction without creating overhead.
73
74 A stack of savepoints identifiers is maintained as an attribute of the
75 connection. None denotes the absence of a savepoint.
76
77 This allows reentrancy even if the same AtomicWrapper is reused. For
78 example, it's possible to define `oa = atomic('other')` and use `@oa` or
79 `with oa:` multiple times.
80
81 Since database connections are stored per-context (ContextVar), this is thread-safe.
82
83 An atomic block can be tagged as durable. In this case, raise a
84 RuntimeError if it's nested within another atomic block. This guarantees
85 that database changes in a durable block are committed to the database when
86 the block exists without error.
87
88 This is a private API.
89 """
90
91 def __init__(self, savepoint: bool, durable: bool) -> None:
92 self.savepoint = savepoint
93 self.durable = durable
94 self._from_testcase = False
95
96 def __enter__(self) -> None:
97 conn = get_connection()
98 if (
99 self.durable
100 and conn.atomic_blocks
101 and not conn.atomic_blocks[-1]._from_testcase
102 ):
103 raise RuntimeError(
104 "A durable atomic block cannot be nested within another atomic block."
105 )
106 if not conn.in_atomic_block:
107 # Reset state when entering an outermost atomic block.
108 conn.needs_rollback = False
109 if conn.in_atomic_block:
110 # We're already in a transaction; create a savepoint, unless we
111 # were told not to or we're already waiting for a rollback. The
112 # second condition avoids creating useless savepoints and prevents
113 # overwriting needs_rollback until the rollback is performed.
114 if self.savepoint and not conn.needs_rollback:
115 sid = conn.savepoint()
116 conn.savepoint_ids.append(sid)
117 else:
118 conn.savepoint_ids.append(None)
119 else:
120 conn.set_autocommit(False)
121 conn.in_atomic_block = True
122
123 if conn.in_atomic_block:
124 conn.atomic_blocks.append(self)
125
126 def __exit__(
127 self,
128 exc_type: type[BaseException] | None,
129 exc_value: BaseException | None,
130 traceback: Any,
131 ) -> None:
132 conn = get_connection()
133 if conn.in_atomic_block:
134 conn.atomic_blocks.pop()
135
136 if conn.savepoint_ids:
137 sid = conn.savepoint_ids.pop()
138 else:
139 # Prematurely unset this flag to allow using commit or rollback.
140 conn.in_atomic_block = False
141
142 try:
143 if conn.closed_in_transaction:
144 # The database will perform a rollback by itself.
145 # Wait until we exit the outermost block.
146 pass
147
148 elif exc_type is None and not conn.needs_rollback:
149 if conn.in_atomic_block:
150 # Release savepoint if there is one
151 if sid is not None:
152 try:
153 conn.savepoint_commit(sid)
154 except DatabaseError:
155 try:
156 conn.savepoint_rollback(sid)
157 # The savepoint won't be reused. Release it to
158 # minimize overhead for the database server.
159 conn.savepoint_commit(sid)
160 except Error:
161 # If rolling back to a savepoint fails, mark for
162 # rollback at a higher level and avoid shadowing
163 # the original exception.
164 conn.needs_rollback = True
165 raise
166 else:
167 # Commit transaction
168 try:
169 conn.commit()
170 except DatabaseError:
171 try:
172 conn.rollback()
173 except Error:
174 # An error during rollback means that something
175 # went wrong with the connection. Drop it.
176 conn.close()
177 raise
178 else:
179 # This flag will be set to True again if there isn't a savepoint
180 # allowing to perform the rollback at this level.
181 conn.needs_rollback = False
182 if conn.in_atomic_block:
183 # Roll back to savepoint if there is one, mark for rollback
184 # otherwise.
185 if sid is None:
186 conn.needs_rollback = True
187 else:
188 try:
189 conn.savepoint_rollback(sid)
190 # The savepoint won't be reused. Release it to
191 # minimize overhead for the database server.
192 conn.savepoint_commit(sid)
193 except Error:
194 # If rolling back to a savepoint fails, mark for
195 # rollback at a higher level and avoid shadowing
196 # the original exception.
197 conn.needs_rollback = True
198 else:
199 # Roll back transaction
200 try:
201 conn.rollback()
202 except Error:
203 # An error during rollback means that something
204 # went wrong with the connection. Drop it.
205 conn.close()
206
207 finally:
208 # Outermost block exit when autocommit was enabled.
209 if not conn.in_atomic_block:
210 if conn.closed_in_transaction:
211 conn.connection = None
212 else:
213 conn.set_autocommit(True)
214
215
216def atomic[F: Callable[..., Any]](
217 func: F | None = None, *, savepoint: bool = True, durable: bool = False
218) -> F | Atomic:
219 """Create an atomic transaction context or decorator."""
220 if callable(func):
221 return Atomic(savepoint, durable)(func)
222 return Atomic(savepoint, durable)