1import inspect
2import json
3import re
4from typing import Any, get_type_hints
5
6from plain.urls import Router, URLPattern, URLResolver
7
8from .helpers import json_content
9from .utils import merge_data, schema_from_type, typed_dict_from_annotation
10
11# A leading `GET /path/` line in a docstring is dropped — the URL is already in `paths`.
12_LEADING_HTTP_METHOD = re.compile(
13 r"^(GET|POST|PUT|PATCH|DELETE|HEAD|OPTIONS)\s+\S+\s*$",
14 re.IGNORECASE,
15)
16
17
18def _merge_parameters(*sources: list[dict[str, Any]]) -> list[dict[str, Any]]:
19 """Merge parameter lists by `(name, in)` (or `$ref`). Later sources override earlier ones."""
20 by_key: dict[tuple[str, str], dict[str, Any]] = {}
21 for source in sources:
22 for p in source:
23 key = ("$ref", p["$ref"]) if "$ref" in p else (p["name"], p["in"])
24 by_key[key] = p
25 return list(by_key.values())
26
27
28def _build_operation_id(view_class: type, method: str) -> str:
29 return f"{view_class.__name__}_{method}"
30
31
32def _schema_for_converter(converter: Any) -> dict[str, Any]:
33 if converter.keyword == "int":
34 return {"type": "integer"}
35 if converter.keyword == "uuid":
36 return {"type": "string", "format": "uuid"}
37 return {"type": "string", "pattern": converter.regex}
38
39
40def _security_schemes_for_view(view_class: type) -> dict[str, dict[str, Any]]:
41 """Collect `openapi_security_schemes` declared on a view's MRO."""
42 schemes: dict[str, dict[str, Any]] = {}
43 for cls in reversed(view_class.__mro__):
44 declared = cls.__dict__.get("openapi_security_schemes")
45 if declared:
46 schemes.update(declared)
47 return schemes
48
49
50def _parameters_for_view(view_class: type) -> list[dict[str, Any]]:
51 """Walk MRO collecting `openapi_parameters`; descendants override by `(name, in)`."""
52 return _merge_parameters(
53 *(
54 cls.__dict__.get("openapi_parameters") or []
55 for cls in reversed(view_class.__mro__)
56 )
57 )
58
59
60def _docstring_summary_description(
61 class_method: Any,
62 view_class: type,
63) -> dict[str, str]:
64 """PEP 257 split: first paragraph → `summary`, rest → `description`. Falls back to the view class."""
65 doc = (inspect.getdoc(class_method) or inspect.getdoc(view_class) or "").strip()
66 if not doc:
67 return {}
68
69 first, _, rest = doc.partition("\n")
70 if _LEADING_HTTP_METHOD.match(first.strip()):
71 doc = rest.lstrip()
72
73 summary, _, description = doc.partition("\n\n")
74 summary = " ".join(summary.split())
75 description = description.strip()
76
77 out: dict[str, str] = {}
78 if summary:
79 out["summary"] = summary
80 if description:
81 out["description"] = description
82 return out
83
84
85class OpenAPISchemaGenerator:
86 def __init__(self, router: Router):
87 # Get initial schema from the router
88 self.schema = getattr(router, "openapi_schema", {}).copy()
89 self.components = getattr(router, "openapi_components", {}).copy()
90
91 self.schema["paths"] = self.get_paths(router.urls)
92
93 if self.components:
94 self.schema["components"] = self.components
95
96 def as_json(self, indent: int) -> str:
97 return json.dumps(self.schema, indent=indent, sort_keys=True)
98
99 def as_yaml(self, indent: int) -> str:
100 import yaml
101
102 # Don't want to get anchors when we dump...
103 cleaned = json.loads(self.as_json(indent=0))
104 return yaml.safe_dump(cleaned, indent=indent, sort_keys=True)
105
106 def get_paths(
107 self,
108 urls: list[URLPattern | URLResolver],
109 ) -> dict[str, dict[str, Any]]:
110 paths = {}
111
112 for url_pattern in urls:
113 if isinstance(url_pattern, URLResolver):
114 paths.update(self.get_paths(url_pattern.url_patterns))
115 elif isinstance(url_pattern, URLPattern):
116 if operations := self.operations_for_url_pattern(url_pattern):
117 path = self.path_from_url_pattern(url_pattern, "/")
118 # TODO could have class level summary/description?
119 paths[path] = operations
120 else:
121 raise ValueError(f"Unknown url pattern: {url_pattern}")
122
123 return paths
124
125 def path_from_url_pattern(self, url_pattern: URLPattern, root_path: str) -> str:
126 path = root_path + url_pattern.raw_route
127 if url_pattern.trailing_slash and url_pattern.raw_route:
128 path += "/"
129
130 for name, converter in url_pattern.converters.items():
131 # Handle both `<type:name>` and the `<name>` shorthand for the default `str` converter.
132 path = path.replace(f"<{converter.keyword}:{name}>", f"{{{name}}}")
133 path = path.replace(f"<{name}>", f"{{{name}}}")
134 return path
135
136 def extract_components(self, obj: Any) -> None:
137 """
138 Extract components from a view or router.
139 """
140 if hasattr(obj, "openapi_components"):
141 self.components = merge_data(
142 self.components,
143 getattr(obj, "openapi_components", {}),
144 )
145
146 def include_view(self, view_class: type) -> bool:
147 """Override to drop a view from the schema."""
148 return True
149
150 def _response_from_return_annotation(
151 self, class_method: Any
152 ) -> dict[str, Any] | None:
153 """Build a 200 response fragment from the method's return annotation if it points at a TypedDict."""
154 try:
155 hints = get_type_hints(class_method)
156 except Exception:
157 return None
158
159 return_type = hints.get("return")
160 if return_type is None:
161 return None
162
163 typed_dict = typed_dict_from_annotation(return_type)
164 if typed_dict is None:
165 return None
166
167 top_ref = schema_from_type(typed_dict, components=self.components)
168 return {
169 "responses": {
170 "200": {
171 "description": "OK",
172 "content": json_content(top_ref),
173 }
174 }
175 }
176
177 def operations_for_url_pattern(
178 self,
179 url_pattern: URLPattern,
180 ) -> dict[str, Any]:
181 operations: dict[str, Any] = {}
182
183 if not self.include_view(url_pattern.view_class):
184 return operations
185
186 # `View` defines runtime stubs for every handler, so gating on
187 # `implemented_methods` is what tells us which verbs the leaf class
188 # actually handles (vs. inheriting a stub that will 405).
189 implemented = getattr(
190 url_pattern.view_class, "implemented_methods", frozenset()
191 )
192
193 inherited_params = _parameters_for_view(url_pattern.view_class)
194 auto_params = self.parameters_from_url_patterns([url_pattern])
195 schemes = _security_schemes_for_view(url_pattern.view_class)
196
197 for vc in reversed(url_pattern.view_class.__mro__):
198 self.extract_components(vc)
199 for method in implemented:
200 class_method = vc.__dict__.get(method)
201 if not class_method:
202 continue
203
204 self.extract_components(class_method)
205 operation = merge_data(
206 getattr(vc, "openapi_schema", {}),
207 getattr(class_method, "openapi_schema", {}),
208 )
209
210 already_has_2xx = any(
211 code.startswith("2") for code in operation.get("responses", {})
212 )
213 if not already_has_2xx:
214 inferred = self._response_from_return_annotation(class_method)
215 if inferred is not None:
216 operation = merge_data(operation, inferred)
217
218 for key, value in _docstring_summary_description(
219 class_method, vc
220 ).items():
221 operation.setdefault(key, value)
222
223 if not operation:
224 continue
225
226 merged_params = _merge_parameters(
227 auto_params,
228 inherited_params,
229 list(operation.get("parameters", [])),
230 )
231 if merged_params:
232 operation["parameters"] = merged_params
233
234 operation.setdefault(
235 "operationId",
236 _build_operation_id(url_pattern.view_class, method),
237 )
238
239 if "security" not in operation and schemes:
240 operation["security"] = [{name: []} for name in schemes]
241 self.components = merge_data(
242 self.components,
243 {"securitySchemes": schemes},
244 )
245
246 # If there are no responses in the 2XX or 3XX range, then don't return it at all.
247 # Most likely the developer didn't define any actual responses for their endpoint,
248 # and all we did was inherit the base error responses.
249 keep_operation = False
250 for status_code in operation.get("responses", {}).keys():
251 if status_code.startswith("2") or status_code.startswith("3"):
252 keep_operation = True
253 break
254
255 if operation and keep_operation:
256 if method in operations:
257 # Merge operation with existing data
258 operations[method] = merge_data(operations[method], operation)
259 else:
260 operations[method] = operation
261
262 return operations
263
264 def parameters_from_url_patterns(
265 self, url_patterns: list[URLPattern]
266 ) -> list[dict[str, Any]]:
267 """Need to process any parent/included url patterns too"""
268 parameters = []
269
270 for url_pattern in url_patterns:
271 for name, converter in url_pattern.converters.items():
272 parameters.append(
273 {
274 "name": name,
275 "in": "path",
276 "required": True,
277 "schema": _schema_for_converter(converter),
278 }
279 )
280
281 return parameters