v0.148.0
  1import inspect
  2import json
  3import re
  4from typing import Any, get_type_hints
  5
  6from plain.urls import Router, URLPattern, URLResolver
  7
  8from .helpers import json_content
  9from .utils import merge_data, schema_from_type, typed_dict_from_annotation
 10
 11# A leading `GET /path/` line in a docstring is dropped — the URL is already in `paths`.
 12_LEADING_HTTP_METHOD = re.compile(
 13    r"^(GET|POST|PUT|PATCH|DELETE|HEAD|OPTIONS)\s+\S+\s*$",
 14    re.IGNORECASE,
 15)
 16
 17
 18def _merge_parameters(*sources: list[dict[str, Any]]) -> list[dict[str, Any]]:
 19    """Merge parameter lists by `(name, in)` (or `$ref`). Later sources override earlier ones."""
 20    by_key: dict[tuple[str, str], dict[str, Any]] = {}
 21    for source in sources:
 22        for p in source:
 23            key = ("$ref", p["$ref"]) if "$ref" in p else (p["name"], p["in"])
 24            by_key[key] = p
 25    return list(by_key.values())
 26
 27
 28def _build_operation_id(view_class: type, method: str) -> str:
 29    return f"{view_class.__name__}_{method}"
 30
 31
 32def _schema_for_converter(converter: Any) -> dict[str, Any]:
 33    if converter.keyword == "int":
 34        return {"type": "integer"}
 35    if converter.keyword == "uuid":
 36        return {"type": "string", "format": "uuid"}
 37    return {"type": "string", "pattern": converter.regex}
 38
 39
 40def _security_schemes_for_view(view_class: type) -> dict[str, dict[str, Any]]:
 41    """Collect `openapi_security_schemes` declared on a view's MRO."""
 42    schemes: dict[str, dict[str, Any]] = {}
 43    for cls in reversed(view_class.__mro__):
 44        declared = cls.__dict__.get("openapi_security_schemes")
 45        if declared:
 46            schemes.update(declared)
 47    return schemes
 48
 49
 50def _parameters_for_view(view_class: type) -> list[dict[str, Any]]:
 51    """Walk MRO collecting `openapi_parameters`; descendants override by `(name, in)`."""
 52    return _merge_parameters(
 53        *(
 54            cls.__dict__.get("openapi_parameters") or []
 55            for cls in reversed(view_class.__mro__)
 56        )
 57    )
 58
 59
 60def _docstring_summary_description(
 61    class_method: Any,
 62    view_class: type,
 63) -> dict[str, str]:
 64    """PEP 257 split: first paragraph → `summary`, rest → `description`. Falls back to the view class."""
 65    doc = (inspect.getdoc(class_method) or inspect.getdoc(view_class) or "").strip()
 66    if not doc:
 67        return {}
 68
 69    first, _, rest = doc.partition("\n")
 70    if _LEADING_HTTP_METHOD.match(first.strip()):
 71        doc = rest.lstrip()
 72
 73    summary, _, description = doc.partition("\n\n")
 74    summary = " ".join(summary.split())
 75    description = description.strip()
 76
 77    out: dict[str, str] = {}
 78    if summary:
 79        out["summary"] = summary
 80    if description:
 81        out["description"] = description
 82    return out
 83
 84
 85class OpenAPISchemaGenerator:
 86    def __init__(self, router: Router):
 87        # Get initial schema from the router
 88        self.schema = getattr(router, "openapi_schema", {}).copy()
 89        self.components = getattr(router, "openapi_components", {}).copy()
 90
 91        self.schema["paths"] = self.get_paths(router.urls)
 92
 93        if self.components:
 94            self.schema["components"] = self.components
 95
 96    def as_json(self, indent: int) -> str:
 97        return json.dumps(self.schema, indent=indent, sort_keys=True)
 98
 99    def as_yaml(self, indent: int) -> str:
100        import yaml
101
102        # Don't want to get anchors when we dump...
103        cleaned = json.loads(self.as_json(indent=0))
104        return yaml.safe_dump(cleaned, indent=indent, sort_keys=True)
105
106    def get_paths(
107        self,
108        urls: list[URLPattern | URLResolver],
109    ) -> dict[str, dict[str, Any]]:
110        paths = {}
111
112        for url_pattern in urls:
113            if isinstance(url_pattern, URLResolver):
114                paths.update(self.get_paths(url_pattern.url_patterns))
115            elif isinstance(url_pattern, URLPattern):
116                if operations := self.operations_for_url_pattern(url_pattern):
117                    path = self.path_from_url_pattern(url_pattern, "/")
118                    # TODO could have class level summary/description?
119                    paths[path] = operations
120            else:
121                raise ValueError(f"Unknown url pattern: {url_pattern}")
122
123        return paths
124
125    def path_from_url_pattern(self, url_pattern: URLPattern, root_path: str) -> str:
126        path = root_path + url_pattern.raw_route
127        if url_pattern.trailing_slash and url_pattern.raw_route:
128            path += "/"
129
130        for name, converter in url_pattern.converters.items():
131            # Handle both `<type:name>` and the `<name>` shorthand for the default `str` converter.
132            path = path.replace(f"<{converter.keyword}:{name}>", f"{{{name}}}")
133            path = path.replace(f"<{name}>", f"{{{name}}}")
134        return path
135
136    def extract_components(self, obj: Any) -> None:
137        """
138        Extract components from a view or router.
139        """
140        if hasattr(obj, "openapi_components"):
141            self.components = merge_data(
142                self.components,
143                getattr(obj, "openapi_components", {}),
144            )
145
146    def include_view(self, view_class: type) -> bool:
147        """Override to drop a view from the schema."""
148        return True
149
150    def _response_from_return_annotation(
151        self, class_method: Any
152    ) -> dict[str, Any] | None:
153        """Build a 200 response fragment from the method's return annotation if it points at a TypedDict."""
154        try:
155            hints = get_type_hints(class_method)
156        except Exception:
157            return None
158
159        return_type = hints.get("return")
160        if return_type is None:
161            return None
162
163        typed_dict = typed_dict_from_annotation(return_type)
164        if typed_dict is None:
165            return None
166
167        top_ref = schema_from_type(typed_dict, components=self.components)
168        return {
169            "responses": {
170                "200": {
171                    "description": "OK",
172                    "content": json_content(top_ref),
173                }
174            }
175        }
176
177    def operations_for_url_pattern(
178        self,
179        url_pattern: URLPattern,
180    ) -> dict[str, Any]:
181        operations: dict[str, Any] = {}
182
183        if not self.include_view(url_pattern.view_class):
184            return operations
185
186        # `View` defines runtime stubs for every handler, so gating on
187        # `implemented_methods` is what tells us which verbs the leaf class
188        # actually handles (vs. inheriting a stub that will 405).
189        implemented = getattr(
190            url_pattern.view_class, "implemented_methods", frozenset()
191        )
192
193        inherited_params = _parameters_for_view(url_pattern.view_class)
194        auto_params = self.parameters_from_url_patterns([url_pattern])
195        schemes = _security_schemes_for_view(url_pattern.view_class)
196
197        for vc in reversed(url_pattern.view_class.__mro__):
198            self.extract_components(vc)
199            for method in implemented:
200                class_method = vc.__dict__.get(method)
201                if not class_method:
202                    continue
203
204                self.extract_components(class_method)
205                operation = merge_data(
206                    getattr(vc, "openapi_schema", {}),
207                    getattr(class_method, "openapi_schema", {}),
208                )
209
210                already_has_2xx = any(
211                    code.startswith("2") for code in operation.get("responses", {})
212                )
213                if not already_has_2xx:
214                    inferred = self._response_from_return_annotation(class_method)
215                    if inferred is not None:
216                        operation = merge_data(operation, inferred)
217
218                for key, value in _docstring_summary_description(
219                    class_method, vc
220                ).items():
221                    operation.setdefault(key, value)
222
223                if not operation:
224                    continue
225
226                merged_params = _merge_parameters(
227                    auto_params,
228                    inherited_params,
229                    list(operation.get("parameters", [])),
230                )
231                if merged_params:
232                    operation["parameters"] = merged_params
233
234                operation.setdefault(
235                    "operationId",
236                    _build_operation_id(url_pattern.view_class, method),
237                )
238
239                if "security" not in operation and schemes:
240                    operation["security"] = [{name: []} for name in schemes]
241                    self.components = merge_data(
242                        self.components,
243                        {"securitySchemes": schemes},
244                    )
245
246                # If there are no responses in the 2XX or 3XX range, then don't return it at all.
247                # Most likely the developer didn't define any actual responses for their endpoint,
248                # and all we did was inherit the base error responses.
249                keep_operation = False
250                for status_code in operation.get("responses", {}).keys():
251                    if status_code.startswith("2") or status_code.startswith("3"):
252                        keep_operation = True
253                        break
254
255                if operation and keep_operation:
256                    if method in operations:
257                        # Merge operation with existing data
258                        operations[method] = merge_data(operations[method], operation)
259                    else:
260                        operations[method] = operation
261
262        return operations
263
264    def parameters_from_url_patterns(
265        self, url_patterns: list[URLPattern]
266    ) -> list[dict[str, Any]]:
267        """Need to process any parent/included url patterns too"""
268        parameters = []
269
270        for url_pattern in url_patterns:
271            for name, converter in url_pattern.converters.items():
272                parameters.append(
273                    {
274                        "name": name,
275                        "in": "path",
276                        "required": True,
277                        "schema": _schema_for_converter(converter),
278                    }
279                )
280
281        return parameters