1import json
2from functools import cached_property
3from typing import Any
4
5from plain.http import Request, Response, ResponseBase
6from plain.views import View
7from plain.views.exceptions import ResponseException
8
9
10class APIVersionChange:
11 description: str = ""
12
13 def transform_request_forward(self, request: Request, data: dict[str, Any]) -> None:
14 """
15 If this version of the API made a change in how a request is processed,
16 (ex. the name of an input changed) then you can
17 """
18 pass
19
20 def transform_response_backward(
21 self, response: ResponseBase, data: dict[str, Any]
22 ) -> None:
23 """
24 Transform the response data for this version.
25
26 We only transform the response data if we are moving backward to an older version.
27 This is because the response data is always in the latest version.
28 """
29 pass
30
31
32class VersionedAPIView(View):
33 # API versions from newest to oldest
34 api_versions: dict[str, list[APIVersionChange]] = {}
35 api_version_header = "API-Version"
36 default_api_version: str = ""
37
38 @cached_property
39 def api_version(self) -> str:
40 return self.get_api_version()
41
42 def get_api_version(self) -> str:
43 version = ""
44
45 if version_name := self.request.headers.get(self.api_version_header, ""):
46 version = version_name
47 elif default_version := self.get_default_api_version():
48 version = default_version
49 else:
50 raise ResponseException(
51 Response(
52 f"Missing API version header '{self.api_version_header}'",
53 status_code=400,
54 )
55 )
56
57 if version in self.api_versions:
58 return version
59 else:
60 raise ResponseException(
61 Response(
62 f"Invalid API version '{version_name}'. Valid versions are: {', '.join(self.api_versions.keys())}",
63 status_code=400,
64 )
65 )
66
67 def get_default_api_version(self) -> str:
68 # If this view has an api_key, use its version name
69 if api_key := getattr(self, "api_key", None):
70 if api_key.api_version:
71 # If the API key has a version, use that
72 return api_key.api_version
73
74 return self.default_api_version
75
76 def get_response(self) -> ResponseBase:
77 if self.request.content_type == "application/json":
78 self.transform_request(self.request)
79
80 # Process the request normally
81 response = super().get_response()
82
83 if response.headers.get("Content-Type") == "application/json":
84 self.transform_response(response)
85
86 # Put the API version on the response
87 response.headers[self.api_version_header] = self.api_version
88
89 return response
90
91 def transform_request(self, request: Request) -> None:
92 request_changes = []
93
94 # Find the version being requested,
95 # then get every change after that up to the latest
96 changing = False
97 for version, changes in reversed(self.api_versions.items()):
98 if version == self.api_version:
99 changing = True
100
101 if changing:
102 request_changes.extend(changes)
103
104 if not request_changes:
105 return
106
107 # Get the original request JSON
108 request_data = json.loads(request.body)
109
110 # Transform the request data for this version
111 for change in changes:
112 change().transform_request_forward(request, request_data)
113
114 # Update the request body with the transformed data
115 request._body = json.dumps(request_data).encode("utf-8")
116
117 def transform_response(self, response: ResponseBase) -> None:
118 response_changes = []
119
120 # Get the changes starting AFTER the current version
121 matching = False
122 for version, changes in reversed(self.api_versions.items()):
123 if matching:
124 response_changes.extend(changes)
125
126 if version == self.api_version:
127 matching = True
128
129 if not response_changes:
130 # No changes to apply, just return
131 return
132
133 # Get the original response JSON
134 response_data = json.loads(response.content) # type: ignore[attr-defined]
135
136 for change in reversed(response_changes):
137 # Transform the response data for this version
138 change().transform_response_backward(response, response_data)
139
140 # Update the response body with the transformed data
141 response.content = json.dumps(response_data).encode("utf-8") # type: ignore[attr-defined]