Plain is headed towards 1.0! Subscribe for development updates →

  1import datetime
  2import logging
  3from functools import cached_property
  4
  5from plain.exceptions import PermissionDenied, ValidationError
  6from plain.forms.exceptions import FormFieldMissingError
  7from plain.http import Http404, JsonResponse, Response
  8from plain.utils import timezone
  9from plain.utils.cache import patch_cache_control
 10from plain.views.base import View
 11from plain.views.exceptions import ResponseException
 12
 13from . import openapi
 14from .schemas import ErrorSchema
 15
 16# Allow plain.api to be used without plain.models
 17try:
 18    from .models import APIKey
 19except ImportError:
 20    APIKey = None
 21
 22logger = logging.getLogger("plain.api")
 23
 24
 25# @openapi.response_typed_dict(400, ErrorSchema)
 26# @openapi.response_typed_dict(401, ErrorSchema)
 27class APIKeyView(View):
 28    api_key_required = True
 29
 30    @cached_property
 31    def api_key(self):
 32        return self.get_api_key()
 33
 34    def get_response(self) -> Response:
 35        if self.api_key:
 36            self.use_api_key()
 37        elif self.api_key_required:
 38            return JsonResponse(
 39                ErrorSchema(
 40                    id="api_key_required",
 41                    message="API key required",
 42                ),
 43                status_code=401,
 44            )
 45
 46        response = super().get_response()
 47        # Make sure it at least has private as a default
 48        patch_cache_control(response, private=True)
 49        return response
 50
 51    def use_api_key(self):
 52        """
 53        Use the API key for this request.
 54
 55        Override this to perform other actions with a valid API key.
 56        """
 57        self.api_key.last_used_at = timezone.now()
 58        self.api_key.save(update_fields=["last_used_at"])
 59
 60    def get_api_key(self):
 61        """
 62        Get the API key from the request.
 63
 64        Override this if you want to use a different input method.
 65        """
 66        if "Authorization" in self.request.headers:
 67            header_value = self.request.headers["Authorization"]
 68            try:
 69                header_token = header_value.split("Bearer ")[1]
 70            except IndexError:
 71                raise ResponseException(
 72                    JsonResponse(
 73                        ErrorSchema(
 74                            id="invalid_authorization_header",
 75                            message="Invalid Authorization header",
 76                        ),
 77                        status_code=400,
 78                    )
 79                )
 80
 81            try:
 82                api_key = APIKey.query.get(token=header_token)
 83            except APIKey.DoesNotExist:
 84                raise ResponseException(
 85                    JsonResponse(
 86                        ErrorSchema(
 87                            id="invalid_api_token",
 88                            message="Invalid API token",
 89                        ),
 90                        status_code=400,
 91                    )
 92                )
 93
 94            if api_key.expires_at and api_key.expires_at < datetime.datetime.now():
 95                raise ResponseException(
 96                    JsonResponse(
 97                        ErrorSchema(
 98                            id="api_token_expired",
 99                            message="API token has expired",
100                        ),
101                        status_code=400,
102                    )
103                )
104
105            return api_key
106
107
108@openapi.response_typed_dict(400, ErrorSchema, component_name="BadRequest")
109@openapi.response_typed_dict(401, ErrorSchema, component_name="Unauthorized")
110@openapi.response_typed_dict(403, ErrorSchema, component_name="Forbidden")
111@openapi.response_typed_dict(404, ErrorSchema, component_name="NotFound")
112@openapi.response_typed_dict(
113    "5XX", ErrorSchema, description="Unexpected Error", component_name="ServerError"
114)
115class APIView(View):
116    def get_response(self):
117        try:
118            return super().get_response()
119        except ResponseException as e:
120            # Catch any response exceptions in APIKeyView or elsewhere before View.get_response
121            return e.response
122        except ValidationError as e:
123            return JsonResponse(
124                ErrorSchema(
125                    id="validation_error",
126                    message=f"Validation error: {e.message}",
127                    # "errors": {field: e.errors[field] for field in e.errors},
128                ),
129                status_code=400,
130            )
131        except FormFieldMissingError as e:
132            return JsonResponse(
133                ErrorSchema(
134                    id="missing_field",
135                    message=f"Missing field: {e.field_name}",
136                ),
137                status_code=400,
138            )
139        except PermissionDenied:
140            return JsonResponse(
141                ErrorSchema(
142                    id="permission_denied",
143                    message="Permission denied",
144                ),
145                status_code=403,
146            )
147        except Http404:
148            return JsonResponse(
149                ErrorSchema(
150                    id="not_found",
151                    message="Not found",
152                ),
153                status_code=404,
154            )
155        except Exception:
156            logger.exception("Internal server error", extra={"request": self.request})
157            return JsonResponse(
158                ErrorSchema(
159                    id="server_error",
160                    message="Internal server error",
161                ),
162                status_code=500,
163            )