Plain is headed towards 1.0! Subscribe for development updates →

  1import datetime
  2import logging
  3from functools import cached_property
  4from typing import Any
  5
  6from plain.exceptions import PermissionDenied, ValidationError
  7from plain.forms.exceptions import FormFieldMissingError
  8from plain.http import Http404, JsonResponse, ResponseBase
  9from plain.utils import timezone
 10from plain.utils.cache import patch_cache_control
 11from plain.views.base import View
 12from plain.views.exceptions import ResponseException
 13
 14from . import openapi
 15from .schemas import ErrorSchema
 16
 17# Allow plain.api to be used without plain.models
 18try:
 19    from .models import APIKey
 20except ImportError:
 21    APIKey = None  # type: ignore[assignment]
 22
 23logger = logging.getLogger("plain.api")
 24
 25
 26# @openapi.response_typed_dict(400, ErrorSchema)
 27# @openapi.response_typed_dict(401, ErrorSchema)
 28class APIKeyView(View):
 29    api_key_required = True
 30
 31    @cached_property
 32    def api_key(self) -> Any:
 33        return self.get_api_key()
 34
 35    def get_response(self) -> ResponseBase:
 36        if self.api_key:
 37            self.use_api_key()
 38        elif self.api_key_required:
 39            return JsonResponse(
 40                ErrorSchema(
 41                    id="api_key_required",
 42                    message="API key required",
 43                    url="",
 44                ),
 45                status_code=401,
 46            )
 47
 48        response = super().get_response()
 49        # Make sure it at least has private as a default
 50        patch_cache_control(response, private=True)
 51        return response
 52
 53    def use_api_key(self) -> None:
 54        """
 55        Use the API key for this request.
 56
 57        Override this to perform other actions with a valid API key.
 58        """
 59        self.api_key.last_used_at = timezone.now()
 60        self.api_key.save(update_fields=["last_used_at"])
 61
 62    def get_api_key(self) -> Any:
 63        """
 64        Get the API key from the request.
 65
 66        Override this if you want to use a different input method.
 67        """
 68        if "Authorization" in self.request.headers:
 69            header_value = self.request.headers["Authorization"]
 70            try:
 71                header_token = header_value.split("Bearer ")[1]
 72            except IndexError:
 73                raise ResponseException(
 74                    JsonResponse(
 75                        ErrorSchema(
 76                            id="invalid_authorization_header",
 77                            message="Invalid Authorization header",
 78                            url="",
 79                        ),
 80                        status_code=400,
 81                    )
 82                )
 83
 84            try:
 85                api_key = APIKey.query.get(token=header_token)
 86            except APIKey.DoesNotExist:
 87                raise ResponseException(
 88                    JsonResponse(
 89                        ErrorSchema(
 90                            id="invalid_api_token",
 91                            message="Invalid API token",
 92                            url="",
 93                        ),
 94                        status_code=400,
 95                    )
 96                )
 97
 98            if api_key.expires_at and api_key.expires_at < datetime.datetime.now():
 99                raise ResponseException(
100                    JsonResponse(
101                        ErrorSchema(
102                            id="api_token_expired",
103                            message="API token has expired",
104                            url="",
105                        ),
106                        status_code=400,
107                    )
108                )
109
110            return api_key
111
112
113@openapi.response_typed_dict(400, ErrorSchema, component_name="BadRequest")
114@openapi.response_typed_dict(401, ErrorSchema, component_name="Unauthorized")
115@openapi.response_typed_dict(403, ErrorSchema, component_name="Forbidden")
116@openapi.response_typed_dict(404, ErrorSchema, component_name="NotFound")
117@openapi.response_typed_dict(
118    "5XX", ErrorSchema, description="Unexpected Error", component_name="ServerError"
119)
120class APIView(View):
121    def get_response(self) -> ResponseBase:
122        try:
123            return super().get_response()
124        except ResponseException as e:
125            # Catch any response exceptions in APIKeyView or elsewhere before View.get_response
126            return e.response
127        except ValidationError as e:
128            return JsonResponse(
129                ErrorSchema(
130                    id="validation_error",
131                    message=f"Validation error: {e.message}",
132                    url="",
133                    # "errors": {field: e.errors[field] for field in e.errors},
134                ),
135                status_code=400,
136            )
137        except FormFieldMissingError as e:
138            return JsonResponse(
139                ErrorSchema(
140                    id="missing_field",
141                    message=f"Missing field: {e.field_name}",
142                    url="",
143                ),
144                status_code=400,
145            )
146        except PermissionDenied:
147            return JsonResponse(
148                ErrorSchema(
149                    id="permission_denied",
150                    message="Permission denied",
151                    url="",
152                ),
153                status_code=403,
154            )
155        except Http404:
156            return JsonResponse(
157                ErrorSchema(
158                    id="not_found",
159                    message="Not found",
160                    url="",
161                ),
162                status_code=404,
163            )
164        except Exception:
165            logger.exception("Internal server error", extra={"request": self.request})
166            return JsonResponse(
167                ErrorSchema(
168                    id="server_error",
169                    message="Internal server error",
170                    url="",
171                ),
172                status_code=500,
173            )