1"""HTTP middleware that manages the per-request database connection lifecycle."""
 2
 3from __future__ import annotations
 4
 5from functools import partial
 6
 7from plain.http import HttpMiddleware, Response
 8from plain.http.request import Request
 9
10from .db import _db_conn, return_database_connection
11
12
13class DatabaseConnectionMiddleware(HttpMiddleware):
14    """Returns the per-request DB connection to the pool at request end.
15
16    For streaming responses the connection is returned when the body is
17    fully drained (via `_resource_closers`) rather than when the view
18    returns โ€” otherwise generators that lazily query the DB (e.g.
19    `Model.query.iterator()` inside a `StreamingResponse`) would see their
20    cursor invalidated when the pool rolls back the returned connection.
21
22    The streaming path captures the wrapper *now* and hands it to the
23    closer explicitly, because `response.close()` runs after `handle()`
24    returns โ€” outside the per-request `contextvars.Context` โ€” so a
25    `_db_conn.get()` at close time would miss the wrapper entirely.
26    """
27
28    def after_response(self, request: Request, response: Response) -> Response:
29        if response.streaming:
30            conn = _db_conn.get()
31            if conn is not None:
32                response._resource_closers.append(
33                    partial(return_database_connection, conn)
34                )
35        else:
36            return_database_connection()
37        return response