v0.139.0
  1import json
  2from functools import cached_property
  3from typing import Any
  4
  5from plain.http import Request, Response
  6from plain.views.exceptions import ResponseException
  7
  8from .views import APIView
  9
 10__all__ = [
 11    "APIVersionChange",
 12    "VersionedAPIView",
 13]
 14
 15
 16class APIVersionChange:
 17    description: str = ""
 18
 19    def transform_request_forward(self, request: Request, data: dict[str, Any]) -> None:
 20        """
 21        If this version of the API made a change in how a request is processed,
 22        (ex. the name of an input changed) then you can
 23        """
 24        pass
 25
 26    def transform_response_backward(
 27        self, response: Response, data: dict[str, Any]
 28    ) -> None:
 29        """
 30        Transform the response data for this version.
 31
 32        We only transform the response data if we are moving backward to an older version.
 33        This is because the response data is always in the latest version.
 34        """
 35        pass
 36
 37
 38class VersionedAPIView(APIView):
 39    # API versions from newest to oldest
 40    api_versions: dict[str, list[type[APIVersionChange]]] = {}
 41    api_version_header = "API-Version"
 42    default_api_version: str = ""
 43
 44    @cached_property
 45    def api_version(self) -> str:
 46        return self.get_api_version()
 47
 48    def get_api_version(self) -> str:
 49        version = ""
 50
 51        if version_name := self.request.headers.get(self.api_version_header, ""):
 52            version = version_name
 53        elif default_version := self.get_default_api_version():
 54            version = default_version
 55        else:
 56            raise ResponseException(
 57                Response(
 58                    f"Missing API version header '{self.api_version_header}'",
 59                    status_code=400,
 60                )
 61            )
 62
 63        if version in self.api_versions:
 64            return version
 65        else:
 66            raise ResponseException(
 67                Response(
 68                    f"Invalid API version '{version_name}'. Valid versions are: {', '.join(self.api_versions.keys())}",
 69                    status_code=400,
 70                )
 71            )
 72
 73    def get_default_api_version(self) -> str:
 74        # If this view has an api_key, use its version name
 75        if api_key := getattr(self, "api_key", None):
 76            if api_key.api_version:
 77                # If the API key has a version, use that
 78                return api_key.api_version
 79
 80        return self.default_api_version
 81
 82    def before_request(self) -> None:
 83        super().before_request()
 84        if self.request.content_type == "application/json":
 85            self.transform_request(self.request)
 86
 87    def after_response(self, response: Response) -> Response:
 88        response = super().after_response(response)
 89        if response.headers.get("Content-Type") == "application/json":
 90            self.transform_response(response)
 91
 92        # Put the API version on the response
 93        response.headers[self.api_version_header] = self.api_version
 94
 95        return response
 96
 97    def transform_request(self, request: Request) -> None:
 98        request_changes = []
 99
100        # Find the version being requested,
101        # then get every change after that up to the latest
102        changing = False
103        for version, changes in reversed(self.api_versions.items()):
104            if version == self.api_version:
105                changing = True
106
107            if changing:
108                request_changes.extend(changes)
109
110        if not request_changes:
111            return
112
113        # Get the original request JSON
114        request_data = json.loads(request.body)
115
116        # Transform the request data for this version
117        for change in changes:
118            change().transform_request_forward(request, request_data)
119
120        # Update the request body with the transformed data
121        request._body = json.dumps(request_data).encode("utf-8")
122
123    def transform_response(self, response: Response) -> None:
124        response_changes = []
125
126        # Get the changes starting AFTER the current version
127        matching = False
128        for version, changes in reversed(self.api_versions.items()):
129            if matching:
130                response_changes.extend(changes)
131
132            if version == self.api_version:
133                matching = True
134
135        if not response_changes:
136            # No changes to apply, just return
137            return
138
139        # Get the original response JSON
140        response_data = json.loads(response.content)
141
142        for change in reversed(response_changes):
143            # Transform the response data for this version
144            change().transform_response_backward(response, response_data)
145
146        # Update the response body with the transformed data
147        response.content = json.dumps(response_data).encode("utf-8")