1from typing import Any
2from uuid import UUID
3
4from plain.forms import fields
5from plain.urls import get_resolver
6from plain.urls.converters import get_converters
7from plain.views import View
8
9from .responses import JsonResponse, JsonResponseCreated, JsonResponseList
10from .views import APIBaseView
11
12
13class OpenAPISchemaView(View):
14 openapi_title: str
15 openapi_version: str
16 openapi_urlconf: str
17
18 def get(self):
19 # TODO can heaviliy cache this - browser headers? or cache the schema obj?
20 return JsonResponse(
21 OpenAPISchema(
22 title=self.openapi_title,
23 version=self.openapi_version,
24 urlconf=self.openapi_urlconf,
25 ),
26 json_dumps_params={"sort_keys": True},
27 )
28
29
30class OpenAPISchema(dict):
31 def __init__(self, *, title: str, version: str, urlconf="app.api.urls"):
32 self.urlconf = urlconf
33 self.url_converters = {
34 class_instance.__class__: key
35 for key, class_instance in get_converters().items()
36 }
37 paths = self.get_paths()
38 super().__init__(
39 openapi="3.0.0",
40 info={
41 "title": title,
42 "version": version,
43 # **moreinfo, or info is a dict?
44 },
45 paths=paths,
46 # "404": {
47 # "$ref": "#/components/responses/not_found"
48 # },
49 # "422": {
50 # "$ref": "#/components/responses/validation_failed_simple"
51 # }
52 )
53
54 # def extract_components(self, paths):
55 # """Look through the paths and find and repeated definitions
56 # that can be pulled out as components."""
57 # from collections import Counter
58 # components = Counter()
59 # for path in paths.values():
60
61 def get_paths(self) -> dict[str, dict[str, Any]]:
62 resolver = get_resolver(self.urlconf)
63 paths = {}
64
65 for url_pattern in resolver.url_patterns:
66 for pattern, root_path in self.url_patterns_from_url_pattern(
67 url_pattern, "/"
68 ):
69 path = self.path_from_url_pattern(pattern, root_path)
70 if operations := self.operations_from_url_pattern(pattern):
71 paths[path] = operations
72 if parameters := self.parameters_from_url_patterns(
73 [url_pattern, pattern]
74 ):
75 # Assume all methods have the same parameters for now (path params)
76 for method in operations:
77 operations[method]["parameters"] = parameters
78
79 return paths
80
81 def url_patterns_from_url_pattern(self, url_pattern, root_path) -> list:
82 if hasattr(url_pattern, "url_patterns"):
83 include_path = self.path_from_url_pattern(url_pattern, root_path)
84 url_patterns = []
85 for u in url_pattern.url_patterns:
86 url_patterns.extend(self.url_patterns_from_url_pattern(u, include_path))
87 return url_patterns
88 else:
89 return [(url_pattern, root_path)]
90
91 def path_from_url_pattern(self, url_pattern, root_path) -> str:
92 path = root_path + str(url_pattern.pattern)
93
94 for name, converter in url_pattern.pattern.converters.items():
95 key = self.url_converters[converter.__class__]
96 path = path.replace(f"<{key}:{name}>", f"{{{name}}}")
97 return path
98
99 def parameters_from_url_patterns(self, url_patterns) -> list[dict[str, Any]]:
100 """Need to process any parent/included url patterns too"""
101 parameters = []
102
103 for url_pattern in url_patterns:
104 for name, converter in url_pattern.pattern.converters.items():
105 parameters.append(
106 {
107 "name": name,
108 "in": "path",
109 "required": True,
110 "schema": {
111 "type": "string",
112 "pattern": converter.regex,
113 # "format": "uuid",
114 },
115 }
116 )
117
118 return parameters
119
120 def operations_from_url_pattern(self, url_pattern) -> dict[str, Any]:
121 view_class = url_pattern.callback.view_class
122
123 if not issubclass(view_class, APIBaseView):
124 return {}
125
126 operations = {}
127
128 known_http_method_names = [
129 "get",
130 "post",
131 "put",
132 "patch",
133 "delete",
134 # Don't care about these ones...
135 # "head",
136 # "options",
137 # "trace",
138 ]
139
140 for method in known_http_method_names:
141 if not hasattr(view_class, method):
142 continue
143
144 if responses := self.responses_from_class_method(view_class, method):
145 operations[method] = {
146 "responses": responses,
147 }
148
149 if parameters := self.request_body_from_class_method(view_class, method):
150 operations[method]["requestBody"] = parameters
151
152 return operations
153
154 def request_body_from_class_method(self, view_class, method: str) -> dict:
155 """Gets parameters from the form_class on a view"""
156
157 if method not in ("post", "put", "patch"):
158 return {}
159
160 form_class = view_class.form_class
161 if not form_class:
162 return {}
163
164 parameters = []
165 # Any args or kwargs in form.__init__ need to be optional
166 # for this to work...
167 for name, field in form_class().fields.items():
168 parameters.append(
169 {
170 "name": name,
171 # "in": "query",
172 # "required": field.required,
173 "schema": self.type_to_schema_obj(field),
174 }
175 )
176
177 return {
178 "content": {
179 "application/json": {
180 "schema": {
181 "type": "object",
182 "properties": {p["name"]: p["schema"] for p in parameters},
183 }
184 },
185 },
186 }
187
188 def responses_from_class_method(
189 self, view_class, method: str
190 ) -> dict[str, dict[str, Any]]:
191 class_method = getattr(view_class, method)
192 return_type = class_method.__annotations__["return"]
193
194 if hasattr(return_type, "status_code"):
195 return_types = [return_type]
196 else:
197 # Assume union...
198 return_types = return_type.__args__
199
200 responses: dict[str, dict[str, Any]] = {}
201
202 for return_type in return_types:
203 if return_type is JsonResponse or return_type is JsonResponseCreated:
204 schema = self.type_to_schema_obj(
205 view_class.object_to_dict.__annotations__["return"]
206 )
207
208 content = {"application/json": {"schema": schema}}
209 elif return_type is JsonResponseList:
210 schema = self.type_to_schema_obj(
211 view_class.object_to_dict.__annotations__["return"]
212 )
213
214 content = {
215 "application/json": {
216 "schema": {
217 "type": "array",
218 "items": schema,
219 }
220 }
221 }
222 else:
223 content = None
224
225 response_key = str(return_type.status_code)
226 responses[response_key] = {}
227
228 if description := getattr(return_type, "openapi_description", ""):
229 responses[response_key]["description"] = description
230
231 responses["5XX"] = {
232 "description": "Server error",
233 }
234
235 if content:
236 responses[response_key]["content"] = content
237
238 return responses
239
240 def type_to_schema_obj(self, t) -> dict[str, Any]:
241 # if it's a union with None, add nullable: true
242
243 # if t has a comment, add description
244 # import inspect
245 # if description := inspect.getdoc(t):
246 # extra_fields = {"description": description}
247 # else:
248 extra_fields: dict[str, Any] = {}
249
250 if hasattr(t, "__annotations__") and t.__annotations__:
251 # It's a TypedDict...
252 return {
253 "type": "object",
254 "properties": {
255 k: self.type_to_schema_obj(v) for k, v in t.__annotations__.items()
256 },
257 **extra_fields,
258 }
259
260 if hasattr(t, "__origin__"):
261 if t.__origin__ is list:
262 return {
263 "type": "array",
264 "items": self.type_to_schema_obj(t.__args__[0]),
265 **extra_fields,
266 }
267 elif t.__origin__ is dict:
268 return {
269 "type": "object",
270 "properties": {
271 k: self.type_to_schema_obj(v)
272 for k, v in t.__args__[1].__annotations__.items()
273 },
274 **extra_fields,
275 }
276 else:
277 raise ValueError(f"Unknown type: {t}")
278
279 if hasattr(t, "__args__") and len(t.__args__) == 2 and type(None) in t.__args__:
280 return {
281 **self.type_to_schema_obj(t.__args__[0]),
282 "nullable": True,
283 **extra_fields,
284 }
285
286 type_mappings: dict[Any, dict] = {
287 str: {
288 "type": "string",
289 },
290 int: {
291 "type": "integer",
292 },
293 float: {
294 "type": "number",
295 },
296 bool: {
297 "type": "boolean",
298 },
299 dict: {
300 "type": "object",
301 },
302 list: {
303 "type": "array",
304 },
305 UUID: {
306 "type": "string",
307 "format": "uuid",
308 },
309 fields.IntegerField: {
310 "type": "integer",
311 },
312 fields.FloatField: {
313 "type": "number",
314 },
315 fields.DateTimeField: {
316 "type": "string",
317 "format": "date-time",
318 },
319 fields.DateField: {
320 "type": "string",
321 "format": "date",
322 },
323 fields.TimeField: {
324 "type": "string",
325 "format": "time",
326 },
327 fields.EmailField: {
328 "type": "string",
329 "format": "email",
330 },
331 fields.URLField: {
332 "type": "string",
333 "format": "uri",
334 },
335 fields.UUIDField: {
336 "type": "string",
337 "format": "uuid",
338 },
339 fields.DecimalField: {
340 "type": "number",
341 },
342 # fields.FileField: {
343 # "type": "string",
344 # "format": "binary",
345 # },
346 fields.ImageField: {
347 "type": "string",
348 "format": "binary",
349 },
350 fields.BooleanField: {
351 "type": "boolean",
352 },
353 fields.NullBooleanField: {
354 "type": "boolean",
355 "nullable": True,
356 },
357 fields.CharField: {
358 "type": "string",
359 },
360 fields.SlugField: {
361 "type": "string",
362 },
363 fields.EmailField: {
364 "type": "string",
365 "format": "email",
366 },
367 }
368
369 schema = type_mappings.get(t, {})
370 if not schema:
371 schema = type_mappings.get(t.__class__, {})
372 if not schema:
373 raise ValueError(f"Unknown type: {t}")
374
375 return {**schema, **extra_fields}