1from __future__ import annotations
  2
  3from collections.abc import Callable, Generator
  4from contextlib import ContextDecorator, contextmanager
  5from typing import Any
  6
  7from plain.postgres.db import DatabaseError, Error, ProgrammingError, get_connection
  8
  9
 10class TransactionManagementError(ProgrammingError):
 11    """Transaction management is used improperly."""
 12
 13    pass
 14
 15
 16@contextmanager
 17def mark_for_rollback_on_error() -> Generator[None]:
 18    """
 19    Internal low-level utility to mark a transaction as "needs rollback" when
 20    an exception is raised while not enforcing the enclosed block to be in a
 21    transaction. This is needed by Model.save() and friends to avoid starting a
 22    transaction when in autocommit mode and a single query is executed.
 23
 24    It's equivalent to:
 25
 26        if get_connection().get_autocommit():
 27            yield
 28        else:
 29            with transaction.atomic(savepoint=False):
 30                yield
 31
 32    but it uses low-level utilities to avoid performance overhead.
 33    """
 34    try:
 35        yield
 36    except Exception as exc:
 37        conn = get_connection()
 38        if conn.in_atomic_block:
 39            conn.needs_rollback = True
 40            conn.rollback_exc = exc
 41        raise
 42
 43
 44def on_commit(func: Callable[[], Any], robust: bool = False) -> None:
 45    """
 46    Register `func` to be called when the current transaction is committed.
 47    If the current transaction is rolled back, `func` will not be called.
 48    """
 49    get_connection().on_commit(func, robust)
 50
 51
 52#################################
 53# Decorators / context managers #
 54#################################
 55
 56
 57class Atomic(ContextDecorator):
 58    """
 59    Guarantee the atomic execution of a given block.
 60
 61    An instance can be used either as a decorator or as a context manager.
 62
 63    When it's used as a decorator, __call__ wraps the execution of the
 64    decorated function in the instance itself, used as a context manager.
 65
 66    When it's used as a context manager, __enter__ creates a transaction or a
 67    savepoint, depending on whether a transaction is already in progress, and
 68    __exit__ commits the transaction or releases the savepoint on normal exit,
 69    and rolls back the transaction or to the savepoint on exceptions.
 70
 71    It's possible to disable the creation of savepoints if the goal is to
 72    ensure that some code runs within a transaction without creating overhead.
 73
 74    A stack of savepoints identifiers is maintained as an attribute of the
 75    connection. None denotes the absence of a savepoint.
 76
 77    This allows reentrancy even if the same AtomicWrapper is reused. For
 78    example, it's possible to define `oa = atomic('other')` and use `@oa` or
 79    `with oa:` multiple times.
 80
 81    Since database connections are stored per-context (ContextVar), this is thread-safe.
 82
 83    An atomic block can be tagged as durable. In this case, raise a
 84    RuntimeError if it's nested within another atomic block. This guarantees
 85    that database changes in a durable block are committed to the database when
 86    the block exists without error.
 87
 88    This is a private API.
 89    """
 90
 91    def __init__(self, savepoint: bool, durable: bool) -> None:
 92        self.savepoint = savepoint
 93        self.durable = durable
 94        self._from_testcase = False
 95
 96    def __enter__(self) -> None:
 97        conn = get_connection()
 98        if (
 99            self.durable
100            and conn.atomic_blocks
101            and not conn.atomic_blocks[-1]._from_testcase
102        ):
103            raise RuntimeError(
104                "A durable atomic block cannot be nested within another atomic block."
105            )
106        if not conn.in_atomic_block:
107            # Reset state when entering an outermost atomic block.
108            conn.needs_rollback = False
109        if conn.in_atomic_block:
110            # We're already in a transaction; create a savepoint, unless we
111            # were told not to or we're already waiting for a rollback. The
112            # second condition avoids creating useless savepoints and prevents
113            # overwriting needs_rollback until the rollback is performed.
114            if self.savepoint and not conn.needs_rollback:
115                sid = conn.savepoint()
116                conn.savepoint_ids.append(sid)
117            else:
118                conn.savepoint_ids.append(None)
119        else:
120            conn.set_autocommit(False)
121            conn.in_atomic_block = True
122
123        if conn.in_atomic_block:
124            conn.atomic_blocks.append(self)
125
126    def __exit__(
127        self,
128        exc_type: type[BaseException] | None,
129        exc_value: BaseException | None,
130        traceback: Any,
131    ) -> None:
132        conn = get_connection()
133        if conn.in_atomic_block:
134            conn.atomic_blocks.pop()
135
136        if conn.savepoint_ids:
137            sid = conn.savepoint_ids.pop()
138        else:
139            # Prematurely unset this flag to allow using commit or rollback.
140            conn.in_atomic_block = False
141
142        try:
143            if conn.closed_in_transaction:
144                # The database will perform a rollback by itself.
145                # Wait until we exit the outermost block.
146                pass
147
148            elif exc_type is None and not conn.needs_rollback:
149                if conn.in_atomic_block:
150                    # Release savepoint if there is one
151                    if sid is not None:
152                        try:
153                            conn.savepoint_commit(sid)
154                        except DatabaseError:
155                            try:
156                                conn.savepoint_rollback(sid)
157                                # The savepoint won't be reused. Release it to
158                                # minimize overhead for the database server.
159                                conn.savepoint_commit(sid)
160                            except Error:
161                                # If rolling back to a savepoint fails, mark for
162                                # rollback at a higher level and avoid shadowing
163                                # the original exception.
164                                conn.needs_rollback = True
165                            raise
166                else:
167                    # Commit transaction
168                    try:
169                        conn.commit()
170                    except DatabaseError:
171                        try:
172                            conn.rollback()
173                        except Error:
174                            # An error during rollback means that something
175                            # went wrong with the connection. Drop it.
176                            conn.close()
177                        raise
178            else:
179                # This flag will be set to True again if there isn't a savepoint
180                # allowing to perform the rollback at this level.
181                conn.needs_rollback = False
182                if conn.in_atomic_block:
183                    # Roll back to savepoint if there is one, mark for rollback
184                    # otherwise.
185                    if sid is None:
186                        conn.needs_rollback = True
187                    else:
188                        try:
189                            conn.savepoint_rollback(sid)
190                            # The savepoint won't be reused. Release it to
191                            # minimize overhead for the database server.
192                            conn.savepoint_commit(sid)
193                        except Error:
194                            # If rolling back to a savepoint fails, mark for
195                            # rollback at a higher level and avoid shadowing
196                            # the original exception.
197                            conn.needs_rollback = True
198                else:
199                    # Roll back transaction
200                    try:
201                        conn.rollback()
202                    except Error:
203                        # An error during rollback means that something
204                        # went wrong with the connection. Drop it.
205                        conn.close()
206
207        finally:
208            # Outermost block exit when autocommit was enabled.
209            if not conn.in_atomic_block:
210                if conn.closed_in_transaction:
211                    conn.connection = None
212                else:
213                    conn.set_autocommit(True)
214
215
216def atomic[F: Callable[..., Any]](
217    func: F | None = None, *, savepoint: bool = True, durable: bool = False
218) -> F | Atomic:
219    """Create an atomic transaction context or decorator."""
220    if callable(func):
221        return Atomic(savepoint, durable)(func)
222    return Atomic(savepoint, durable)