v0.150.0
  1from __future__ import annotations
  2
  3import datetime
  4import io
  5import json
  6import mimetypes
  7import os
  8import re
  9import sys
 10import time
 11from collections.abc import AsyncIterator, Iterator
 12from email.header import Header
 13from http.client import responses
 14from http.cookies import SimpleCookie
 15from typing import IO, Any
 16
 17from plain.http.cookie import sign_cookie_value
 18from plain.json import PlainJSONEncoder
 19from plain.utils import timezone
 20from plain.utils.datastructures import CaseInsensitiveMapping
 21from plain.utils.encoding import iri_to_uri
 22from plain.utils.http import content_disposition_header, http_date
 23from plain.utils.regex_helper import _lazy_re_compile
 24
 25_charset_from_content_type_re = _lazy_re_compile(
 26    r";\s*charset=(?P<charset>[^\s;]+)", re.I
 27)
 28
 29
 30class ResponseHeaders(CaseInsensitiveMapping):
 31    def __init__(self, data: dict[str, Any] | None = None):
 32        """
 33        Populate the initial data using __setitem__ to ensure values are
 34        correctly encoded.
 35        """
 36        self._store = {}
 37        if data:
 38            for header, value in self._unpack_items(data):
 39                self[header] = value
 40
 41    def _convert_to_charset(
 42        self, value: str | bytes, charset: str, mime_encode: bool = False
 43    ) -> str:
 44        """
 45        Convert headers key/value to ascii/latin-1 native strings.
 46        `charset` must be 'ascii' or 'latin-1'. If `mime_encode` is True and
 47        `value` can't be represented in the given charset, apply MIME-encoding.
 48        """
 49        try:
 50            if isinstance(value, str):
 51                # Ensure string is valid in given charset
 52                value.encode(charset)
 53            elif isinstance(value, bytes):
 54                # Convert bytestring using given charset
 55                value = value.decode(charset)
 56            else:
 57                value = str(value)
 58                # Ensure string is valid in given charset.
 59                value.encode(charset)
 60            if "\n" in value or "\r" in value:
 61                raise BadHeaderError(
 62                    f"Header values can't contain newlines (got {value!r})"
 63                )
 64        except UnicodeError as e:
 65            # Encoding to a string of the specified charset failed, but we
 66            # don't know what type that value was, or if it contains newlines,
 67            # which we may need to check for before sending it to be
 68            # encoded for multiple character sets.
 69            if (isinstance(value, bytes) and (b"\n" in value or b"\r" in value)) or (
 70                isinstance(value, str) and ("\n" in value or "\r" in value)
 71            ):
 72                raise BadHeaderError(
 73                    f"Header values can't contain newlines (got {value!r})"
 74                ) from e
 75            if mime_encode:
 76                value = Header(value, "utf-8", maxlinelen=sys.maxsize).encode()
 77            else:
 78                if hasattr(e, "reason") and isinstance(e.reason, str):
 79                    e.reason += f", HTTP response headers must be in {charset} format"
 80                raise
 81        return value
 82
 83    def __delitem__(self, key: str) -> None:
 84        self.pop(key)
 85
 86    def __setitem__(self, key: str, value: str | bytes | None) -> None:
 87        key = self._convert_to_charset(key, "ascii")
 88        if value is None:
 89            self._store[key.lower()] = (key, None)
 90        else:
 91            value = self._convert_to_charset(value, "latin-1", mime_encode=True)
 92            self._store[key.lower()] = (key, value)
 93
 94    def pop(self, key: str, default: Any = None) -> Any:
 95        return self._store.pop(key.lower(), default)
 96
 97    def setdefault(self, key: str, value: str | bytes) -> None:
 98        if key not in self:
 99            self[key] = value
100
101
102class BadHeaderError(ValueError):
103    pass
104
105
106# Private sentinel streaming subclasses pass to skip bytes-body setup in
107# Response.__init__. Using a dedicated object (not None) keeps Response(None)
108# working as before — it goes through the content setter and becomes b"None".
109_NO_CONTENT: Any = object()
110
111
112class Response:
113    """
114    An HTTP response class with a bytes body.
115
116    Base class for all response types — streaming variants subclass this
117    and swap the body for an iterator. Users annotate handler returns and
118    middleware with `Response` to cover all response shapes.
119    """
120
121    status_code = 200
122    streaming = False
123
124    def __init__(
125        self,
126        content: bytes | str | Iterator[bytes] = b"",
127        *,
128        content_type: str | None = None,
129        status_code: int | None = None,
130        reason: str | None = None,
131        charset: str | None = None,
132        headers: dict[str, Any] | None = None,
133    ):
134        self.headers = ResponseHeaders(headers)
135        self._charset = charset
136        if "Content-Type" not in self.headers:
137            if content_type is None:
138                content_type = f"text/html; charset={self.charset}"
139            self.headers["Content-Type"] = content_type
140        elif content_type:
141            raise ValueError(
142                "'headers' must not contain 'Content-Type' when the "
143                "'content_type' parameter is provided."
144            )
145        self._resource_closers = []
146        self.cookies = SimpleCookie()
147        self.closed = False
148        if status_code is not None:
149            try:
150                self.status_code = int(status_code)
151            except (ValueError, TypeError):
152                raise TypeError("HTTP status code must be an integer.")
153
154            if not 100 <= self.status_code <= 599:
155                raise ValueError("HTTP status code must be an integer from 100 to 599.")
156        self._reason_phrase = reason
157        # Exception that caused this response, if any (primarily for 500 errors)
158        self.exception: Exception | None = None
159        # Whether the server should log this response in the access log
160        self.log_access: bool = True
161        if content is not _NO_CONTENT:
162            self.content = content
163
164    @property
165    def reason_phrase(self) -> str:
166        if self._reason_phrase is not None:
167            return self._reason_phrase
168        # Leave self._reason_phrase unset in order to use the default
169        # reason phrase for status code.
170        return responses.get(self.status_code, "Unknown Status Code")
171
172    @reason_phrase.setter
173    def reason_phrase(self, value: str) -> None:
174        self._reason_phrase = value
175
176    @property
177    def charset(self) -> str:
178        if self._charset is not None:
179            return self._charset
180        # The Content-Type header may not yet be set, because the charset is
181        # being inserted *into* it.
182        if content_type := self.headers.get("Content-Type"):
183            if matched := _charset_from_content_type_re.search(content_type):
184                # Extract the charset and strip its double quotes.
185                # Note that having parsed it from the Content-Type, we don't
186                # store it back into the _charset for later intentionally, to
187                # allow for the Content-Type to be switched again later.
188                return matched["charset"].replace('"', "")
189        return "utf-8"
190
191    @charset.setter
192    def charset(self, value: str) -> None:
193        self._charset = value
194
195    @property
196    def _content_type_for_repr(self) -> str:
197        return (
198            ', "{}"'.format(self.headers["Content-Type"])
199            if "Content-Type" in self.headers
200            else ""
201        )
202
203    def set_cookie(
204        self,
205        key: str,
206        value: str = "",
207        max_age: int | float | datetime.timedelta | None = None,
208        expires: str | datetime.datetime | None = None,
209        path: str | None = "/",
210        domain: str | None = None,
211        secure: bool = False,
212        httponly: bool = False,
213        samesite: str | None = None,
214    ) -> None:
215        """
216        Set a cookie.
217
218        ``expires`` can be:
219        - a string in the correct format,
220        - a naive ``datetime.datetime`` object in UTC,
221        - an aware ``datetime.datetime`` object in any time zone.
222        If it is a ``datetime.datetime`` object then calculate ``max_age``.
223
224        ``max_age`` can be:
225        - int/float specifying seconds,
226        - ``datetime.timedelta`` object.
227        """
228        self.cookies[key] = value
229        if expires is not None:
230            if isinstance(expires, datetime.datetime):
231                if timezone.is_naive(expires):
232                    expires = timezone.make_aware(expires, datetime.UTC)
233                delta = expires - datetime.datetime.now(tz=datetime.UTC)
234                # Add one second so the date matches exactly (a fraction of
235                # time gets lost between converting to a timedelta and
236                # then the date string).
237                delta += datetime.timedelta(seconds=1)
238                # Just set max_age - the max_age logic will set expires.
239                expires = None
240                if max_age is not None:
241                    raise ValueError("'expires' and 'max_age' can't be used together.")
242                max_age = max(0, delta.days * 86400 + delta.seconds)
243            else:
244                self.cookies[key]["expires"] = expires
245        else:
246            self.cookies[key]["expires"] = ""
247        if max_age is not None:
248            if isinstance(max_age, datetime.timedelta):
249                max_age = max_age.total_seconds()
250            self.cookies[key]["max-age"] = int(max_age)
251            # IE requires expires, so set it if hasn't been already.
252            if not expires:
253                self.cookies[key]["expires"] = http_date(time.time() + max_age)
254        if path is not None:
255            self.cookies[key]["path"] = path
256        if domain is not None:
257            self.cookies[key]["domain"] = domain
258        if secure:
259            self.cookies[key]["secure"] = True
260        if httponly:
261            self.cookies[key]["httponly"] = True
262        if samesite:
263            if samesite.lower() not in ("lax", "none", "strict"):
264                raise ValueError('samesite must be "lax", "none", or "strict".')
265            self.cookies[key]["samesite"] = samesite
266
267    def set_signed_cookie(
268        self, key: str, value: str, salt: str = "", **kwargs: Any
269    ) -> None:
270        """Set a cookie signed with the SECRET_KEY."""
271
272        signed_value = sign_cookie_value(key, value, salt)
273        return self.set_cookie(key, signed_value, **kwargs)
274
275    def delete_cookie(
276        self,
277        key: str,
278        path: str = "/",
279        domain: str | None = None,
280        samesite: str | None = None,
281    ) -> None:
282        # Browsers can ignore the Set-Cookie header if the cookie doesn't use
283        # the secure flag and:
284        # - the cookie name starts with "__Host-" or "__Secure-", or
285        # - the samesite is "none".
286        secure = key.startswith(("__Secure-", "__Host-")) or bool(
287            samesite and samesite.lower() == "none"
288        )
289        self.set_cookie(
290            key,
291            max_age=0,
292            path=path,
293            domain=domain,
294            secure=secure,
295            expires="Thu, 01 Jan 1970 00:00:00 GMT",
296            samesite=samesite,
297        )
298
299    # Common methods used by subclasses
300
301    def make_bytes(self, value: str | bytes) -> bytes:
302        """Turn a value into a bytestring encoded in the output charset."""
303        # Per PEP 3333, this response body must be bytes. To avoid returning
304        # an instance of a subclass, this function returns `bytes(value)`.
305        # This doesn't make a copy when `value` already contains bytes.
306
307        # Handle string types -- we can't rely on force_bytes here because:
308        # - Python attempts str conversion first
309        # - when self._charset != 'utf-8' it re-encodes the content
310        if isinstance(value, bytes | memoryview):
311            return bytes(value)
312        if isinstance(value, str):
313            return bytes(value.encode(self.charset))
314        # Handle non-string types.
315        return str(value).encode(self.charset)
316
317    # The server must call this method upon completion of the request.
318    # See http://blog.dscpl.com.au/2012/10/obligations-for-calling-close-on.html
319    def close(self) -> None:
320        if self.closed:
321            return
322        for closer in self._resource_closers:
323            try:
324                closer()
325            except Exception:
326                pass
327        # Free resources that were still referenced.
328        self._resource_closers.clear()
329        self.closed = True
330
331    def __repr__(self) -> str:
332        return "<%(cls)s status_code=%(status_code)d%(content_type)s>" % {  # noqa: UP031
333            "cls": self.__class__.__name__,
334            "status_code": self.status_code,
335            "content_type": self._content_type_for_repr,
336        }
337
338    @property
339    def content(self) -> bytes:
340        return b"".join(self._container)
341
342    @content.setter
343    def content(self, value: bytes | str | Iterator[bytes]) -> None:
344        # Consume iterators upon assignment to allow repeated iteration.
345        if hasattr(value, "__iter__") and not isinstance(
346            value, bytes | memoryview | str
347        ):
348            content = b"".join(self.make_bytes(chunk) for chunk in value)
349            if hasattr(value, "close") and callable(getattr(value, "close")):
350                try:
351                    value.close()  # ty: ignore[call-non-callable]
352                except Exception:
353                    pass
354        else:
355            content = self.make_bytes(value)
356        self._container = [content]
357
358    def __iter__(self) -> Iterator[bytes]:
359        return iter(self._container)
360
361
362class StreamingResponse(Response):
363    """
364    A streaming HTTP response class with an iterator as content.
365
366    This should only be iterated once, when the response is streamed to the
367    client. However, it can be appended to or replaced with a new iterator
368    that wraps the original content (or yields entirely new content).
369    """
370
371    streaming = True
372
373    def __init__(
374        self,
375        streaming_content: Any = (),
376        *,
377        content_type: str | None = None,
378        status_code: int | None = None,
379        reason: str | None = None,
380        charset: str | None = None,
381        headers: dict[str, Any] | None = None,
382    ):
383        super().__init__(
384            content=_NO_CONTENT,
385            content_type=content_type,
386            status_code=status_code,
387            reason=reason,
388            charset=charset,
389            headers=headers,
390        )
391        # `streaming_content` should be an iterable of bytestrings.
392        # See the `streaming_content` property methods.
393        self.streaming_content = streaming_content
394
395    @property
396    def content(self) -> bytes:
397        raise AttributeError(
398            f"This {self.__class__.__name__} instance has no `content` attribute. Use "
399            "`streaming_content` instead."
400        )
401
402    @property
403    def streaming_content(self) -> Iterator[bytes]:
404        return map(self.make_bytes, self._iterator)
405
406    @streaming_content.setter
407    def streaming_content(self, value: Iterator[bytes | str]) -> None:
408        self._set_streaming_content(value)
409
410    def _set_streaming_content(self, value: Iterator[bytes | str]) -> None:
411        # Ensure we can never iterate on "value" more than once.
412        self._iterator = iter(value)
413        if hasattr(value, "close"):
414            self._resource_closers.append(value.close)
415
416    def __iter__(self) -> Iterator[bytes]:
417        return iter(self.streaming_content)
418
419
420class AsyncStreamingResponse(Response):
421    """
422    A streaming HTTP response class with an async iterator as content.
423
424    Used for long-lived connections like Server-Sent Events (SSE) where
425    data arrives asynchronously and should be streamed to the client
426    without buffering the entire response.
427    """
428
429    streaming = True
430
431    def __init__(
432        self,
433        streaming_content: AsyncIterator[bytes | str],
434        *,
435        content_type: str | None = None,
436        status_code: int | None = None,
437        reason: str | None = None,
438        charset: str | None = None,
439        headers: dict[str, Any] | None = None,
440    ):
441        super().__init__(
442            content=_NO_CONTENT,
443            content_type=content_type,
444            status_code=status_code,
445            reason=reason,
446            charset=charset,
447            headers=headers,
448        )
449        self._async_iterator = streaming_content
450
451    @property
452    def content(self) -> bytes:
453        raise AttributeError(
454            f"This {self.__class__.__name__} instance has no `content` attribute. Use "
455            "`streaming_content` instead."
456        )
457
458    def __iter__(self) -> Iterator[bytes]:
459        raise TypeError(
460            f"{self.__class__.__name__} is async — use `async for` / `__aiter__` instead."
461        )
462
463    async def __aiter__(self) -> AsyncIterator[bytes]:
464        async for chunk in self._async_iterator:
465            yield self.make_bytes(chunk)
466
467    async def aclose(self) -> None:
468        """Close the underlying async iterator if it supports it."""
469        close = getattr(self._async_iterator, "aclose", None)
470        if close is not None:
471            await close()
472
473
474class FileResponse(StreamingResponse):
475    """
476    A streaming HTTP response class optimized for files.
477    """
478
479    block_size = 4096
480
481    def __init__(
482        self,
483        streaming_content: Any = (),
484        *,
485        as_attachment: bool = False,
486        filename: str = "",
487        content_type: str | None = None,
488        status_code: int | None = None,
489        reason: str | None = None,
490        charset: str | None = None,
491        headers: dict[str, Any] | None = None,
492    ):
493        self.as_attachment = as_attachment
494        self.filename = filename
495        self._no_explicit_content_type = content_type is None
496        super().__init__(
497            streaming_content,
498            content_type=content_type,
499            status_code=status_code,
500            reason=reason,
501            charset=charset,
502            headers=headers,
503        )
504
505    def _set_streaming_content(self, value: Any) -> None:
506        if not hasattr(value, "read"):
507            self.file_to_stream = None
508            return super()._set_streaming_content(value)
509
510        self.file_to_stream = filelike = value
511        if hasattr(filelike, "close"):
512            self._resource_closers.append(filelike.close)
513        value = iter(lambda: filelike.read(self.block_size), b"")
514        self.set_headers(filelike)
515        super()._set_streaming_content(value)
516
517    def set_headers(self, filelike: IO[bytes]) -> None:
518        """
519        Set some common response headers (Content-Length, Content-Type, and
520        Content-Disposition) based on the `filelike` response content.
521        """
522        filename = getattr(filelike, "name", "")
523        filename = filename if isinstance(filename, str) else ""
524        seekable = hasattr(filelike, "seek") and (
525            not hasattr(filelike, "seekable") or filelike.seekable()
526        )
527        if hasattr(filelike, "tell"):
528            if seekable:
529                initial_position = filelike.tell()
530                filelike.seek(0, io.SEEK_END)
531                self.headers["Content-Length"] = str(filelike.tell() - initial_position)
532                filelike.seek(initial_position)
533            elif hasattr(filelike, "getbuffer") and callable(
534                getattr(filelike, "getbuffer")
535            ):
536                self.headers["Content-Length"] = str(
537                    filelike.getbuffer().nbytes - filelike.tell()  # ty: ignore[call-non-callable]
538                )
539            elif os.path.exists(filename):
540                self.headers["Content-Length"] = str(
541                    os.path.getsize(filename) - filelike.tell()
542                )
543        elif seekable:
544            self.headers["Content-Length"] = str(
545                sum(iter(lambda: len(filelike.read(self.block_size)), 0))
546            )
547            filelike.seek(-int(self.headers["Content-Length"]), io.SEEK_END)
548
549        filename = os.path.basename(self.filename or filename)
550        if self._no_explicit_content_type:
551            if filename:
552                content_type, encoding = mimetypes.guess_type(filename)
553                # Encoding isn't set to prevent browsers from automatically
554                # uncompressing files.
555                encoding_types: dict[str, str] = {
556                    "br": "application/x-brotli",
557                    "bzip2": "application/x-bzip",
558                    "compress": "application/x-compress",
559                    "gzip": "application/gzip",
560                    "xz": "application/x-xz",
561                }
562                if encoding and encoding in encoding_types:
563                    content_type = encoding_types[encoding]
564                self.headers["Content-Type"] = (
565                    content_type or "application/octet-stream"
566                )
567            else:
568                self.headers["Content-Type"] = "application/octet-stream"
569
570        if content_disposition := content_disposition_header(
571            self.as_attachment, filename
572        ):
573            self.headers["Content-Disposition"] = content_disposition
574
575
576def _is_external_url(url: str) -> bool:
577    """Check if a URL would redirect to an external host."""
578    if not url:
579        return False
580    # Browsers strip leading whitespace from Location headers
581    url = url.strip()
582    # Browsers normalize backslashes to forward slashes in URLs,
583    # so \\ and /\ are equivalent to //
584    if url[:2].replace("\\", "/") == "//":
585        return True
586    colon_pos = url.find("://")
587    if colon_pos > 0 and url[:colon_pos].isalpha():
588        return True
589    return False
590
591
592class RedirectResponse(Response):
593    """HTTP redirect response"""
594
595    status_code = 302
596
597    def __init__(
598        self,
599        redirect_to: str,
600        *,
601        allow_external: bool = False,
602        content_type: str | None = None,
603        status_code: int | None = None,
604        reason: str | None = None,
605        charset: str | None = None,
606        headers: dict[str, Any] | None = None,
607    ):
608        if not allow_external and _is_external_url(redirect_to):
609            raise ValueError(
610                f"Unsafe redirect URL: {redirect_to!r}. "
611                "RedirectResponse does not allow external URLs by default. "
612                "Use allow_external=True if you intentionally want to redirect "
613                "to an external URL."
614            )
615        super().__init__(
616            content_type=content_type,
617            status_code=status_code,
618            reason=reason,
619            charset=charset,
620            headers=headers,
621        )
622        self.headers["Location"] = iri_to_uri(redirect_to) or ""
623
624    @property
625    def url(self) -> str:
626        return self.headers["Location"]
627
628    def __repr__(self) -> str:
629        return (
630            '<%(cls)s status_code=%(status_code)d%(content_type)s, url="%(url)s">'  # noqa: UP031
631            % {
632                "cls": self.__class__.__name__,
633                "status_code": self.status_code,
634                "content_type": self._content_type_for_repr,
635                "url": self.url,
636            }
637        )
638
639
640class NotModifiedResponse(Response):
641    """HTTP 304 response"""
642
643    status_code = 304
644
645    def __init__(
646        self,
647        *,
648        reason: str | None = None,
649        charset: str | None = None,
650        headers: dict[str, Any] | None = None,
651    ):
652        super().__init__(
653            reason=reason,
654            charset=charset,
655            headers=headers,
656        )
657        del self.headers["content-type"]
658
659    @Response.content.setter
660    def content(self, value: bytes | str | Iterator[bytes]) -> None:
661        if value:
662            raise AttributeError(
663                "You cannot set content to a 304 (Not Modified) response"
664            )
665        self._container = []
666
667
668class NotAllowedResponse(Response):
669    """HTTP 405 response"""
670
671    status_code = 405
672
673    def __init__(
674        self,
675        permitted_methods: list[str],
676        *,
677        content_type: str | None = None,
678        status_code: int | None = None,
679        reason: str | None = None,
680        charset: str | None = None,
681        headers: dict[str, Any] | None = None,
682    ):
683        super().__init__(
684            content_type=content_type,
685            status_code=status_code,
686            reason=reason,
687            charset=charset,
688            headers=headers,
689        )
690        self.headers["Allow"] = ", ".join(permitted_methods)
691
692    def __repr__(self) -> str:
693        return "<%(cls)s [%(methods)s] status_code=%(status_code)d%(content_type)s>" % {  # noqa: UP031
694            "cls": self.__class__.__name__,
695            "status_code": self.status_code,
696            "content_type": self._content_type_for_repr,
697            "methods": self.headers["Allow"],
698        }
699
700
701class JsonResponse(Response):
702    """An HTTP response class that consumes data to be serialized to JSON."""
703
704    def __init__(
705        self,
706        data: Any,
707        *,
708        encoder: type[json.JSONEncoder] = PlainJSONEncoder,
709        json_dumps_params: dict[str, Any] | None = None,
710        content_type: str = "application/json",
711        status_code: int | None = None,
712        reason: str | None = None,
713        charset: str | None = None,
714        headers: dict[str, Any] | None = None,
715    ):
716        if json_dumps_params is None:
717            json_dumps_params = {}
718        data = json.dumps(data, cls=encoder, **json_dumps_params)
719        super().__init__(
720            content=data,
721            content_type=content_type,
722            status_code=status_code,
723            reason=reason,
724            charset=charset,
725            headers=headers,
726        )