1import datetime
  2import logging
  3from functools import cached_property
  4from typing import Any
  5
  6from plain.exceptions import ValidationError
  7from plain.forms.exceptions import FormFieldMissingError
  8from plain.http import ForbiddenError403, JsonResponse, NotFoundError404, ResponseBase
  9from plain.utils import timezone
 10from plain.utils.cache import patch_cache_control
 11from plain.views.base import View
 12from plain.views.exceptions import ResponseException
 13
 14from . import openapi
 15from .schemas import ErrorSchema
 16
 17# Allow plain.api to be used without plain.models
 18try:
 19    from .models import APIKey
 20except ImportError:
 21    APIKey = None  # type: ignore[misc, assignment]
 22
 23__all__ = [
 24    "APIKeyView",
 25    "APIView",
 26]
 27
 28logger = logging.getLogger("plain.api")
 29
 30
 31# @openapi.response_typed_dict(400, ErrorSchema)
 32# @openapi.response_typed_dict(401, ErrorSchema)
 33class APIKeyView(View):
 34    api_key_required = True
 35
 36    @cached_property
 37    def api_key(self) -> Any:
 38        return self.get_api_key()
 39
 40    def get_response(self) -> ResponseBase:
 41        if self.api_key:
 42            self.use_api_key()
 43        elif self.api_key_required:
 44            return JsonResponse(
 45                ErrorSchema(
 46                    id="api_key_required",
 47                    message="API key required",
 48                    url="",
 49                ),
 50                status_code=401,
 51            )
 52
 53        response = super().get_response()
 54        # Make sure it at least has private as a default
 55        patch_cache_control(response, private=True)
 56        return response
 57
 58    def use_api_key(self) -> None:
 59        """
 60        Use the API key for this request.
 61
 62        Override this to perform other actions with a valid API key.
 63        """
 64        self.api_key.last_used_at = timezone.now()
 65        self.api_key.save(update_fields=["last_used_at"])
 66
 67    def get_api_key(self) -> Any:
 68        """
 69        Get the API key from the request.
 70
 71        Override this if you want to use a different input method.
 72        """
 73        if "Authorization" in self.request.headers:
 74            header_value = self.request.headers["Authorization"]
 75            try:
 76                header_token = header_value.split("Bearer ")[1]
 77            except IndexError:
 78                raise ResponseException(
 79                    JsonResponse(
 80                        ErrorSchema(
 81                            id="invalid_authorization_header",
 82                            message="Invalid Authorization header",
 83                            url="",
 84                        ),
 85                        status_code=400,
 86                    )
 87                )
 88
 89            try:
 90                api_key = APIKey.query.get(token=header_token)
 91            except APIKey.DoesNotExist:
 92                raise ResponseException(
 93                    JsonResponse(
 94                        ErrorSchema(
 95                            id="invalid_api_token",
 96                            message="Invalid API token",
 97                            url="",
 98                        ),
 99                        status_code=400,
100                    )
101                )
102
103            if api_key.expires_at and api_key.expires_at < datetime.datetime.now():
104                raise ResponseException(
105                    JsonResponse(
106                        ErrorSchema(
107                            id="api_token_expired",
108                            message="API token has expired",
109                            url="",
110                        ),
111                        status_code=400,
112                    )
113                )
114
115            return api_key
116
117
118@openapi.response_typed_dict(400, ErrorSchema, component_name="BadRequest")
119@openapi.response_typed_dict(401, ErrorSchema, component_name="Unauthorized")
120@openapi.response_typed_dict(403, ErrorSchema, component_name="Forbidden")
121@openapi.response_typed_dict(404, ErrorSchema, component_name="NotFound")
122@openapi.response_typed_dict(
123    "5XX", ErrorSchema, description="Unexpected Error", component_name="ServerError"
124)
125class APIView(View):
126    def get_response(self) -> ResponseBase:
127        try:
128            return super().get_response()
129        except ResponseException as e:
130            # Catch any response exceptions in APIKeyView or elsewhere before View.get_response
131            return e.response
132        except ValidationError as e:
133            return JsonResponse(
134                ErrorSchema(
135                    id="validation_error",
136                    message=f"Validation error: {e.message}",
137                    url="",
138                    # "errors": {field: e.errors[field] for field in e.errors},
139                ),
140                status_code=400,
141            )
142        except FormFieldMissingError as e:
143            return JsonResponse(
144                ErrorSchema(
145                    id="missing_field",
146                    message=f"Missing field: {e.field_name}",
147                    url="",
148                ),
149                status_code=400,
150            )
151        except ForbiddenError403:
152            return JsonResponse(
153                ErrorSchema(
154                    id="permission_denied",
155                    message="Permission denied",
156                    url="",
157                ),
158                status_code=403,
159            )
160        except NotFoundError404:
161            return JsonResponse(
162                ErrorSchema(
163                    id="not_found",
164                    message="Not found",
165                    url="",
166                ),
167                status_code=404,
168            )
169        except Exception:
170            logger.exception("Internal server error", extra={"request": self.request})
171            return JsonResponse(
172                ErrorSchema(
173                    id="server_error",
174                    message="Internal server error",
175                    url="",
176                ),
177                status_code=500,
178            )