1import json
2from functools import cached_property
3from typing import Any
4
5from plain.http import Request, Response
6from plain.views.exceptions import ResponseException
7
8from .views import APIView
9
10__all__ = [
11 "APIVersionChange",
12 "VersionedAPIView",
13]
14
15
16class APIVersionChange:
17 description: str = ""
18
19 def transform_request_forward(self, request: Request, data: dict[str, Any]) -> None:
20 """
21 If this version of the API made a change in how a request is processed,
22 (ex. the name of an input changed) then you can
23 """
24 pass
25
26 def transform_response_backward(
27 self, response: Response, data: dict[str, Any]
28 ) -> None:
29 """
30 Transform the response data for this version.
31
32 We only transform the response data if we are moving backward to an older version.
33 This is because the response data is always in the latest version.
34 """
35 pass
36
37
38class VersionedAPIView(APIView):
39 # API versions from newest to oldest
40 api_versions: dict[str, list[type[APIVersionChange]]] = {}
41 api_version_header = "API-Version"
42 default_api_version: str = ""
43
44 @cached_property
45 def api_version(self) -> str:
46 return self.get_api_version()
47
48 def get_api_version(self) -> str:
49 version = ""
50
51 if version_name := self.request.headers.get(self.api_version_header, ""):
52 version = version_name
53 elif default_version := self.get_default_api_version():
54 version = default_version
55 else:
56 raise ResponseException(
57 Response(
58 f"Missing API version header '{self.api_version_header}'",
59 status_code=400,
60 )
61 )
62
63 if version in self.api_versions:
64 return version
65 else:
66 raise ResponseException(
67 Response(
68 f"Invalid API version '{version_name}'. Valid versions are: {', '.join(self.api_versions.keys())}",
69 status_code=400,
70 )
71 )
72
73 def get_default_api_version(self) -> str:
74 # If this view has an api_key, use its version name
75 if api_key := getattr(self, "api_key", None):
76 if api_key.api_version:
77 # If the API key has a version, use that
78 return api_key.api_version
79
80 return self.default_api_version
81
82 def before_request(self) -> None:
83 super().before_request()
84 if self.request.content_type == "application/json":
85 self.transform_request(self.request)
86
87 def after_response(self, response: Response) -> Response:
88 response = super().after_response(response)
89 if response.headers.get("Content-Type") == "application/json":
90 self.transform_response(response)
91
92 # Put the API version on the response
93 response.headers[self.api_version_header] = self.api_version
94
95 return response
96
97 def transform_request(self, request: Request) -> None:
98 request_changes = []
99
100 # Find the version being requested,
101 # then get every change after that up to the latest
102 changing = False
103 for version, changes in reversed(self.api_versions.items()):
104 if version == self.api_version:
105 changing = True
106
107 if changing:
108 request_changes.extend(changes)
109
110 if not request_changes:
111 return
112
113 # Get the original request JSON
114 request_data = json.loads(request.body)
115
116 # Transform the request data for this version
117 for change in changes:
118 change().transform_request_forward(request, request_data)
119
120 # Update the request body with the transformed data
121 request._body = json.dumps(request_data).encode("utf-8")
122
123 def transform_response(self, response: Response) -> None:
124 response_changes = []
125
126 # Get the changes starting AFTER the current version
127 matching = False
128 for version, changes in reversed(self.api_versions.items()):
129 if matching:
130 response_changes.extend(changes)
131
132 if version == self.api_version:
133 matching = True
134
135 if not response_changes:
136 # No changes to apply, just return
137 return
138
139 # Get the original response JSON
140 response_data = json.loads(response.content)
141
142 for change in reversed(response_changes):
143 # Transform the response data for this version
144 change().transform_response_backward(response, response_data)
145
146 # Update the response body with the transformed data
147 response.content = json.dumps(response_data).encode("utf-8")