Plain is headed towards 1.0! Subscribe for development updates →

  1import json
  2from functools import cached_property
  3
  4from plain.http import ResponseBadRequest
  5from plain.views import View
  6from plain.views.exceptions import ResponseException
  7
  8
  9class APIVersionChange:
 10    description: str = ""
 11
 12    def transform_request_forward(self, request, data):
 13        """
 14        If this version of the API made a change in how a request is processed,
 15        (ex. the name of an input changed) then you can
 16        """
 17        pass
 18
 19    def transform_response_backward(self, response, data):
 20        """
 21        Transform the response data for this version.
 22
 23        We only transform the response data if we are moving backward to an older version.
 24        This is because the response data is always in the latest version.
 25        """
 26        pass
 27
 28
 29class VersionedAPIView(View):
 30    # API versions from newest to oldest
 31    api_versions: dict[str, list[APIVersionChange]] = {}
 32    api_version_header = "API-Version"
 33    default_api_version: str = ""
 34
 35    @cached_property
 36    def api_version(self) -> str:
 37        return self.get_api_version()
 38
 39    def get_api_version(self) -> str:
 40        version = ""
 41
 42        if version_name := self.request.headers.get(self.api_version_header, ""):
 43            version = version_name
 44        elif default_version := self.get_default_api_version():
 45            version = default_version
 46        else:
 47            raise ResponseException(
 48                ResponseBadRequest(
 49                    f"Missing API version header '{self.api_version_header}'"
 50                )
 51            )
 52
 53        if version in self.api_versions:
 54            return version
 55        else:
 56            raise ResponseException(
 57                ResponseBadRequest(
 58                    f"Invalid API version '{version_name}'. Valid versions are: {', '.join(self.api_versions.keys())}"
 59                )
 60            )
 61
 62    def get_default_api_version(self) -> str:
 63        # If this view has an api_key, use its version name
 64        if api_key := getattr(self, "api_key", None):
 65            if api_key.api_version:
 66                # If the API key has a version, use that
 67                return api_key.api_version
 68
 69        return self.default_api_version
 70
 71    def get_response(self):
 72        if self.request.content_type == "application/json":
 73            self.transform_request(self.request)
 74
 75        # Process the request normally
 76        response = super().get_response()
 77
 78        if response.headers.get("Content-Type") == "application/json":
 79            self.transform_response(response)
 80
 81        # Put the API version on the response
 82        response.headers[self.api_version_header] = self.api_version
 83
 84        return response
 85
 86    def transform_request(self, request):
 87        request_changes = []
 88
 89        # Find the version being requested,
 90        # then get every change after that up to the latest
 91        changing = False
 92        for version, changes in reversed(self.api_versions.items()):
 93            if version == self.api_version:
 94                changing = True
 95
 96            if changing:
 97                request_changes.extend(changes)
 98
 99        if not request_changes:
100            return
101
102        # Get the original request JSON
103        request_data = json.loads(request.body)
104
105        # Transform the request data for this version
106        for change in changes:
107            change().transform_request_forward(request, request_data)
108
109        # Update the request body with the transformed data
110        request._body = json.dumps(request_data).encode("utf-8")
111
112    def transform_response(self, response):
113        response_changes = []
114
115        # Get the changes starting AFTER the current version
116        matching = False
117        for version, changes in reversed(self.api_versions.items()):
118            if matching:
119                response_changes.extend(changes)
120
121            if version == self.api_version:
122                matching = True
123
124        if not response_changes:
125            # No changes to apply, just return
126            return
127
128        # Get the original response JSON
129        response_data = json.loads(response.content)
130
131        for change in reversed(response_changes):
132            # Transform the response data for this version
133            change().transform_response_backward(response, response_data)
134
135        # Update the response body with the transformed data
136        response.content = json.dumps(response_data).encode("utf-8")