1from __future__ import annotations
  2
  3import asyncio
  4import concurrent.futures
  5import dataclasses
  6import inspect
  7from typing import TYPE_CHECKING, Any
  8
  9from opentelemetry import baggage, context, trace
 10from opentelemetry.semconv.attributes import http_attributes, url_attributes
 11
 12from plain import signals
 13from plain.runtime import settings
 14from plain.urls import get_resolver
 15from plain.utils.module_loading import import_string
 16
 17from .exception import response_for_exception
 18
 19if TYPE_CHECKING:
 20    from plain.http import Request, ResponseBase
 21    from plain.http.middleware import HttpMiddleware
 22    from plain.urls import ResolverMatch
 23
 24
 25# Builtin middleware that runs before user middleware.
 26# before_request runs top-down, after_response runs bottom-up (outermost).
 27BUILTIN_BEFORE_MIDDLEWARE = [
 28    "plain.internal.middleware.headers.DefaultHeadersMiddleware",
 29    "plain.internal.middleware.healthcheck.HealthcheckMiddleware",
 30    "plain.internal.middleware.hosts.HostValidationMiddleware",
 31    "plain.internal.middleware.https.HttpsRedirectMiddleware",
 32    "plain.csrf.middleware.CsrfViewMiddleware",
 33]
 34
 35# Builtin middleware that runs after user middleware (closest to the view).
 36# after_response runs first, so replacements (e.g. slash redirect) happen
 37# before user middleware modifies the response (e.g. session cookies).
 38BUILTIN_AFTER_MIDDLEWARE = [
 39    "plain.internal.middleware.slash.RedirectSlashMiddleware",
 40]
 41
 42
 43tracer = trace.get_tracer("plain")
 44
 45
 46@dataclasses.dataclass
 47class _AsyncViewPending:
 48    """Returned by _run_sync_pipeline when an async view needs to be awaited."""
 49
 50    coroutine: Any
 51    view_class: type
 52    ran_before: list[HttpMiddleware]
 53
 54
 55class BaseHandler:
 56    _middleware_chain: list[HttpMiddleware] | None = None
 57
 58    def load_middleware(self) -> None:
 59        """
 60        Populate middleware list from settings.MIDDLEWARE.
 61
 62        Must be called after the environment is fixed (see __call__ in subclasses).
 63        """
 64        middleware_paths = (
 65            BUILTIN_BEFORE_MIDDLEWARE + settings.MIDDLEWARE + BUILTIN_AFTER_MIDDLEWARE
 66        )
 67
 68        chain: list[HttpMiddleware] = []
 69        for middleware_path in middleware_paths:
 70            middleware_class = import_string(middleware_path)
 71            mw_instance = middleware_class()
 72            chain.append(mw_instance)
 73
 74        # We only assign to this when initialization is complete as it is used
 75        # as a flag for initialization being complete.
 76        self._middleware_chain = chain
 77
 78    def _build_request_span(
 79        self, request: Request
 80    ) -> tuple[dict[str, str], context.Context]:
 81        """Build OpenTelemetry span attributes and baggage context for a request."""
 82        span_attributes: dict[str, str] = {
 83            "plain.request.id": request.unique_id,
 84            http_attributes.HTTP_REQUEST_METHOD: request.method or "",
 85            url_attributes.URL_PATH: request.path_info,
 86            url_attributes.URL_SCHEME: request.scheme,
 87        }
 88
 89        try:
 90            span_attributes[url_attributes.URL_FULL] = request.build_absolute_uri()
 91        except (KeyError, AttributeError):
 92            pass
 93
 94        if request.query_string:
 95            span_attributes[url_attributes.URL_QUERY] = request.query_string
 96
 97        span_context = baggage.set_baggage("http.request.cookies", request.cookies)
 98        span_context = baggage.set_baggage(
 99            "http.request.headers", request.headers, span_context
100        )
101        return span_attributes, span_context
102
103    def _finalize_span(self, span: trace.Span, response: ResponseBase) -> None:
104        """Set span status and record exceptions from the response."""
105        span.set_attribute(
106            http_attributes.HTTP_RESPONSE_STATUS_CODE, response.status_code
107        )
108        span.set_status(
109            trace.StatusCode.OK
110            if response.status_code < 400
111            else trace.StatusCode.ERROR
112        )
113        if response.exception:
114            span.record_exception(response.exception)
115
116    async def _run_in_executor(
117        self,
118        executor: concurrent.futures.Executor,
119        fn: Any,
120        *args: Any,
121        **kwargs: Any,
122    ) -> Any:
123        """Run a sync function in the executor, propagating OTel context."""
124        loop = asyncio.get_running_loop()
125        ctx = context.get_current()
126
127        def _wrapper() -> Any:
128            token = context.attach(ctx)
129            try:
130                return fn(*args, **kwargs)
131            finally:
132                context.detach(token)
133
134        return await loop.run_in_executor(executor, _wrapper)
135
136    async def handle(
137        self,
138        request: Request,
139        executor: concurrent.futures.Executor,
140    ) -> ResponseBase:
141        """Single entry point for handling a request.
142
143        Creates OTel span and runs the full pipeline: signal → before
144        middleware → resolve/dispatch view → after middleware → signal.
145
146        For sync views, the entire pipeline runs in a single executor call
147        so that signals, middleware, and the view all execute on the same
148        thread (preserving thread-local DB connection assumptions).
149
150        For async views, the sync portion (signal + before middleware +
151        URL resolution) runs in one executor call, the coroutine is awaited
152        on the event loop, then after-middleware + request_finished runs in
153        a second executor call.
154        """
155        assert self._middleware_chain is not None, (
156            "load_middleware() must be called before handle()"
157        )
158
159        span_attributes, span_context = self._build_request_span(request)
160
161        with tracer.start_as_current_span(
162            f"{request.method} {request.path_info}",
163            context=span_context,
164            attributes=span_attributes,
165            kind=trace.SpanKind.SERVER,
166        ) as span:
167            result = await self._run_in_executor(
168                executor, self._run_sync_pipeline, request
169            )
170
171            if isinstance(result, _AsyncViewPending):
172                # Async view: await the coroutine on the event loop, then
173                # run after-middleware + request_finished back in the executor.
174                try:
175                    response = await result.coroutine
176                    self._check_response(response, result.view_class)
177                except Exception as exc:
178                    response = response_for_exception(request, exc)
179
180                response = await self._run_in_executor(
181                    executor,
182                    self._finish_pipeline,
183                    request,
184                    response,
185                    result.ran_before,
186                )
187            else:
188                response = result
189
190            response._resource_closers.append(request.close)
191            self._finalize_span(span, response)
192            return response
193
194    def _run_sync_pipeline(self, request: Request) -> ResponseBase | _AsyncViewPending:
195        """Run the entire sync request pipeline on a single thread.
196
197        Sends request_started, runs before-middleware, resolves and dispatches
198        the view, runs after-middleware, and sends request_finished.
199
200        If the view is async, returns an _AsyncViewPending so the caller
201        can await the coroutine on the event loop.
202        """
203        signals.request_started.send(sender=self.__class__, request=request)
204
205        # 1. Before middleware
206        response, ran_before = self._run_before_request(request)
207
208        # 2. Resolve and dispatch the view
209        if response is None:
210            try:
211                resolver_match = self._resolve_request(request)
212                view = resolver_match.view_class(
213                    request=request,
214                    url_args=resolver_match.args,
215                    url_kwargs=resolver_match.kwargs,
216                )
217                response = view.get_response()
218                view_class = type(view)
219
220                # Async views return a coroutine that must be awaited
221                if inspect.iscoroutine(response):
222                    return _AsyncViewPending(
223                        coroutine=response,
224                        view_class=view_class,
225                        ran_before=ran_before,
226                    )
227
228                self._check_response(response, view_class)
229            except Exception as exc:
230                response = response_for_exception(request, exc)
231
232        # 3. After middleware + request_finished signal
233        return self._finish_pipeline(request, response, ran_before)
234
235    def _finish_pipeline(
236        self,
237        request: Request,
238        response: ResponseBase,
239        ran_before: list[HttpMiddleware],
240    ) -> ResponseBase:
241        """Run after-middleware and send request_finished signal.
242
243        request_finished is sent here (on the same thread as request_started)
244        rather than from response.close(), which runs after the response body
245        is written and may land on a different thread. This means the signal
246        fires before streaming responses are iterated — handlers like
247        close_old_connections should not affect in-progress streams since
248        request_started on the next request also handles stale connections.
249        """
250        response = self._run_after_response(request, response, ran_before)
251        signals.request_finished.send(sender=self.__class__)
252        return response
253
254    def _resolve_request(self, request: Request) -> ResolverMatch:
255        """Resolve the URL, caching on request.resolver_match."""
256        if request.resolver_match is not None:
257            resolver_match = request.resolver_match
258        else:
259            resolver = get_resolver()
260            resolver_match = resolver.resolve(request.path_info)
261            request.resolver_match = resolver_match
262
263        # Update span with route info
264        span = trace.get_current_span()
265        if resolver_match.route:
266            route_with_slash = f"/{resolver_match.route}"
267            span.set_attribute(http_attributes.HTTP_ROUTE, route_with_slash)
268            span.update_name(f"{request.method} {route_with_slash}")
269
270        return resolver_match
271
272    def _run_before_request(
273        self, request: Request
274    ) -> tuple[ResponseBase | None, list[HttpMiddleware]]:
275        """Run before_request forward through middleware chain."""
276        chain = self._middleware_chain
277        assert chain is not None
278
279        response = None
280        ran_before: list[HttpMiddleware] = []
281
282        for mw in chain:
283            try:
284                result = mw.before_request(request)
285            except Exception as exc:
286                response = response_for_exception(request, exc)
287                break
288
289            ran_before.append(mw)
290
291            if result is not None:
292                response = result
293                break
294
295        return response, ran_before
296
297    def _run_after_response(
298        self,
299        request: Request,
300        response: ResponseBase,
301        ran_before: list[HttpMiddleware],
302    ) -> ResponseBase:
303        """Run after_response in reverse through middleware that ran before_request."""
304        for mw in reversed(ran_before):
305            try:
306                response = mw.after_response(request, response)  # type: ignore[arg-type]
307            except Exception as exc:
308                response = response_for_exception(request, exc)
309
310        return response
311
312    def _check_response(
313        self,
314        response: ResponseBase | None,
315        view_class: type,
316    ) -> None:
317        """Raise an error if the view returned None."""
318        if response is None:
319            name = f"{view_class.__module__}.{view_class.__qualname__}"
320            raise ValueError(
321                f"The view {name} didn't return a Response object. It returned None instead."
322            )