1import json
2from functools import cached_property
3from typing import Any
4
5from plain.http import Request, Response, ResponseBase
6from plain.views import View
7from plain.views.exceptions import ResponseException
8
9__all__ = [
10 "APIVersionChange",
11 "VersionedAPIView",
12]
13
14
15class APIVersionChange:
16 description: str = ""
17
18 def transform_request_forward(self, request: Request, data: dict[str, Any]) -> None:
19 """
20 If this version of the API made a change in how a request is processed,
21 (ex. the name of an input changed) then you can
22 """
23 pass
24
25 def transform_response_backward(
26 self, response: ResponseBase, data: dict[str, Any]
27 ) -> None:
28 """
29 Transform the response data for this version.
30
31 We only transform the response data if we are moving backward to an older version.
32 This is because the response data is always in the latest version.
33 """
34 pass
35
36
37class VersionedAPIView(View):
38 # API versions from newest to oldest
39 api_versions: dict[str, list[type[APIVersionChange]]] = {}
40 api_version_header = "API-Version"
41 default_api_version: str = ""
42
43 @cached_property
44 def api_version(self) -> str:
45 return self.get_api_version()
46
47 def get_api_version(self) -> str:
48 version = ""
49
50 if version_name := self.request.headers.get(self.api_version_header, ""):
51 version = version_name
52 elif default_version := self.get_default_api_version():
53 version = default_version
54 else:
55 raise ResponseException(
56 Response(
57 f"Missing API version header '{self.api_version_header}'",
58 status_code=400,
59 )
60 )
61
62 if version in self.api_versions:
63 return version
64 else:
65 raise ResponseException(
66 Response(
67 f"Invalid API version '{version_name}'. Valid versions are: {', '.join(self.api_versions.keys())}",
68 status_code=400,
69 )
70 )
71
72 def get_default_api_version(self) -> str:
73 # If this view has an api_key, use its version name
74 if api_key := getattr(self, "api_key", None):
75 if api_key.api_version:
76 # If the API key has a version, use that
77 return api_key.api_version
78
79 return self.default_api_version
80
81 def get_response(self) -> ResponseBase:
82 if self.request.content_type == "application/json":
83 self.transform_request(self.request)
84
85 # Process the request normally
86 response = super().get_response()
87
88 if response.headers.get("Content-Type") == "application/json":
89 self.transform_response(response)
90
91 # Put the API version on the response
92 response.headers[self.api_version_header] = self.api_version
93
94 return response
95
96 def transform_request(self, request: Request) -> None:
97 request_changes = []
98
99 # Find the version being requested,
100 # then get every change after that up to the latest
101 changing = False
102 for version, changes in reversed(self.api_versions.items()):
103 if version == self.api_version:
104 changing = True
105
106 if changing:
107 request_changes.extend(changes)
108
109 if not request_changes:
110 return
111
112 # Get the original request JSON
113 request_data = json.loads(request.body)
114
115 # Transform the request data for this version
116 for change in changes:
117 change().transform_request_forward(request, request_data)
118
119 # Update the request body with the transformed data
120 request._body = json.dumps(request_data).encode("utf-8")
121
122 def transform_response(self, response: ResponseBase) -> None:
123 response_changes = []
124
125 # Get the changes starting AFTER the current version
126 matching = False
127 for version, changes in reversed(self.api_versions.items()):
128 if matching:
129 response_changes.extend(changes)
130
131 if version == self.api_version:
132 matching = True
133
134 if not response_changes:
135 # No changes to apply, just return
136 return
137
138 # Get the original response JSON
139 response_data = json.loads(response.content) # type: ignore[attr-defined]
140
141 for change in reversed(response_changes):
142 # Transform the response data for this version
143 change().transform_response_backward(response, response_data)
144
145 # Update the response body with the transformed data
146 response.content = json.dumps(response_data).encode("utf-8") # type: ignore[attr-defined]