1"""HTTP middleware that manages the per-request database connection lifecycle."""
2
3from __future__ import annotations
4
5from functools import partial
6
7from plain.http import HttpMiddleware, Response
8from plain.http.request import Request
9
10from .db import _db_conn, return_database_connection
11
12
13class DatabaseConnectionMiddleware(HttpMiddleware):
14 """Returns the per-request DB connection to the pool at request end.
15
16 For streaming responses the connection is returned when the body is
17 fully drained (via `_resource_closers`) rather than when the view
18 returns โ otherwise generators that lazily query the DB (e.g.
19 `Model.query.iterator()` inside a `StreamingResponse`) would see their
20 cursor invalidated when the pool rolls back the returned connection.
21
22 The streaming path captures the wrapper *now* and hands it to the
23 closer explicitly, because `response.close()` runs after `handle()`
24 returns โ outside the per-request `contextvars.Context` โ so a
25 `_db_conn.get()` at close time would miss the wrapper entirely.
26 """
27
28 def after_response(self, request: Request, response: Response) -> Response:
29 if response.streaming:
30 conn = _db_conn.get()
31 if conn is not None:
32 response._resource_closers.append(
33 partial(return_database_connection, conn)
34 )
35 else:
36 return_database_connection()
37 return response