1import json
  2from functools import cached_property
  3from typing import Any
  4
  5from plain.http import Request, Response, ResponseBase
  6from plain.views import View
  7from plain.views.exceptions import ResponseException
  8
  9__all__ = [
 10    "APIVersionChange",
 11    "VersionedAPIView",
 12]
 13
 14
 15class APIVersionChange:
 16    description: str = ""
 17
 18    def transform_request_forward(self, request: Request, data: dict[str, Any]) -> None:
 19        """
 20        If this version of the API made a change in how a request is processed,
 21        (ex. the name of an input changed) then you can
 22        """
 23        pass
 24
 25    def transform_response_backward(
 26        self, response: ResponseBase, data: dict[str, Any]
 27    ) -> None:
 28        """
 29        Transform the response data for this version.
 30
 31        We only transform the response data if we are moving backward to an older version.
 32        This is because the response data is always in the latest version.
 33        """
 34        pass
 35
 36
 37class VersionedAPIView(View):
 38    # API versions from newest to oldest
 39    api_versions: dict[str, list[type[APIVersionChange]]] = {}
 40    api_version_header = "API-Version"
 41    default_api_version: str = ""
 42
 43    @cached_property
 44    def api_version(self) -> str:
 45        return self.get_api_version()
 46
 47    def get_api_version(self) -> str:
 48        version = ""
 49
 50        if version_name := self.request.headers.get(self.api_version_header, ""):
 51            version = version_name
 52        elif default_version := self.get_default_api_version():
 53            version = default_version
 54        else:
 55            raise ResponseException(
 56                Response(
 57                    f"Missing API version header '{self.api_version_header}'",
 58                    status_code=400,
 59                )
 60            )
 61
 62        if version in self.api_versions:
 63            return version
 64        else:
 65            raise ResponseException(
 66                Response(
 67                    f"Invalid API version '{version_name}'. Valid versions are: {', '.join(self.api_versions.keys())}",
 68                    status_code=400,
 69                )
 70            )
 71
 72    def get_default_api_version(self) -> str:
 73        # If this view has an api_key, use its version name
 74        if api_key := getattr(self, "api_key", None):
 75            if api_key.api_version:
 76                # If the API key has a version, use that
 77                return api_key.api_version
 78
 79        return self.default_api_version
 80
 81    def get_response(self) -> ResponseBase:
 82        if self.request.content_type == "application/json":
 83            self.transform_request(self.request)
 84
 85        # Process the request normally
 86        response = super().get_response()
 87
 88        if response.headers.get("Content-Type") == "application/json":
 89            self.transform_response(response)
 90
 91        # Put the API version on the response
 92        response.headers[self.api_version_header] = self.api_version
 93
 94        return response
 95
 96    def transform_request(self, request: Request) -> None:
 97        request_changes = []
 98
 99        # Find the version being requested,
100        # then get every change after that up to the latest
101        changing = False
102        for version, changes in reversed(self.api_versions.items()):
103            if version == self.api_version:
104                changing = True
105
106            if changing:
107                request_changes.extend(changes)
108
109        if not request_changes:
110            return
111
112        # Get the original request JSON
113        request_data = json.loads(request.body)
114
115        # Transform the request data for this version
116        for change in changes:
117            change().transform_request_forward(request, request_data)
118
119        # Update the request body with the transformed data
120        request._body = json.dumps(request_data).encode("utf-8")
121
122    def transform_response(self, response: ResponseBase) -> None:
123        response_changes = []
124
125        # Get the changes starting AFTER the current version
126        matching = False
127        for version, changes in reversed(self.api_versions.items()):
128            if matching:
129                response_changes.extend(changes)
130
131            if version == self.api_version:
132                matching = True
133
134        if not response_changes:
135            # No changes to apply, just return
136            return
137
138        # Get the original response JSON
139        response_data = json.loads(response.content)  # type: ignore[attr-defined]
140
141        for change in reversed(response_changes):
142            # Transform the response data for this version
143            change().transform_response_backward(response, response_data)
144
145        # Update the response body with the transformed data
146        response.content = json.dumps(response_data).encode("utf-8")  # type: ignore[attr-defined]