1import datetime
2import logging
3from functools import cached_property
4from typing import Any
5
6from plain.exceptions import PermissionDenied, ValidationError
7from plain.forms.exceptions import FormFieldMissingError
8from plain.http import Http404, JsonResponse, ResponseBase
9from plain.utils import timezone
10from plain.utils.cache import patch_cache_control
11from plain.views.base import View
12from plain.views.exceptions import ResponseException
13
14from . import openapi
15from .schemas import ErrorSchema
16
17# Allow plain.api to be used without plain.models
18try:
19 from .models import APIKey
20except ImportError:
21 APIKey = None # type: ignore[assignment]
22
23logger = logging.getLogger("plain.api")
24
25
26# @openapi.response_typed_dict(400, ErrorSchema)
27# @openapi.response_typed_dict(401, ErrorSchema)
28class APIKeyView(View):
29 api_key_required = True
30
31 @cached_property
32 def api_key(self) -> Any:
33 return self.get_api_key()
34
35 def get_response(self) -> ResponseBase:
36 if self.api_key:
37 self.use_api_key()
38 elif self.api_key_required:
39 return JsonResponse(
40 ErrorSchema(
41 id="api_key_required",
42 message="API key required",
43 url="",
44 ),
45 status_code=401,
46 )
47
48 response = super().get_response()
49 # Make sure it at least has private as a default
50 patch_cache_control(response, private=True)
51 return response
52
53 def use_api_key(self) -> None:
54 """
55 Use the API key for this request.
56
57 Override this to perform other actions with a valid API key.
58 """
59 self.api_key.last_used_at = timezone.now()
60 self.api_key.save(update_fields=["last_used_at"])
61
62 def get_api_key(self) -> Any:
63 """
64 Get the API key from the request.
65
66 Override this if you want to use a different input method.
67 """
68 if "Authorization" in self.request.headers:
69 header_value = self.request.headers["Authorization"]
70 try:
71 header_token = header_value.split("Bearer ")[1]
72 except IndexError:
73 raise ResponseException(
74 JsonResponse(
75 ErrorSchema(
76 id="invalid_authorization_header",
77 message="Invalid Authorization header",
78 url="",
79 ),
80 status_code=400,
81 )
82 )
83
84 try:
85 api_key = APIKey.query.get(token=header_token)
86 except APIKey.DoesNotExist:
87 raise ResponseException(
88 JsonResponse(
89 ErrorSchema(
90 id="invalid_api_token",
91 message="Invalid API token",
92 url="",
93 ),
94 status_code=400,
95 )
96 )
97
98 if api_key.expires_at and api_key.expires_at < datetime.datetime.now():
99 raise ResponseException(
100 JsonResponse(
101 ErrorSchema(
102 id="api_token_expired",
103 message="API token has expired",
104 url="",
105 ),
106 status_code=400,
107 )
108 )
109
110 return api_key
111
112
113@openapi.response_typed_dict(400, ErrorSchema, component_name="BadRequest")
114@openapi.response_typed_dict(401, ErrorSchema, component_name="Unauthorized")
115@openapi.response_typed_dict(403, ErrorSchema, component_name="Forbidden")
116@openapi.response_typed_dict(404, ErrorSchema, component_name="NotFound")
117@openapi.response_typed_dict(
118 "5XX", ErrorSchema, description="Unexpected Error", component_name="ServerError"
119)
120class APIView(View):
121 def get_response(self) -> ResponseBase:
122 try:
123 return super().get_response()
124 except ResponseException as e:
125 # Catch any response exceptions in APIKeyView or elsewhere before View.get_response
126 return e.response
127 except ValidationError as e:
128 return JsonResponse(
129 ErrorSchema(
130 id="validation_error",
131 message=f"Validation error: {e.message}",
132 url="",
133 # "errors": {field: e.errors[field] for field in e.errors},
134 ),
135 status_code=400,
136 )
137 except FormFieldMissingError as e:
138 return JsonResponse(
139 ErrorSchema(
140 id="missing_field",
141 message=f"Missing field: {e.field_name}",
142 url="",
143 ),
144 status_code=400,
145 )
146 except PermissionDenied:
147 return JsonResponse(
148 ErrorSchema(
149 id="permission_denied",
150 message="Permission denied",
151 url="",
152 ),
153 status_code=403,
154 )
155 except Http404:
156 return JsonResponse(
157 ErrorSchema(
158 id="not_found",
159 message="Not found",
160 url="",
161 ),
162 status_code=404,
163 )
164 except Exception:
165 logger.exception("Internal server error", extra={"request": self.request})
166 return JsonResponse(
167 ErrorSchema(
168 id="server_error",
169 message="Internal server error",
170 url="",
171 ),
172 status_code=500,
173 )