1from __future__ import annotations
2
3import asyncio
4import concurrent.futures
5import dataclasses
6import inspect
7from typing import TYPE_CHECKING, Any
8
9from opentelemetry import baggage, context, trace
10from opentelemetry.semconv.attributes import http_attributes, url_attributes
11
12from plain import signals
13from plain.runtime import settings
14from plain.urls import get_resolver
15from plain.utils.module_loading import import_string
16
17from .exception import response_for_exception
18
19if TYPE_CHECKING:
20 from plain.http import Request, ResponseBase
21 from plain.http.middleware import HttpMiddleware
22 from plain.urls import ResolverMatch
23
24
25# Builtin middleware that runs before user middleware.
26# before_request runs top-down, after_response runs bottom-up (outermost).
27BUILTIN_BEFORE_MIDDLEWARE = [
28 "plain.internal.middleware.headers.DefaultHeadersMiddleware",
29 "plain.internal.middleware.healthcheck.HealthcheckMiddleware",
30 "plain.internal.middleware.hosts.HostValidationMiddleware",
31 "plain.internal.middleware.https.HttpsRedirectMiddleware",
32 "plain.csrf.middleware.CsrfViewMiddleware",
33]
34
35# Builtin middleware that runs after user middleware (closest to the view).
36# after_response runs first, so replacements (e.g. slash redirect) happen
37# before user middleware modifies the response (e.g. session cookies).
38BUILTIN_AFTER_MIDDLEWARE = [
39 "plain.internal.middleware.slash.RedirectSlashMiddleware",
40]
41
42
43tracer = trace.get_tracer("plain")
44
45
46@dataclasses.dataclass
47class _AsyncViewPending:
48 """Returned by _run_sync_pipeline when an async view needs to be awaited."""
49
50 coroutine: Any
51 view_class: type
52 ran_before: list[HttpMiddleware]
53
54
55class BaseHandler:
56 _middleware_chain: list[HttpMiddleware] | None = None
57
58 def load_middleware(self) -> None:
59 """
60 Populate middleware list from settings.MIDDLEWARE.
61
62 Must be called after the environment is fixed (see __call__ in subclasses).
63 """
64 middleware_paths = (
65 BUILTIN_BEFORE_MIDDLEWARE + settings.MIDDLEWARE + BUILTIN_AFTER_MIDDLEWARE
66 )
67
68 chain: list[HttpMiddleware] = []
69 for middleware_path in middleware_paths:
70 middleware_class = import_string(middleware_path)
71 mw_instance = middleware_class()
72 chain.append(mw_instance)
73
74 # We only assign to this when initialization is complete as it is used
75 # as a flag for initialization being complete.
76 self._middleware_chain = chain
77
78 def _build_request_span(
79 self, request: Request
80 ) -> tuple[dict[str, str], context.Context]:
81 """Build OpenTelemetry span attributes and baggage context for a request."""
82 span_attributes: dict[str, str] = {
83 "plain.request.id": request.unique_id,
84 http_attributes.HTTP_REQUEST_METHOD: request.method or "",
85 url_attributes.URL_PATH: request.path_info,
86 url_attributes.URL_SCHEME: request.scheme,
87 }
88
89 try:
90 span_attributes[url_attributes.URL_FULL] = request.build_absolute_uri()
91 except (KeyError, AttributeError):
92 pass
93
94 if request.query_string:
95 span_attributes[url_attributes.URL_QUERY] = request.query_string
96
97 span_context = baggage.set_baggage("http.request.cookies", request.cookies)
98 span_context = baggage.set_baggage(
99 "http.request.headers", request.headers, span_context
100 )
101 return span_attributes, span_context
102
103 def _finalize_span(self, span: trace.Span, response: ResponseBase) -> None:
104 """Set span status and record exceptions from the response."""
105 span.set_attribute(
106 http_attributes.HTTP_RESPONSE_STATUS_CODE, response.status_code
107 )
108 span.set_status(
109 trace.StatusCode.OK
110 if response.status_code < 400
111 else trace.StatusCode.ERROR
112 )
113 if response.exception:
114 span.record_exception(response.exception)
115
116 async def _run_in_executor(
117 self,
118 executor: concurrent.futures.Executor,
119 fn: Any,
120 *args: Any,
121 **kwargs: Any,
122 ) -> Any:
123 """Run a sync function in the executor, propagating OTel context."""
124 loop = asyncio.get_running_loop()
125 ctx = context.get_current()
126
127 def _wrapper() -> Any:
128 token = context.attach(ctx)
129 try:
130 return fn(*args, **kwargs)
131 finally:
132 context.detach(token)
133
134 return await loop.run_in_executor(executor, _wrapper)
135
136 async def handle(
137 self,
138 request: Request,
139 executor: concurrent.futures.Executor,
140 ) -> ResponseBase:
141 """Single entry point for handling a request.
142
143 Creates OTel span and runs the full pipeline: signal → before
144 middleware → resolve/dispatch view → after middleware → signal.
145
146 For sync views, the entire pipeline runs in a single executor call
147 so that signals, middleware, and the view all execute on the same
148 thread (preserving thread-local DB connection assumptions).
149
150 For async views, the sync portion (signal + before middleware +
151 URL resolution) runs in one executor call, the coroutine is awaited
152 on the event loop, then after-middleware + request_finished runs in
153 a second executor call.
154 """
155 assert self._middleware_chain is not None, (
156 "load_middleware() must be called before handle()"
157 )
158
159 span_attributes, span_context = self._build_request_span(request)
160
161 with tracer.start_as_current_span(
162 f"{request.method} {request.path_info}",
163 context=span_context,
164 attributes=span_attributes,
165 kind=trace.SpanKind.SERVER,
166 ) as span:
167 result = await self._run_in_executor(
168 executor, self._run_sync_pipeline, request
169 )
170
171 if isinstance(result, _AsyncViewPending):
172 # Async view: await the coroutine on the event loop, then
173 # run after-middleware + request_finished back in the executor.
174 try:
175 response = await result.coroutine
176 self._check_response(response, result.view_class)
177 except Exception as exc:
178 response = response_for_exception(request, exc)
179
180 response = await self._run_in_executor(
181 executor,
182 self._finish_pipeline,
183 request,
184 response,
185 result.ran_before,
186 )
187 else:
188 response = result
189
190 response._resource_closers.append(request.close)
191 self._finalize_span(span, response)
192 return response
193
194 def _run_sync_pipeline(self, request: Request) -> ResponseBase | _AsyncViewPending:
195 """Run the entire sync request pipeline on a single thread.
196
197 Sends request_started, runs before-middleware, resolves and dispatches
198 the view, runs after-middleware, and sends request_finished.
199
200 If the view is async, returns an _AsyncViewPending so the caller
201 can await the coroutine on the event loop.
202 """
203 signals.request_started.send(sender=self.__class__, request=request)
204
205 # 1. Before middleware
206 response, ran_before = self._run_before_request(request)
207
208 # 2. Resolve and dispatch the view
209 if response is None:
210 try:
211 resolver_match = self._resolve_request(request)
212 view = resolver_match.view_class(
213 request=request,
214 url_args=resolver_match.args,
215 url_kwargs=resolver_match.kwargs,
216 )
217 response = view.get_response()
218 view_class = type(view)
219
220 # Async views return a coroutine that must be awaited
221 if inspect.iscoroutine(response):
222 return _AsyncViewPending(
223 coroutine=response,
224 view_class=view_class,
225 ran_before=ran_before,
226 )
227
228 self._check_response(response, view_class)
229 except Exception as exc:
230 response = response_for_exception(request, exc)
231
232 # 3. After middleware + request_finished signal
233 return self._finish_pipeline(request, response, ran_before)
234
235 def _finish_pipeline(
236 self,
237 request: Request,
238 response: ResponseBase,
239 ran_before: list[HttpMiddleware],
240 ) -> ResponseBase:
241 """Run after-middleware and send request_finished signal.
242
243 request_finished is sent here (on the same thread as request_started)
244 rather than from response.close(), which runs after the response body
245 is written and may land on a different thread. This means the signal
246 fires before streaming responses are iterated — handlers like
247 close_old_connections should not affect in-progress streams since
248 request_started on the next request also handles stale connections.
249 """
250 response = self._run_after_response(request, response, ran_before)
251 signals.request_finished.send(sender=self.__class__)
252 return response
253
254 def _resolve_request(self, request: Request) -> ResolverMatch:
255 """Resolve the URL, caching on request.resolver_match."""
256 if request.resolver_match is not None:
257 resolver_match = request.resolver_match
258 else:
259 resolver = get_resolver()
260 resolver_match = resolver.resolve(request.path_info)
261 request.resolver_match = resolver_match
262
263 # Update span with route info
264 span = trace.get_current_span()
265 if resolver_match.route:
266 route_with_slash = f"/{resolver_match.route}"
267 span.set_attribute(http_attributes.HTTP_ROUTE, route_with_slash)
268 span.update_name(f"{request.method} {route_with_slash}")
269
270 return resolver_match
271
272 def _run_before_request(
273 self, request: Request
274 ) -> tuple[ResponseBase | None, list[HttpMiddleware]]:
275 """Run before_request forward through middleware chain."""
276 chain = self._middleware_chain
277 assert chain is not None
278
279 response = None
280 ran_before: list[HttpMiddleware] = []
281
282 for mw in chain:
283 try:
284 result = mw.before_request(request)
285 except Exception as exc:
286 response = response_for_exception(request, exc)
287 break
288
289 ran_before.append(mw)
290
291 if result is not None:
292 response = result
293 break
294
295 return response, ran_before
296
297 def _run_after_response(
298 self,
299 request: Request,
300 response: ResponseBase,
301 ran_before: list[HttpMiddleware],
302 ) -> ResponseBase:
303 """Run after_response in reverse through middleware that ran before_request."""
304 for mw in reversed(ran_before):
305 try:
306 response = mw.after_response(request, response) # type: ignore[arg-type]
307 except Exception as exc:
308 response = response_for_exception(request, exc)
309
310 return response
311
312 def _check_response(
313 self,
314 response: ResponseBase | None,
315 view_class: type,
316 ) -> None:
317 """Raise an error if the view returned None."""
318 if response is None:
319 name = f"{view_class.__module__}.{view_class.__qualname__}"
320 raise ValueError(
321 f"The view {name} didn't return a Response object. It returned None instead."
322 )