v0.148.0
  1from collections.abc import Mapping
  2from functools import cached_property
  3from http.client import responses as http_status_phrases
  4from typing import Any, cast
  5
  6from plain.exceptions import ValidationError
  7from plain.forms.exceptions import FormFieldMissingError
  8from plain.http import (
  9    HTTPException,
 10    JsonResponse,
 11    NotFoundError404,
 12    Response,
 13)
 14from plain.utils import timezone
 15from plain.utils.cache import patch_cache_control
 16from plain.views.base import View
 17from plain.views.exceptions import ResponseException
 18
 19from . import openapi
 20from .schemas import ErrorSchema, FieldError
 21
 22# Allow plain.api to be used without plain.postgres
 23try:
 24    from .models import APIKey
 25except ImportError:
 26    APIKey: Any = None
 27
 28__all__ = [
 29    "APIKeyView",
 30    "APIResult",
 31    "APIView",
 32    "JsonNotFoundView",
 33]
 34
 35# `Mapping[str, Any]` (vs `dict[str, Any]`) lets `def get(self) -> MyTypedDict:`
 36# satisfy Liskov against the base view — TypedDicts aren't `dict` per PEP 589.
 37type APIResult = (
 38    Response
 39    | None
 40    | Mapping[str, Any]
 41    | list[Any]
 42    | tuple[int, dict[str, Any] | list[Any]]
 43)
 44
 45
 46def _error_response(
 47    *,
 48    error_id: str,
 49    message: str,
 50    status_code: int,
 51    errors: list[FieldError] | None = None,
 52) -> JsonResponse:
 53    body: ErrorSchema = {"id": error_id, "message": message}
 54    if errors is not None:
 55        body["errors"] = errors
 56    return JsonResponse(body, status_code=status_code)
 57
 58
 59def _validation_field_errors(exc: ValidationError) -> list[FieldError] | None:
 60    """Flatten a field-dict ValidationError into a list of `{field, message}`.
 61
 62    Returns None for string- or list-shaped errors that have no field context.
 63    """
 64    if not hasattr(exc, "error_dict"):
 65        return None
 66    return [
 67        {"field": field, "message": message}
 68        for field, messages in exc
 69        for message in messages
 70    ]
 71
 72
 73# Snake-case ids are part of the public API surface — client libs key off them.
 74_STATUS_ERROR_IDS = {
 75    400: "bad_request",
 76    401: "unauthorized",
 77    403: "permission_denied",
 78    404: "not_found",
 79    405: "method_not_allowed",
 80    409: "conflict",
 81    415: "unsupported_media_type",
 82    429: "rate_limited",
 83}
 84
 85
 86# @openapi.response_typed_dict(400, ErrorSchema)
 87# @openapi.response_typed_dict(401, ErrorSchema)
 88class APIKeyView(View[APIResult]):
 89    api_key_required = True
 90
 91    # Picked up by the OpenAPI generator: each entry is added to
 92    # `components.securitySchemes` and required on every operation served by
 93    # this view. Subclasses can override to declare a different scheme.
 94    openapi_security_schemes: dict[str, dict[str, Any]] = {
 95        "BearerAuth": {
 96            "type": "http",
 97            "scheme": "bearer",
 98        }
 99    }
100
101    @cached_property
102    def api_key(self) -> Any:
103        return self.get_api_key()
104
105    def before_request(self) -> None:
106        if self.api_key:
107            self.use_api_key()
108        elif self.api_key_required:
109            raise ResponseException(
110                _error_response(
111                    error_id="api_key_required",
112                    message="API key required",
113                    status_code=401,
114                )
115            )
116
117    def after_response(self, response: Response) -> Response:
118        response = super().after_response(response)
119        # Make sure it at least has private as a default
120        patch_cache_control(response, private=True)
121        return response
122
123    def use_api_key(self) -> None:
124        """
125        Use the API key for this request.
126
127        Override this to perform other actions with a valid API key.
128        """
129        self.api_key.last_used_at = timezone.now()
130        self.api_key.save(update_fields=["last_used_at"])
131
132    def get_api_key(self) -> Any:
133        """
134        Get the API key from the request.
135
136        Override this if you want to use a different input method.
137        """
138        if "Authorization" in self.request.headers:
139            header_value = self.request.headers["Authorization"]
140            try:
141                header_token = header_value.split("Bearer ")[1]
142            except IndexError:
143                raise ResponseException(
144                    _error_response(
145                        error_id="invalid_authorization_header",
146                        message="Invalid Authorization header",
147                        status_code=400,
148                    )
149                )
150
151            try:
152                api_key = APIKey.query.get(token=header_token)
153            except APIKey.DoesNotExist:
154                raise ResponseException(
155                    _error_response(
156                        error_id="invalid_api_token",
157                        message="Invalid API token",
158                        status_code=400,
159                    )
160                )
161
162            if api_key.is_expired():
163                raise ResponseException(
164                    _error_response(
165                        error_id="api_token_expired",
166                        message="API token has expired",
167                        status_code=400,
168                    )
169                )
170
171            return api_key
172
173
174@openapi.response_typed_dict(400, ErrorSchema, component_name="BadRequest")
175@openapi.response_typed_dict(401, ErrorSchema, component_name="Unauthorized")
176@openapi.response_typed_dict(403, ErrorSchema, component_name="Forbidden")
177@openapi.response_typed_dict(404, ErrorSchema, component_name="NotFound")
178@openapi.response_typed_dict(
179    "5XX", ErrorSchema, description="Unexpected Error", component_name="ServerError"
180)
181class APIView(View[APIResult]):
182    def convert_result_to_response(self, result: APIResult) -> Response:
183        if isinstance(result, Response):
184            return result
185
186        if result is None:
187            raise NotFoundError404
188
189        status_code = 200
190
191        if isinstance(result, tuple):
192            if len(result) != 2:
193                raise ValueError(
194                    "Tuple response must be of length 2 (status_code, data)"
195                )
196            status_code, result = cast(tuple[int, dict[str, Any] | list[Any]], result)
197
198        if isinstance(result, dict):
199            return JsonResponse(result, status_code=status_code)
200
201        if isinstance(result, list):
202            return JsonResponse(result, status_code=status_code, safe=False)
203
204        raise TypeError(f"Unexpected APIView return type: {type(result).__name__}")
205
206    def handle_exception(self, exc: Exception) -> Response:
207        if isinstance(exc, ValidationError):
208            errors = _validation_field_errors(exc)
209            if errors is not None:
210                message = "Validation error"
211            else:
212                detail = "; ".join(exc.messages) if exc.messages else str(exc)
213                message = f"Validation error: {detail}"
214            return _error_response(
215                error_id="validation_error",
216                message=message,
217                status_code=400,
218                errors=errors,
219            )
220        if isinstance(exc, FormFieldMissingError):
221            return _error_response(
222                error_id="missing_field",
223                message=f"Missing field: {exc.field_name}",
224                status_code=400,
225            )
226        if isinstance(exc, HTTPException):
227            error_id = _STATUS_ERROR_IDS.get(exc.status_code, "http_error")
228            return _error_response(
229                error_id=error_id,
230                message=str(exc)
231                or http_status_phrases.get(exc.status_code, "HTTP error"),
232                status_code=exc.status_code,
233            )
234        return _error_response(
235            error_id="server_error",
236            message="Internal server error",
237            status_code=500,
238        )
239
240
241class JsonNotFoundView(APIView):
242    """Catch-all view that always returns a JSON 404.
243
244    Mount as a regex catch-all at the end of an API router so unmatched
245    paths under your API prefix return a JSON `ErrorSchema` body instead of
246    the framework's HTML 404 page.
247    """
248
249    def before_request(self) -> None:
250        raise NotFoundError404