Plain is headed towards 1.0! Subscribe for development updates →

  1import datetime
  2import logging
  3from functools import cached_property
  4
  5from plain.exceptions import PermissionDenied, ValidationError
  6from plain.forms.exceptions import FormFieldMissingError
  7from plain.http import Http404, JsonResponse, Response
  8from plain.utils import timezone
  9from plain.utils.cache import patch_cache_control
 10from plain.views.base import View
 11from plain.views.csrf import CsrfExemptViewMixin
 12from plain.views.exceptions import ResponseException
 13
 14from . import openapi
 15from .schemas import ErrorSchema
 16
 17# Allow plain.api to be used without plain.models
 18try:
 19    from .models import APIKey
 20except ImportError:
 21    APIKey = None
 22
 23logger = logging.getLogger("plain.api")
 24
 25
 26# @openapi.response_typed_dict(400, ErrorSchema)
 27# @openapi.response_typed_dict(401, ErrorSchema)
 28class APIKeyView(View):
 29    api_key_required = True
 30
 31    @cached_property
 32    def api_key(self):
 33        return self.get_api_key()
 34
 35    def get_response(self) -> Response:
 36        if self.api_key:
 37            self.use_api_key()
 38        elif self.api_key_required:
 39            return JsonResponse(
 40                ErrorSchema(
 41                    id="api_key_required",
 42                    message="API key required",
 43                ),
 44                status_code=401,
 45            )
 46
 47        response = super().get_response()
 48        # Make sure it at least has private as a default
 49        patch_cache_control(response, private=True)
 50        return response
 51
 52    def use_api_key(self):
 53        """
 54        Use the API key for this request.
 55
 56        Override this to perform other actions with a valid API key.
 57        """
 58        self.api_key.last_used_at = timezone.now()
 59        self.api_key.save(update_fields=["last_used_at"])
 60
 61    def get_api_key(self):
 62        """
 63        Get the API key from the request.
 64
 65        Override this if you want to use a different input method.
 66        """
 67        if "Authorization" in self.request.headers:
 68            header_value = self.request.headers["Authorization"]
 69            try:
 70                header_token = header_value.split("Bearer ")[1]
 71            except IndexError:
 72                raise ResponseException(
 73                    JsonResponse(
 74                        ErrorSchema(
 75                            id="invalid_authorization_header",
 76                            message="Invalid Authorization header",
 77                        ),
 78                        status_code=400,
 79                    )
 80                )
 81
 82            try:
 83                api_key = APIKey.objects.get(token=header_token)
 84            except APIKey.DoesNotExist:
 85                raise ResponseException(
 86                    JsonResponse(
 87                        ErrorSchema(
 88                            id="invalid_api_token",
 89                            message="Invalid API token",
 90                        ),
 91                        status_code=400,
 92                    )
 93                )
 94
 95            if api_key.expires_at and api_key.expires_at < datetime.datetime.now():
 96                raise ResponseException(
 97                    JsonResponse(
 98                        ErrorSchema(
 99                            id="api_token_expired",
100                            message="API token has expired",
101                        ),
102                        status_code=400,
103                    )
104                )
105
106            return api_key
107
108
109@openapi.response_typed_dict(400, ErrorSchema, component_name="BadRequest")
110@openapi.response_typed_dict(401, ErrorSchema, component_name="Unauthorized")
111@openapi.response_typed_dict(403, ErrorSchema, component_name="Forbidden")
112@openapi.response_typed_dict(404, ErrorSchema, component_name="NotFound")
113@openapi.response_typed_dict(
114    "5XX", ErrorSchema, description="Unexpected Error", component_name="ServerError"
115)
116class APIView(CsrfExemptViewMixin, View):
117    def get_response(self):
118        try:
119            return super().get_response()
120        except ResponseException as e:
121            # Catch any response exceptions in APIKeyView or elsewhere before View.get_response
122            return e.response
123        except ValidationError as e:
124            return JsonResponse(
125                ErrorSchema(
126                    id="validation_error",
127                    message=f"Validation error: {e.message}",
128                    # "errors": {field: e.errors[field] for field in e.errors},
129                ),
130                status_code=400,
131            )
132        except FormFieldMissingError as e:
133            return JsonResponse(
134                ErrorSchema(
135                    id="missing_field",
136                    message=f"Missing field: {e.field_name}",
137                ),
138                status_code=400,
139            )
140        except PermissionDenied:
141            return JsonResponse(
142                ErrorSchema(
143                    id="permission_denied",
144                    message="Permission denied",
145                ),
146                status_code=403,
147            )
148        except Http404:
149            return JsonResponse(
150                ErrorSchema(
151                    id="not_found",
152                    message="Not found",
153                ),
154                status_code=404,
155            )
156        except Exception:
157            logger.exception("Internal server error", extra={"request": self.request})
158            return JsonResponse(
159                ErrorSchema(
160                    id="server_error",
161                    message="Internal server error",
162                ),
163                status_code=500,
164            )