1import datetime
2import logging
3from functools import cached_property
4from typing import Any
5
6from plain.exceptions import ValidationError
7from plain.forms.exceptions import FormFieldMissingError
8from plain.http import ForbiddenError403, JsonResponse, NotFoundError404, ResponseBase
9from plain.utils import timezone
10from plain.utils.cache import patch_cache_control
11from plain.views.base import View
12from plain.views.exceptions import ResponseException
13
14from . import openapi
15from .schemas import ErrorSchema
16
17# Allow plain.api to be used without plain.models
18try:
19 from .models import APIKey
20except ImportError:
21 APIKey = None # type: ignore[misc, assignment]
22
23__all__ = [
24 "APIKeyView",
25 "APIView",
26]
27
28logger = logging.getLogger("plain.api")
29
30
31# @openapi.response_typed_dict(400, ErrorSchema)
32# @openapi.response_typed_dict(401, ErrorSchema)
33class APIKeyView(View):
34 api_key_required = True
35
36 @cached_property
37 def api_key(self) -> Any:
38 return self.get_api_key()
39
40 def get_response(self) -> ResponseBase:
41 if self.api_key:
42 self.use_api_key()
43 elif self.api_key_required:
44 return JsonResponse(
45 ErrorSchema(
46 id="api_key_required",
47 message="API key required",
48 url="",
49 ),
50 status_code=401,
51 )
52
53 response = super().get_response()
54 # Make sure it at least has private as a default
55 patch_cache_control(response, private=True)
56 return response
57
58 def use_api_key(self) -> None:
59 """
60 Use the API key for this request.
61
62 Override this to perform other actions with a valid API key.
63 """
64 self.api_key.last_used_at = timezone.now()
65 self.api_key.save(update_fields=["last_used_at"])
66
67 def get_api_key(self) -> Any:
68 """
69 Get the API key from the request.
70
71 Override this if you want to use a different input method.
72 """
73 if "Authorization" in self.request.headers:
74 header_value = self.request.headers["Authorization"]
75 try:
76 header_token = header_value.split("Bearer ")[1]
77 except IndexError:
78 raise ResponseException(
79 JsonResponse(
80 ErrorSchema(
81 id="invalid_authorization_header",
82 message="Invalid Authorization header",
83 url="",
84 ),
85 status_code=400,
86 )
87 )
88
89 try:
90 api_key = APIKey.query.get(token=header_token)
91 except APIKey.DoesNotExist:
92 raise ResponseException(
93 JsonResponse(
94 ErrorSchema(
95 id="invalid_api_token",
96 message="Invalid API token",
97 url="",
98 ),
99 status_code=400,
100 )
101 )
102
103 if api_key.expires_at and api_key.expires_at < datetime.datetime.now():
104 raise ResponseException(
105 JsonResponse(
106 ErrorSchema(
107 id="api_token_expired",
108 message="API token has expired",
109 url="",
110 ),
111 status_code=400,
112 )
113 )
114
115 return api_key
116
117
118@openapi.response_typed_dict(400, ErrorSchema, component_name="BadRequest")
119@openapi.response_typed_dict(401, ErrorSchema, component_name="Unauthorized")
120@openapi.response_typed_dict(403, ErrorSchema, component_name="Forbidden")
121@openapi.response_typed_dict(404, ErrorSchema, component_name="NotFound")
122@openapi.response_typed_dict(
123 "5XX", ErrorSchema, description="Unexpected Error", component_name="ServerError"
124)
125class APIView(View):
126 def get_response(self) -> ResponseBase:
127 try:
128 return super().get_response()
129 except ResponseException as e:
130 # Catch any response exceptions in APIKeyView or elsewhere before View.get_response
131 return e.response
132 except ValidationError as e:
133 return JsonResponse(
134 ErrorSchema(
135 id="validation_error",
136 message=f"Validation error: {e.message}",
137 url="",
138 # "errors": {field: e.errors[field] for field in e.errors},
139 ),
140 status_code=400,
141 )
142 except FormFieldMissingError as e:
143 return JsonResponse(
144 ErrorSchema(
145 id="missing_field",
146 message=f"Missing field: {e.field_name}",
147 url="",
148 ),
149 status_code=400,
150 )
151 except ForbiddenError403:
152 return JsonResponse(
153 ErrorSchema(
154 id="permission_denied",
155 message="Permission denied",
156 url="",
157 ),
158 status_code=403,
159 )
160 except NotFoundError404:
161 return JsonResponse(
162 ErrorSchema(
163 id="not_found",
164 message="Not found",
165 url="",
166 ),
167 status_code=404,
168 )
169 except Exception:
170 logger.exception("Internal server error", extra={"request": self.request})
171 return JsonResponse(
172 ErrorSchema(
173 id="server_error",
174 message="Internal server error",
175 url="",
176 ),
177 status_code=500,
178 )