Plain is headed towards 1.0! Subscribe for development updates →

  1import json
  2from functools import cached_property
  3from typing import Any
  4
  5from plain.http import Request, ResponseBadRequest, ResponseBase
  6from plain.views import View
  7from plain.views.exceptions import ResponseException
  8
  9
 10class APIVersionChange:
 11    description: str = ""
 12
 13    def transform_request_forward(self, request: Request, data: dict[str, Any]) -> None:
 14        """
 15        If this version of the API made a change in how a request is processed,
 16        (ex. the name of an input changed) then you can
 17        """
 18        pass
 19
 20    def transform_response_backward(
 21        self, response: ResponseBase, data: dict[str, Any]
 22    ) -> None:
 23        """
 24        Transform the response data for this version.
 25
 26        We only transform the response data if we are moving backward to an older version.
 27        This is because the response data is always in the latest version.
 28        """
 29        pass
 30
 31
 32class VersionedAPIView(View):
 33    # API versions from newest to oldest
 34    api_versions: dict[str, list[APIVersionChange]] = {}
 35    api_version_header = "API-Version"
 36    default_api_version: str = ""
 37
 38    @cached_property
 39    def api_version(self) -> str:
 40        return self.get_api_version()
 41
 42    def get_api_version(self) -> str:
 43        version = ""
 44
 45        if version_name := self.request.headers.get(self.api_version_header, ""):
 46            version = version_name
 47        elif default_version := self.get_default_api_version():
 48            version = default_version
 49        else:
 50            raise ResponseException(
 51                ResponseBadRequest(
 52                    f"Missing API version header '{self.api_version_header}'"
 53                )
 54            )
 55
 56        if version in self.api_versions:
 57            return version
 58        else:
 59            raise ResponseException(
 60                ResponseBadRequest(
 61                    f"Invalid API version '{version_name}'. Valid versions are: {', '.join(self.api_versions.keys())}"
 62                )
 63            )
 64
 65    def get_default_api_version(self) -> str:
 66        # If this view has an api_key, use its version name
 67        if api_key := getattr(self, "api_key", None):
 68            if api_key.api_version:
 69                # If the API key has a version, use that
 70                return api_key.api_version
 71
 72        return self.default_api_version
 73
 74    def get_response(self) -> ResponseBase:
 75        if self.request.content_type == "application/json":
 76            self.transform_request(self.request)
 77
 78        # Process the request normally
 79        response = super().get_response()
 80
 81        if response.headers.get("Content-Type") == "application/json":
 82            self.transform_response(response)
 83
 84        # Put the API version on the response
 85        response.headers[self.api_version_header] = self.api_version
 86
 87        return response
 88
 89    def transform_request(self, request: Request) -> None:
 90        request_changes = []
 91
 92        # Find the version being requested,
 93        # then get every change after that up to the latest
 94        changing = False
 95        for version, changes in reversed(self.api_versions.items()):
 96            if version == self.api_version:
 97                changing = True
 98
 99            if changing:
100                request_changes.extend(changes)
101
102        if not request_changes:
103            return
104
105        # Get the original request JSON
106        request_data = json.loads(request.body)
107
108        # Transform the request data for this version
109        for change in changes:
110            change().transform_request_forward(request, request_data)
111
112        # Update the request body with the transformed data
113        request._body = json.dumps(request_data).encode("utf-8")
114
115    def transform_response(self, response: ResponseBase) -> None:
116        response_changes = []
117
118        # Get the changes starting AFTER the current version
119        matching = False
120        for version, changes in reversed(self.api_versions.items()):
121            if matching:
122                response_changes.extend(changes)
123
124            if version == self.api_version:
125                matching = True
126
127        if not response_changes:
128            # No changes to apply, just return
129            return
130
131        # Get the original response JSON
132        response_data = json.loads(response.content)  # type: ignore[attr-defined]
133
134        for change in reversed(response_changes):
135            # Transform the response data for this version
136            change().transform_response_backward(response, response_data)
137
138        # Update the response body with the transformed data
139        response.content = json.dumps(response_data).encode("utf-8")  # type: ignore[attr-defined]