1from typing import Any
2from uuid import UUID
3
4from plain.forms import fields
5from plain.urls import get_resolver
6from plain.urls.converters import get_converters
7from plain.views import View
8
9from .responses import JsonResponse, JsonResponseCreated, JsonResponseList
10from .views import APIBaseView
11
12
13class OpenAPISchemaView(View):
14 openapi_title: str
15 openapi_version: str
16 # openapi_urlrouter: str
17
18 def get(self):
19 # TODO can heaviliy cache this - browser headers? or cache the schema obj?
20 return JsonResponse(
21 OpenAPISchema(
22 title=self.openapi_title,
23 version=self.openapi_version,
24 # urlrouter=self.openapi_urlrouter,
25 ),
26 json_dumps_params={"sort_keys": True},
27 )
28
29
30class OpenAPISchema(dict):
31 def __init__(self, *, title: str, version: str):
32 self.url_converters = {
33 class_instance.__class__: key
34 for key, class_instance in get_converters().items()
35 }
36 paths = self.get_paths()
37 super().__init__(
38 openapi="3.0.0",
39 info={
40 "title": title,
41 "version": version,
42 # **moreinfo, or info is a dict?
43 },
44 paths=paths,
45 # "404": {
46 # "$ref": "#/components/responses/not_found"
47 # },
48 # "422": {
49 # "$ref": "#/components/responses/validation_failed_simple"
50 # }
51 )
52
53 # def extract_components(self, paths):
54 # """Look through the paths and find and repeated definitions
55 # that can be pulled out as components."""
56 # from collections import Counter
57 # components = Counter()
58 # for path in paths.values():
59
60 def get_paths(self) -> dict[str, dict[str, Any]]:
61 resolver = get_resolver() # (self.urlrouter)
62 paths = {}
63
64 for url_pattern in resolver.url_patterns:
65 for pattern, root_path in self.url_patterns_from_url_pattern(
66 url_pattern, "/"
67 ):
68 path = self.path_from_url_pattern(pattern, root_path)
69 if operations := self.operations_from_url_pattern(pattern):
70 paths[path] = operations
71 if parameters := self.parameters_from_url_patterns(
72 [url_pattern, pattern]
73 ):
74 # Assume all methods have the same parameters for now (path params)
75 for method in operations:
76 operations[method]["parameters"] = parameters
77
78 return paths
79
80 def url_patterns_from_url_pattern(self, url_pattern, root_path) -> list:
81 if hasattr(url_pattern, "url_patterns"):
82 include_path = self.path_from_url_pattern(url_pattern, root_path)
83 url_patterns = []
84 for u in url_pattern.url_patterns:
85 url_patterns.extend(self.url_patterns_from_url_pattern(u, include_path))
86 return url_patterns
87 else:
88 return [(url_pattern, root_path)]
89
90 def path_from_url_pattern(self, url_pattern, root_path) -> str:
91 path = root_path + str(url_pattern.pattern)
92
93 for name, converter in url_pattern.pattern.converters.items():
94 key = self.url_converters[converter.__class__]
95 path = path.replace(f"<{key}:{name}>", f"{{{name}}}")
96 return path
97
98 def parameters_from_url_patterns(self, url_patterns) -> list[dict[str, Any]]:
99 """Need to process any parent/included url patterns too"""
100 parameters = []
101
102 for url_pattern in url_patterns:
103 for name, converter in url_pattern.pattern.converters.items():
104 parameters.append(
105 {
106 "name": name,
107 "in": "path",
108 "required": True,
109 "schema": {
110 "type": "string",
111 "pattern": converter.regex,
112 # "format": "uuid",
113 },
114 }
115 )
116
117 return parameters
118
119 def operations_from_url_pattern(self, url_pattern) -> dict[str, Any]:
120 view_class = url_pattern.callback.view_class
121
122 if not issubclass(view_class, APIBaseView):
123 return {}
124
125 operations = {}
126
127 for method in view_class.allowed_http_methods:
128 if responses := self.responses_from_class_method(view_class, method):
129 operations[method] = {
130 "responses": responses,
131 }
132
133 if parameters := self.request_body_from_class_method(view_class, method):
134 operations[method]["requestBody"] = parameters
135
136 return operations
137
138 def request_body_from_class_method(self, view_class, method: str) -> dict:
139 """Gets parameters from the form_class on a view"""
140
141 if method not in ("post", "put", "patch"):
142 return {}
143
144 form_class = view_class.form_class
145 if not form_class:
146 return {}
147
148 parameters = []
149 # Any args or kwargs in form.__init__ need to be optional
150 # for this to work...
151 for name, field in form_class().fields.items():
152 parameters.append(
153 {
154 "name": name,
155 # "in": "query",
156 # "required": field.required,
157 "schema": self.type_to_schema_obj(field),
158 }
159 )
160
161 return {
162 "content": {
163 "application/json": {
164 "schema": {
165 "type": "object",
166 "properties": {p["name"]: p["schema"] for p in parameters},
167 }
168 },
169 },
170 }
171
172 def responses_from_class_method(
173 self, view_class, method: str
174 ) -> dict[str, dict[str, Any]]:
175 class_method = getattr(view_class, method)
176 return_type = class_method.__annotations__["return"]
177
178 if hasattr(return_type, "status_code"):
179 return_types = [return_type]
180 else:
181 # Assume union...
182 return_types = return_type.__args__
183
184 responses: dict[str, dict[str, Any]] = {}
185
186 for return_type in return_types:
187 if return_type is JsonResponse or return_type is JsonResponseCreated:
188 schema = self.type_to_schema_obj(
189 view_class.object_to_dict.__annotations__["return"]
190 )
191
192 content = {"application/json": {"schema": schema}}
193 elif return_type is JsonResponseList:
194 schema = self.type_to_schema_obj(
195 view_class.object_to_dict.__annotations__["return"]
196 )
197
198 content = {
199 "application/json": {
200 "schema": {
201 "type": "array",
202 "items": schema,
203 }
204 }
205 }
206 else:
207 content = None
208
209 response_key = str(return_type.status_code)
210 responses[response_key] = {}
211
212 if description := getattr(return_type, "openapi_description", ""):
213 responses[response_key]["description"] = description
214
215 responses["5XX"] = {
216 "description": "Server error",
217 }
218
219 if content:
220 responses[response_key]["content"] = content
221
222 return responses
223
224 def type_to_schema_obj(self, t) -> dict[str, Any]:
225 # if it's a union with None, add nullable: true
226
227 # if t has a comment, add description
228 # import inspect
229 # if description := inspect.getdoc(t):
230 # extra_fields = {"description": description}
231 # else:
232 extra_fields: dict[str, Any] = {}
233
234 if hasattr(t, "__annotations__") and t.__annotations__:
235 # It's a TypedDict...
236 return {
237 "type": "object",
238 "properties": {
239 k: self.type_to_schema_obj(v) for k, v in t.__annotations__.items()
240 },
241 **extra_fields,
242 }
243
244 if hasattr(t, "__origin__"):
245 if t.__origin__ is list:
246 return {
247 "type": "array",
248 "items": self.type_to_schema_obj(t.__args__[0]),
249 **extra_fields,
250 }
251 elif t.__origin__ is dict:
252 return {
253 "type": "object",
254 "properties": {
255 k: self.type_to_schema_obj(v)
256 for k, v in t.__args__[1].__annotations__.items()
257 },
258 **extra_fields,
259 }
260 else:
261 raise ValueError(f"Unknown type: {t}")
262
263 if hasattr(t, "__args__") and len(t.__args__) == 2 and type(None) in t.__args__:
264 return {
265 **self.type_to_schema_obj(t.__args__[0]),
266 "nullable": True,
267 **extra_fields,
268 }
269
270 type_mappings: dict[Any, dict] = {
271 str: {
272 "type": "string",
273 },
274 int: {
275 "type": "integer",
276 },
277 float: {
278 "type": "number",
279 },
280 bool: {
281 "type": "boolean",
282 },
283 dict: {
284 "type": "object",
285 },
286 list: {
287 "type": "array",
288 },
289 UUID: {
290 "type": "string",
291 "format": "uuid",
292 },
293 fields.IntegerField: {
294 "type": "integer",
295 },
296 fields.FloatField: {
297 "type": "number",
298 },
299 fields.DateTimeField: {
300 "type": "string",
301 "format": "date-time",
302 },
303 fields.DateField: {
304 "type": "string",
305 "format": "date",
306 },
307 fields.TimeField: {
308 "type": "string",
309 "format": "time",
310 },
311 fields.EmailField: {
312 "type": "string",
313 "format": "email",
314 },
315 fields.URLField: {
316 "type": "string",
317 "format": "uri",
318 },
319 fields.UUIDField: {
320 "type": "string",
321 "format": "uuid",
322 },
323 fields.DecimalField: {
324 "type": "number",
325 },
326 # fields.FileField: {
327 # "type": "string",
328 # "format": "binary",
329 # },
330 fields.ImageField: {
331 "type": "string",
332 "format": "binary",
333 },
334 fields.BooleanField: {
335 "type": "boolean",
336 },
337 fields.NullBooleanField: {
338 "type": "boolean",
339 "nullable": True,
340 },
341 fields.CharField: {
342 "type": "string",
343 },
344 fields.EmailField: {
345 "type": "string",
346 "format": "email",
347 },
348 }
349
350 schema = type_mappings.get(t, {})
351 if not schema:
352 schema = type_mappings.get(t.__class__, {})
353 if not schema:
354 raise ValueError(f"Unknown type: {t}")
355
356 return {**schema, **extra_fields}