1import json
  2from functools import cached_property
  3from typing import Any
  4
  5from plain.http import Request, Response, ResponseBase
  6from plain.views import View
  7from plain.views.exceptions import ResponseException
  8
  9
 10class APIVersionChange:
 11    description: str = ""
 12
 13    def transform_request_forward(self, request: Request, data: dict[str, Any]) -> None:
 14        """
 15        If this version of the API made a change in how a request is processed,
 16        (ex. the name of an input changed) then you can
 17        """
 18        pass
 19
 20    def transform_response_backward(
 21        self, response: ResponseBase, data: dict[str, Any]
 22    ) -> None:
 23        """
 24        Transform the response data for this version.
 25
 26        We only transform the response data if we are moving backward to an older version.
 27        This is because the response data is always in the latest version.
 28        """
 29        pass
 30
 31
 32class VersionedAPIView(View):
 33    # API versions from newest to oldest
 34    api_versions: dict[str, list[APIVersionChange]] = {}
 35    api_version_header = "API-Version"
 36    default_api_version: str = ""
 37
 38    @cached_property
 39    def api_version(self) -> str:
 40        return self.get_api_version()
 41
 42    def get_api_version(self) -> str:
 43        version = ""
 44
 45        if version_name := self.request.headers.get(self.api_version_header, ""):
 46            version = version_name
 47        elif default_version := self.get_default_api_version():
 48            version = default_version
 49        else:
 50            raise ResponseException(
 51                Response(
 52                    f"Missing API version header '{self.api_version_header}'",
 53                    status_code=400,
 54                )
 55            )
 56
 57        if version in self.api_versions:
 58            return version
 59        else:
 60            raise ResponseException(
 61                Response(
 62                    f"Invalid API version '{version_name}'. Valid versions are: {', '.join(self.api_versions.keys())}",
 63                    status_code=400,
 64                )
 65            )
 66
 67    def get_default_api_version(self) -> str:
 68        # If this view has an api_key, use its version name
 69        if api_key := getattr(self, "api_key", None):
 70            if api_key.api_version:
 71                # If the API key has a version, use that
 72                return api_key.api_version
 73
 74        return self.default_api_version
 75
 76    def get_response(self) -> ResponseBase:
 77        if self.request.content_type == "application/json":
 78            self.transform_request(self.request)
 79
 80        # Process the request normally
 81        response = super().get_response()
 82
 83        if response.headers.get("Content-Type") == "application/json":
 84            self.transform_response(response)
 85
 86        # Put the API version on the response
 87        response.headers[self.api_version_header] = self.api_version
 88
 89        return response
 90
 91    def transform_request(self, request: Request) -> None:
 92        request_changes = []
 93
 94        # Find the version being requested,
 95        # then get every change after that up to the latest
 96        changing = False
 97        for version, changes in reversed(self.api_versions.items()):
 98            if version == self.api_version:
 99                changing = True
100
101            if changing:
102                request_changes.extend(changes)
103
104        if not request_changes:
105            return
106
107        # Get the original request JSON
108        request_data = json.loads(request.body)
109
110        # Transform the request data for this version
111        for change in changes:
112            change().transform_request_forward(request, request_data)
113
114        # Update the request body with the transformed data
115        request._body = json.dumps(request_data).encode("utf-8")
116
117    def transform_response(self, response: ResponseBase) -> None:
118        response_changes = []
119
120        # Get the changes starting AFTER the current version
121        matching = False
122        for version, changes in reversed(self.api_versions.items()):
123            if matching:
124                response_changes.extend(changes)
125
126            if version == self.api_version:
127                matching = True
128
129        if not response_changes:
130            # No changes to apply, just return
131            return
132
133        # Get the original response JSON
134        response_data = json.loads(response.content)  # type: ignore[attr-defined]
135
136        for change in reversed(response_changes):
137            # Transform the response data for this version
138            change().transform_response_backward(response, response_data)
139
140        # Update the response body with the transformed data
141        response.content = json.dumps(response_data).encode("utf-8")  # type: ignore[attr-defined]