1import datetime
2import logging
3from functools import cached_property
4
5from plain.exceptions import PermissionDenied, ValidationError
6from plain.forms.exceptions import FormFieldMissingError
7from plain.http import Http404, JsonResponse, Response
8from plain.utils import timezone
9from plain.utils.cache import patch_cache_control
10from plain.views.base import View
11from plain.views.exceptions import ResponseException
12
13from . import openapi
14from .schemas import ErrorSchema
15
16# Allow plain.api to be used without plain.models
17try:
18 from .models import APIKey
19except ImportError:
20 APIKey = None
21
22logger = logging.getLogger("plain.api")
23
24
25# @openapi.response_typed_dict(400, ErrorSchema)
26# @openapi.response_typed_dict(401, ErrorSchema)
27class APIKeyView(View):
28 api_key_required = True
29
30 @cached_property
31 def api_key(self):
32 return self.get_api_key()
33
34 def get_response(self) -> Response:
35 if self.api_key:
36 self.use_api_key()
37 elif self.api_key_required:
38 return JsonResponse(
39 ErrorSchema(
40 id="api_key_required",
41 message="API key required",
42 ),
43 status_code=401,
44 )
45
46 response = super().get_response()
47 # Make sure it at least has private as a default
48 patch_cache_control(response, private=True)
49 return response
50
51 def use_api_key(self):
52 """
53 Use the API key for this request.
54
55 Override this to perform other actions with a valid API key.
56 """
57 self.api_key.last_used_at = timezone.now()
58 self.api_key.save(update_fields=["last_used_at"])
59
60 def get_api_key(self):
61 """
62 Get the API key from the request.
63
64 Override this if you want to use a different input method.
65 """
66 if "Authorization" in self.request.headers:
67 header_value = self.request.headers["Authorization"]
68 try:
69 header_token = header_value.split("Bearer ")[1]
70 except IndexError:
71 raise ResponseException(
72 JsonResponse(
73 ErrorSchema(
74 id="invalid_authorization_header",
75 message="Invalid Authorization header",
76 ),
77 status_code=400,
78 )
79 )
80
81 try:
82 api_key = APIKey.query.get(token=header_token)
83 except APIKey.DoesNotExist:
84 raise ResponseException(
85 JsonResponse(
86 ErrorSchema(
87 id="invalid_api_token",
88 message="Invalid API token",
89 ),
90 status_code=400,
91 )
92 )
93
94 if api_key.expires_at and api_key.expires_at < datetime.datetime.now():
95 raise ResponseException(
96 JsonResponse(
97 ErrorSchema(
98 id="api_token_expired",
99 message="API token has expired",
100 ),
101 status_code=400,
102 )
103 )
104
105 return api_key
106
107
108@openapi.response_typed_dict(400, ErrorSchema, component_name="BadRequest")
109@openapi.response_typed_dict(401, ErrorSchema, component_name="Unauthorized")
110@openapi.response_typed_dict(403, ErrorSchema, component_name="Forbidden")
111@openapi.response_typed_dict(404, ErrorSchema, component_name="NotFound")
112@openapi.response_typed_dict(
113 "5XX", ErrorSchema, description="Unexpected Error", component_name="ServerError"
114)
115class APIView(View):
116 def get_response(self):
117 try:
118 return super().get_response()
119 except ResponseException as e:
120 # Catch any response exceptions in APIKeyView or elsewhere before View.get_response
121 return e.response
122 except ValidationError as e:
123 return JsonResponse(
124 ErrorSchema(
125 id="validation_error",
126 message=f"Validation error: {e.message}",
127 # "errors": {field: e.errors[field] for field in e.errors},
128 ),
129 status_code=400,
130 )
131 except FormFieldMissingError as e:
132 return JsonResponse(
133 ErrorSchema(
134 id="missing_field",
135 message=f"Missing field: {e.field_name}",
136 ),
137 status_code=400,
138 )
139 except PermissionDenied:
140 return JsonResponse(
141 ErrorSchema(
142 id="permission_denied",
143 message="Permission denied",
144 ),
145 status_code=403,
146 )
147 except Http404:
148 return JsonResponse(
149 ErrorSchema(
150 id="not_found",
151 message="Not found",
152 ),
153 status_code=404,
154 )
155 except Exception:
156 logger.exception("Internal server error", extra={"request": self.request})
157 return JsonResponse(
158 ErrorSchema(
159 id="server_error",
160 message="Internal server error",
161 ),
162 status_code=500,
163 )