Plain is headed towards 1.0! Subscribe for development updates →

  1from __future__ import annotations
  2
  3from collections.abc import Callable, Generator
  4from contextlib import ContextDecorator, contextmanager
  5from typing import Any, TypeVar
  6
  7from plain.models.db import DatabaseError, Error, ProgrammingError, db_connection
  8
  9F = TypeVar("F", bound=Callable[..., Any])
 10
 11
 12class TransactionManagementError(ProgrammingError):
 13    """Transaction management is used improperly."""
 14
 15    pass
 16
 17
 18@contextmanager
 19def mark_for_rollback_on_error() -> Generator[None, None, None]:
 20    """
 21    Internal low-level utility to mark a transaction as "needs rollback" when
 22    an exception is raised while not enforcing the enclosed block to be in a
 23    transaction. This is needed by Model.save() and friends to avoid starting a
 24    transaction when in autocommit mode and a single query is executed.
 25
 26    It's equivalent to:
 27
 28        if db_connection.get_autocommit():
 29            yield
 30        else:
 31            with transaction.atomic(savepoint=False):
 32                yield
 33
 34    but it uses low-level utilities to avoid performance overhead.
 35    """
 36    try:
 37        yield
 38    except Exception as exc:
 39        if db_connection.in_atomic_block:
 40            db_connection.needs_rollback = True
 41            db_connection.rollback_exc = exc
 42        raise
 43
 44
 45def on_commit(func: Callable[[], Any], robust: bool = False) -> None:
 46    """
 47    Register `func` to be called when the current transaction is committed.
 48    If the current transaction is rolled back, `func` will not be called.
 49    """
 50    db_connection.on_commit(func, robust)
 51
 52
 53#################################
 54# Decorators / context managers #
 55#################################
 56
 57
 58class Atomic(ContextDecorator):
 59    """
 60    Guarantee the atomic execution of a given block.
 61
 62    An instance can be used either as a decorator or as a context manager.
 63
 64    When it's used as a decorator, __call__ wraps the execution of the
 65    decorated function in the instance itself, used as a context manager.
 66
 67    When it's used as a context manager, __enter__ creates a transaction or a
 68    savepoint, depending on whether a transaction is already in progress, and
 69    __exit__ commits the transaction or releases the savepoint on normal exit,
 70    and rolls back the transaction or to the savepoint on exceptions.
 71
 72    It's possible to disable the creation of savepoints if the goal is to
 73    ensure that some code runs within a transaction without creating overhead.
 74
 75    A stack of savepoints identifiers is maintained as an attribute of the
 76    db_connection. None denotes the absence of a savepoint.
 77
 78    This allows reentrancy even if the same AtomicWrapper is reused. For
 79    example, it's possible to define `oa = atomic('other')` and use `@oa` or
 80    `with oa:` multiple times.
 81
 82    Since database connections are thread-local, this is thread-safe.
 83
 84    An atomic block can be tagged as durable. In this case, raise a
 85    RuntimeError if it's nested within another atomic block. This guarantees
 86    that database changes in a durable block are committed to the database when
 87    the block exists without error.
 88
 89    This is a private API.
 90    """
 91
 92    def __init__(self, savepoint: bool, durable: bool) -> None:
 93        self.savepoint = savepoint
 94        self.durable = durable
 95        self._from_testcase = False
 96
 97    def __enter__(self) -> None:
 98        if (
 99            self.durable
100            and db_connection.atomic_blocks
101            and not db_connection.atomic_blocks[-1]._from_testcase
102        ):
103            raise RuntimeError(
104                "A durable atomic block cannot be nested within another atomic block."
105            )
106        if not db_connection.in_atomic_block:
107            # Reset state when entering an outermost atomic block.
108            db_connection.commit_on_exit = True
109            db_connection.needs_rollback = False
110            if not db_connection.get_autocommit():
111                # Pretend we're already in an atomic block to bypass the code
112                # that disables autocommit to enter a transaction, and make a
113                # note to deal with this case in __exit__.
114                db_connection.in_atomic_block = True
115                db_connection.commit_on_exit = False
116
117        if db_connection.in_atomic_block:
118            # We're already in a transaction; create a savepoint, unless we
119            # were told not to or we're already waiting for a rollback. The
120            # second condition avoids creating useless savepoints and prevents
121            # overwriting needs_rollback until the rollback is performed.
122            if self.savepoint and not db_connection.needs_rollback:
123                sid = db_connection.savepoint()
124                db_connection.savepoint_ids.append(sid)
125            else:
126                db_connection.savepoint_ids.append(None)
127        else:
128            db_connection.set_autocommit(
129                False, force_begin_transaction_with_broken_autocommit=True
130            )
131            db_connection.in_atomic_block = True
132
133        if db_connection.in_atomic_block:
134            db_connection.atomic_blocks.append(self)
135
136    def __exit__(
137        self,
138        exc_type: type[BaseException] | None,
139        exc_value: BaseException | None,
140        traceback: Any,
141    ) -> None:
142        if db_connection.in_atomic_block:
143            db_connection.atomic_blocks.pop()
144
145        if db_connection.savepoint_ids:
146            sid = db_connection.savepoint_ids.pop()
147        else:
148            # Prematurely unset this flag to allow using commit or rollback.
149            db_connection.in_atomic_block = False
150
151        try:
152            if db_connection.closed_in_transaction:
153                # The database will perform a rollback by itself.
154                # Wait until we exit the outermost block.
155                pass
156
157            elif exc_type is None and not db_connection.needs_rollback:
158                if db_connection.in_atomic_block:
159                    # Release savepoint if there is one
160                    if sid is not None:
161                        try:
162                            db_connection.savepoint_commit(sid)
163                        except DatabaseError:
164                            try:
165                                db_connection.savepoint_rollback(sid)
166                                # The savepoint won't be reused. Release it to
167                                # minimize overhead for the database server.
168                                db_connection.savepoint_commit(sid)
169                            except Error:
170                                # If rolling back to a savepoint fails, mark for
171                                # rollback at a higher level and avoid shadowing
172                                # the original exception.
173                                db_connection.needs_rollback = True
174                            raise
175                else:
176                    # Commit transaction
177                    try:
178                        db_connection.commit()
179                    except DatabaseError:
180                        try:
181                            db_connection.rollback()
182                        except Error:
183                            # An error during rollback means that something
184                            # went wrong with the db_connection. Drop it.
185                            db_connection.close()
186                        raise
187            else:
188                # This flag will be set to True again if there isn't a savepoint
189                # allowing to perform the rollback at this level.
190                db_connection.needs_rollback = False
191                if db_connection.in_atomic_block:
192                    # Roll back to savepoint if there is one, mark for rollback
193                    # otherwise.
194                    if sid is None:
195                        db_connection.needs_rollback = True
196                    else:
197                        try:
198                            db_connection.savepoint_rollback(sid)
199                            # The savepoint won't be reused. Release it to
200                            # minimize overhead for the database server.
201                            db_connection.savepoint_commit(sid)
202                        except Error:
203                            # If rolling back to a savepoint fails, mark for
204                            # rollback at a higher level and avoid shadowing
205                            # the original exception.
206                            db_connection.needs_rollback = True
207                else:
208                    # Roll back transaction
209                    try:
210                        db_connection.rollback()
211                    except Error:
212                        # An error during rollback means that something
213                        # went wrong with the db_connection. Drop it.
214                        db_connection.close()
215
216        finally:
217            # Outermost block exit when autocommit was enabled.
218            if not db_connection.in_atomic_block:
219                if db_connection.closed_in_transaction:
220                    db_connection.connection = None
221                else:
222                    db_connection.set_autocommit(True)
223            # Outermost block exit when autocommit was disabled.
224            elif not db_connection.savepoint_ids and not db_connection.commit_on_exit:
225                if db_connection.closed_in_transaction:
226                    db_connection.connection = None
227                else:
228                    db_connection.in_atomic_block = False
229
230
231def atomic(
232    func: F | None = None, *, savepoint: bool = True, durable: bool = False
233) -> F | Atomic:
234    """Create an atomic transaction context or decorator."""
235    if callable(func):
236        return Atomic(savepoint, durable)(func)  # type: ignore[return-value]
237    return Atomic(savepoint, durable)