1import json
2from http import HTTPMethod
3from typing import Any
4
5from plain.urls import Router, URLPattern, URLResolver
6from plain.urls.converters import get_converters
7
8from .utils import merge_data
9
10
11class OpenAPISchemaGenerator:
12 def __init__(self, router: Router):
13 self.url_converters = {
14 class_instance.__class__: key
15 for key, class_instance in get_converters().items()
16 }
17
18 # Get initial schema from the router
19 self.schema = getattr(router, "openapi_schema", {}).copy()
20 self.components = getattr(router, "openapi_components", {}).copy()
21
22 self.schema["paths"] = self.get_paths(router.urls)
23
24 if self.components:
25 self.schema["components"] = self.components
26
27 def as_json(self, indent):
28 return json.dumps(self.schema, indent=indent, sort_keys=True)
29
30 def as_yaml(self, indent):
31 import yaml
32
33 # Don't want to get anchors when we dump...
34 cleaned = json.loads(self.as_json(indent=0))
35 return yaml.safe_dump(cleaned, indent=indent, sort_keys=True)
36
37 def get_paths(self, urls) -> dict[str, dict[str, Any]]:
38 paths = {}
39
40 for url_pattern in urls:
41 if isinstance(url_pattern, URLResolver):
42 paths.update(self.get_paths(url_pattern.url_patterns))
43 elif isinstance(url_pattern, URLPattern):
44 if operations := self.operations_for_url_pattern(url_pattern):
45 path = self.path_from_url_pattern(url_pattern, "/")
46 # TODO could have class level summary/description?
47 paths[path] = operations
48 else:
49 raise ValueError(f"Unknown url pattern: {url_pattern}")
50
51 return paths
52
53 def path_from_url_pattern(self, url_pattern, root_path) -> str:
54 path = root_path + str(url_pattern.pattern)
55
56 for name, converter in url_pattern.pattern.converters.items():
57 key = self.url_converters[converter.__class__]
58 path = path.replace(f"<{key}:{name}>", f"{{{name}}}")
59 return path
60
61 def extract_components(self, obj):
62 """
63 Extract components from a view or router.
64 """
65 if hasattr(obj, "openapi_components"):
66 self.components = merge_data(
67 self.components,
68 getattr(obj, "openapi_components", {}),
69 )
70
71 def operations_for_url_pattern(self, url_pattern) -> dict[str, Any]:
72 operations = {}
73
74 for vc in reversed(url_pattern.view.view_class.__mro__):
75 exclude_http_methods = [
76 HTTPMethod.TRACE,
77 HTTPMethod.OPTIONS,
78 HTTPMethod.CONNECT,
79 ]
80
81 for method in [
82 x.lower() for x in HTTPMethod if x not in exclude_http_methods
83 ]:
84 class_method = getattr(vc, method, None)
85 if not class_method:
86 continue
87
88 operation = {}
89
90 # Get anything on from the view class itself,
91 # then override it with the method-specific data
92 self.extract_components(vc)
93 operation = merge_data(
94 operation,
95 getattr(vc, "openapi_schema", {}),
96 )
97
98 # Get the schema that applies to the specific method
99 self.extract_components(class_method)
100 operation = merge_data(
101 operation,
102 getattr(class_method, "openapi_schema", {}),
103 )
104
105 # Get URL parameters if nothing else was defined
106 if operation and "parameters" not in operation:
107 if parameters := self.parameters_from_url_patterns([url_pattern]):
108 operation["parameters"] = parameters
109
110 # If there are no responses in the 2XX or 3XX range, then don't return it at all.
111 # Most likely the developer didn't define any actual responses for their endpoint,
112 # and all we did was inherit the base error responses.
113 keep_operation = False
114 for status_code in operation.get("responses", {}).keys():
115 if status_code.startswith("2") or status_code.startswith("3"):
116 keep_operation = True
117 break
118
119 if operation and keep_operation:
120 if method in operations:
121 # Merge operation with existing data
122 operations[method] = merge_data(operations[method], operation)
123 else:
124 operations[method] = operation
125
126 return operations
127
128 def parameters_from_url_patterns(self, url_patterns) -> list[dict[str, Any]]:
129 """Need to process any parent/included url patterns too"""
130 parameters = []
131
132 for url_pattern in url_patterns:
133 for name, converter in url_pattern.pattern.converters.items():
134 parameters.append(
135 {
136 "name": name,
137 "in": "path",
138 "required": True,
139 "schema": {
140 "type": "string",
141 "pattern": converter.regex,
142 # "format": "uuid",
143 },
144 }
145 )
146
147 return parameters