1import json
2from functools import cached_property
3from typing import Any
4
5from plain.http import Request, ResponseBadRequest, ResponseBase
6from plain.views import View
7from plain.views.exceptions import ResponseException
8
9
10class APIVersionChange:
11 description: str = ""
12
13 def transform_request_forward(self, request: Request, data: dict[str, Any]) -> None:
14 """
15 If this version of the API made a change in how a request is processed,
16 (ex. the name of an input changed) then you can
17 """
18 pass
19
20 def transform_response_backward(
21 self, response: ResponseBase, data: dict[str, Any]
22 ) -> None:
23 """
24 Transform the response data for this version.
25
26 We only transform the response data if we are moving backward to an older version.
27 This is because the response data is always in the latest version.
28 """
29 pass
30
31
32class VersionedAPIView(View):
33 # API versions from newest to oldest
34 api_versions: dict[str, list[APIVersionChange]] = {}
35 api_version_header = "API-Version"
36 default_api_version: str = ""
37
38 @cached_property
39 def api_version(self) -> str:
40 return self.get_api_version()
41
42 def get_api_version(self) -> str:
43 version = ""
44
45 if version_name := self.request.headers.get(self.api_version_header, ""):
46 version = version_name
47 elif default_version := self.get_default_api_version():
48 version = default_version
49 else:
50 raise ResponseException(
51 ResponseBadRequest(
52 f"Missing API version header '{self.api_version_header}'"
53 )
54 )
55
56 if version in self.api_versions:
57 return version
58 else:
59 raise ResponseException(
60 ResponseBadRequest(
61 f"Invalid API version '{version_name}'. Valid versions are: {', '.join(self.api_versions.keys())}"
62 )
63 )
64
65 def get_default_api_version(self) -> str:
66 # If this view has an api_key, use its version name
67 if api_key := getattr(self, "api_key", None):
68 if api_key.api_version:
69 # If the API key has a version, use that
70 return api_key.api_version
71
72 return self.default_api_version
73
74 def get_response(self) -> ResponseBase:
75 if self.request.content_type == "application/json":
76 self.transform_request(self.request)
77
78 # Process the request normally
79 response = super().get_response()
80
81 if response.headers.get("Content-Type") == "application/json":
82 self.transform_response(response)
83
84 # Put the API version on the response
85 response.headers[self.api_version_header] = self.api_version
86
87 return response
88
89 def transform_request(self, request: Request) -> None:
90 request_changes = []
91
92 # Find the version being requested,
93 # then get every change after that up to the latest
94 changing = False
95 for version, changes in reversed(self.api_versions.items()):
96 if version == self.api_version:
97 changing = True
98
99 if changing:
100 request_changes.extend(changes)
101
102 if not request_changes:
103 return
104
105 # Get the original request JSON
106 request_data = json.loads(request.body)
107
108 # Transform the request data for this version
109 for change in changes:
110 change().transform_request_forward(request, request_data)
111
112 # Update the request body with the transformed data
113 request._body = json.dumps(request_data).encode("utf-8")
114
115 def transform_response(self, response: ResponseBase) -> None:
116 response_changes = []
117
118 # Get the changes starting AFTER the current version
119 matching = False
120 for version, changes in reversed(self.api_versions.items()):
121 if matching:
122 response_changes.extend(changes)
123
124 if version == self.api_version:
125 matching = True
126
127 if not response_changes:
128 # No changes to apply, just return
129 return
130
131 # Get the original response JSON
132 response_data = json.loads(response.content) # type: ignore[attr-defined]
133
134 for change in reversed(response_changes):
135 # Transform the response data for this version
136 change().transform_response_backward(response, response_data)
137
138 # Update the response body with the transformed data
139 response.content = json.dumps(response_data).encode("utf-8") # type: ignore[attr-defined]