1from collections.abc import Mapping
2from functools import cached_property
3from http.client import responses as http_status_phrases
4from typing import Any, cast
5
6from plain.exceptions import ValidationError
7from plain.forms.exceptions import FormFieldMissingError
8from plain.http import (
9 HTTPException,
10 JsonResponse,
11 NotFoundError404,
12 Response,
13)
14from plain.utils import timezone
15from plain.utils.cache import patch_cache_control
16from plain.views.base import View
17from plain.views.exceptions import ResponseException
18
19from . import openapi
20from .schemas import ErrorSchema, FieldError
21
22# Allow plain.api to be used without plain.postgres
23try:
24 from .models import APIKey
25except ImportError:
26 APIKey: Any = None
27
28__all__ = [
29 "APIKeyView",
30 "APIResult",
31 "APIView",
32 "JsonNotFoundView",
33]
34
35# `Mapping[str, Any]` (vs `dict[str, Any]`) lets `def get(self) -> MyTypedDict:`
36# satisfy Liskov against the base view — TypedDicts aren't `dict` per PEP 589.
37type APIResult = (
38 Response
39 | None
40 | Mapping[str, Any]
41 | list[Any]
42 | tuple[int, dict[str, Any] | list[Any]]
43)
44
45
46def _error_response(
47 *,
48 error_id: str,
49 message: str,
50 status_code: int,
51 errors: list[FieldError] | None = None,
52) -> JsonResponse:
53 body: ErrorSchema = {"id": error_id, "message": message}
54 if errors is not None:
55 body["errors"] = errors
56 return JsonResponse(body, status_code=status_code)
57
58
59def _validation_field_errors(exc: ValidationError) -> list[FieldError] | None:
60 """Flatten a field-dict ValidationError into a list of `{field, message}`.
61
62 Returns None for string- or list-shaped errors that have no field context.
63 """
64 if not hasattr(exc, "error_dict"):
65 return None
66 return [
67 {"field": field, "message": message}
68 for field, messages in exc
69 for message in messages
70 ]
71
72
73# Snake-case ids are part of the public API surface — client libs key off them.
74_STATUS_ERROR_IDS = {
75 400: "bad_request",
76 401: "unauthorized",
77 403: "permission_denied",
78 404: "not_found",
79 405: "method_not_allowed",
80 409: "conflict",
81 415: "unsupported_media_type",
82 429: "rate_limited",
83}
84
85
86# @openapi.response_typed_dict(400, ErrorSchema)
87# @openapi.response_typed_dict(401, ErrorSchema)
88class APIKeyView(View[APIResult]):
89 api_key_required = True
90
91 # Picked up by the OpenAPI generator: each entry is added to
92 # `components.securitySchemes` and required on every operation served by
93 # this view. Subclasses can override to declare a different scheme.
94 openapi_security_schemes: dict[str, dict[str, Any]] = {
95 "BearerAuth": {
96 "type": "http",
97 "scheme": "bearer",
98 }
99 }
100
101 @cached_property
102 def api_key(self) -> Any:
103 return self.get_api_key()
104
105 def before_request(self) -> None:
106 if self.api_key:
107 self.use_api_key()
108 elif self.api_key_required:
109 raise ResponseException(
110 _error_response(
111 error_id="api_key_required",
112 message="API key required",
113 status_code=401,
114 )
115 )
116
117 def after_response(self, response: Response) -> Response:
118 response = super().after_response(response)
119 # Make sure it at least has private as a default
120 patch_cache_control(response, private=True)
121 return response
122
123 def use_api_key(self) -> None:
124 """
125 Use the API key for this request.
126
127 Override this to perform other actions with a valid API key.
128 """
129 self.api_key.last_used_at = timezone.now()
130 self.api_key.save(update_fields=["last_used_at"])
131
132 def get_api_key(self) -> Any:
133 """
134 Get the API key from the request.
135
136 Override this if you want to use a different input method.
137 """
138 if "Authorization" in self.request.headers:
139 header_value = self.request.headers["Authorization"]
140 try:
141 header_token = header_value.split("Bearer ")[1]
142 except IndexError:
143 raise ResponseException(
144 _error_response(
145 error_id="invalid_authorization_header",
146 message="Invalid Authorization header",
147 status_code=400,
148 )
149 )
150
151 try:
152 api_key = APIKey.query.get(token=header_token)
153 except APIKey.DoesNotExist:
154 raise ResponseException(
155 _error_response(
156 error_id="invalid_api_token",
157 message="Invalid API token",
158 status_code=400,
159 )
160 )
161
162 if api_key.is_expired():
163 raise ResponseException(
164 _error_response(
165 error_id="api_token_expired",
166 message="API token has expired",
167 status_code=400,
168 )
169 )
170
171 return api_key
172
173
174@openapi.response_typed_dict(400, ErrorSchema, component_name="BadRequest")
175@openapi.response_typed_dict(401, ErrorSchema, component_name="Unauthorized")
176@openapi.response_typed_dict(403, ErrorSchema, component_name="Forbidden")
177@openapi.response_typed_dict(404, ErrorSchema, component_name="NotFound")
178@openapi.response_typed_dict(
179 "5XX", ErrorSchema, description="Unexpected Error", component_name="ServerError"
180)
181class APIView(View[APIResult]):
182 def convert_result_to_response(self, result: APIResult) -> Response:
183 if isinstance(result, Response):
184 return result
185
186 if result is None:
187 raise NotFoundError404
188
189 status_code = 200
190
191 if isinstance(result, tuple):
192 if len(result) != 2:
193 raise ValueError(
194 "Tuple response must be of length 2 (status_code, data)"
195 )
196 status_code, result = cast(tuple[int, dict[str, Any] | list[Any]], result)
197
198 if isinstance(result, dict):
199 return JsonResponse(result, status_code=status_code)
200
201 if isinstance(result, list):
202 return JsonResponse(result, status_code=status_code, safe=False)
203
204 raise TypeError(f"Unexpected APIView return type: {type(result).__name__}")
205
206 def handle_exception(self, exc: Exception) -> Response:
207 if isinstance(exc, ValidationError):
208 errors = _validation_field_errors(exc)
209 if errors is not None:
210 message = "Validation error"
211 else:
212 detail = "; ".join(exc.messages) if exc.messages else str(exc)
213 message = f"Validation error: {detail}"
214 return _error_response(
215 error_id="validation_error",
216 message=message,
217 status_code=400,
218 errors=errors,
219 )
220 if isinstance(exc, FormFieldMissingError):
221 return _error_response(
222 error_id="missing_field",
223 message=f"Missing field: {exc.field_name}",
224 status_code=400,
225 )
226 if isinstance(exc, HTTPException):
227 error_id = _STATUS_ERROR_IDS.get(exc.status_code, "http_error")
228 return _error_response(
229 error_id=error_id,
230 message=str(exc)
231 or http_status_phrases.get(exc.status_code, "HTTP error"),
232 status_code=exc.status_code,
233 )
234 return _error_response(
235 error_id="server_error",
236 message="Internal server error",
237 status_code=500,
238 )
239
240
241class JsonNotFoundView(APIView):
242 """Catch-all view that always returns a JSON 404.
243
244 Mount as a regex catch-all at the end of an API router so unmatched
245 paths under your API prefix return a JSON `ErrorSchema` body instead of
246 the framework's HTML 404 page.
247 """
248
249 def before_request(self) -> None:
250 raise NotFoundError404