1import json
2from functools import cached_property
3
4from plain.http import ResponseBadRequest
5from plain.views import View
6from plain.views.exceptions import ResponseException
7
8
9class APIVersionChange:
10 description: str = ""
11
12 def transform_request_forward(self, request, data):
13 """
14 If this version of the API made a change in how a request is processed,
15 (ex. the name of an input changed) then you can
16 """
17 pass
18
19 def transform_response_backward(self, response, data):
20 """
21 Transform the response data for this version.
22
23 We only transform the response data if we are moving backward to an older version.
24 This is because the response data is always in the latest version.
25 """
26 pass
27
28
29class VersionedAPIView(View):
30 # API versions from newest to oldest
31 api_versions: dict[str, list[APIVersionChange]] = {}
32 api_version_header = "API-Version"
33 default_api_version: str = ""
34
35 @cached_property
36 def api_version(self) -> str:
37 return self.get_api_version()
38
39 def get_api_version(self) -> str:
40 version = ""
41
42 if version_name := self.request.headers.get(self.api_version_header, ""):
43 version = version_name
44 elif default_version := self.get_default_api_version():
45 version = default_version
46 else:
47 raise ResponseException(
48 ResponseBadRequest(
49 f"Missing API version header '{self.api_version_header}'"
50 )
51 )
52
53 if version in self.api_versions:
54 return version
55 else:
56 raise ResponseException(
57 ResponseBadRequest(
58 f"Invalid API version '{version_name}'. Valid versions are: {', '.join(self.api_versions.keys())}"
59 )
60 )
61
62 def get_default_api_version(self) -> str:
63 # If this view has an api_key, use its version name
64 if api_key := getattr(self, "api_key", None):
65 if api_key.api_version:
66 # If the API key has a version, use that
67 return api_key.api_version
68
69 return self.default_api_version
70
71 def get_response(self):
72 if self.request.content_type == "application/json":
73 self.transform_request(self.request)
74
75 # Process the request normally
76 response = super().get_response()
77
78 if response.headers.get("Content-Type") == "application/json":
79 self.transform_response(response)
80
81 # Put the API version on the response
82 response.headers[self.api_version_header] = self.api_version
83
84 return response
85
86 def transform_request(self, request):
87 request_changes = []
88
89 # Find the version being requested,
90 # then get every change after that up to the latest
91 changing = False
92 for version, changes in reversed(self.api_versions.items()):
93 if version == self.api_version:
94 changing = True
95
96 if changing:
97 request_changes.extend(changes)
98
99 if not request_changes:
100 return
101
102 # Get the original request JSON
103 request_data = json.loads(request.body)
104
105 # Transform the request data for this version
106 for change in changes:
107 change().transform_request_forward(request, request_data)
108
109 # Update the request body with the transformed data
110 request._body = json.dumps(request_data).encode("utf-8")
111
112 def transform_response(self, response):
113 response_changes = []
114
115 # Get the changes starting AFTER the current version
116 matching = False
117 for version, changes in reversed(self.api_versions.items()):
118 if matching:
119 response_changes.extend(changes)
120
121 if version == self.api_version:
122 matching = True
123
124 if not response_changes:
125 # No changes to apply, just return
126 return
127
128 # Get the original response JSON
129 response_data = json.loads(response.content)
130
131 for change in reversed(response_changes):
132 # Transform the response data for this version
133 change().transform_response_backward(response, response_data)
134
135 # Update the response body with the transformed data
136 response.content = json.dumps(response_data).encode("utf-8")