1import json
2from http import HTTPMethod
3from typing import Any
4
5from plain.urls import Router, URLPattern, URLResolver
6from plain.urls.converters import get_converters
7
8from .utils import merge_data
9
10
11class OpenAPISchemaGenerator:
12 def __init__(self, router: Router):
13 self.url_converters = {
14 class_instance.__class__: key
15 for key, class_instance in get_converters().items()
16 }
17
18 # Get initial schema from the router
19 self.schema = getattr(router, "openapi_schema", {}).copy()
20 self.components = getattr(router, "openapi_components", {}).copy()
21
22 self.schema["paths"] = self.get_paths(router.urls)
23
24 if self.components:
25 self.schema["components"] = self.components
26
27 def as_json(self, indent: int) -> str:
28 return json.dumps(self.schema, indent=indent, sort_keys=True)
29
30 def as_yaml(self, indent: int) -> str:
31 import yaml
32
33 # Don't want to get anchors when we dump...
34 cleaned = json.loads(self.as_json(indent=0))
35 return yaml.safe_dump(cleaned, indent=indent, sort_keys=True)
36
37 def get_paths(
38 self, urls: list[URLPattern | URLResolver]
39 ) -> dict[str, dict[str, Any]]:
40 paths = {}
41
42 for url_pattern in urls:
43 if isinstance(url_pattern, URLResolver):
44 paths.update(self.get_paths(url_pattern.url_patterns))
45 elif isinstance(url_pattern, URLPattern):
46 if operations := self.operations_for_url_pattern(url_pattern):
47 path = self.path_from_url_pattern(url_pattern, "/")
48 # TODO could have class level summary/description?
49 paths[path] = operations
50 else:
51 raise ValueError(f"Unknown url pattern: {url_pattern}")
52
53 return paths
54
55 def path_from_url_pattern(self, url_pattern: URLPattern, root_path: str) -> str:
56 path = root_path + str(url_pattern.pattern)
57
58 for name, converter in url_pattern.pattern.converters.items():
59 key = self.url_converters[converter.__class__]
60 path = path.replace(f"<{key}:{name}>", f"{{{name}}}")
61 return path
62
63 def extract_components(self, obj: Any) -> None:
64 """
65 Extract components from a view or router.
66 """
67 if hasattr(obj, "openapi_components"):
68 self.components = merge_data(
69 self.components,
70 getattr(obj, "openapi_components", {}),
71 )
72
73 def operations_for_url_pattern(self, url_pattern: URLPattern) -> dict[str, Any]:
74 operations = {}
75
76 for vc in reversed(url_pattern.view.view_class.__mro__):
77 exclude_http_methods = [
78 HTTPMethod.TRACE,
79 HTTPMethod.OPTIONS,
80 HTTPMethod.CONNECT,
81 ]
82
83 for method in [
84 x.lower() for x in HTTPMethod if x not in exclude_http_methods
85 ]:
86 class_method = getattr(vc, method, None)
87 if not class_method:
88 continue
89
90 operation = {}
91
92 # Get anything on from the view class itself,
93 # then override it with the method-specific data
94 self.extract_components(vc)
95 operation = merge_data(
96 operation,
97 getattr(vc, "openapi_schema", {}),
98 )
99
100 # Get the schema that applies to the specific method
101 self.extract_components(class_method)
102 operation = merge_data(
103 operation,
104 getattr(class_method, "openapi_schema", {}),
105 )
106
107 # Get URL parameters if nothing else was defined
108 if operation and "parameters" not in operation:
109 if parameters := self.parameters_from_url_patterns([url_pattern]):
110 operation["parameters"] = parameters
111
112 # If there are no responses in the 2XX or 3XX range, then don't return it at all.
113 # Most likely the developer didn't define any actual responses for their endpoint,
114 # and all we did was inherit the base error responses.
115 keep_operation = False
116 for status_code in operation.get("responses", {}).keys():
117 if status_code.startswith("2") or status_code.startswith("3"):
118 keep_operation = True
119 break
120
121 if operation and keep_operation:
122 if method in operations:
123 # Merge operation with existing data
124 operations[method] = merge_data(operations[method], operation)
125 else:
126 operations[method] = operation
127
128 return operations
129
130 def parameters_from_url_patterns(
131 self, url_patterns: list[URLPattern]
132 ) -> list[dict[str, Any]]:
133 """Need to process any parent/included url patterns too"""
134 parameters = []
135
136 for url_pattern in url_patterns:
137 for name, converter in url_pattern.pattern.converters.items():
138 parameters.append(
139 {
140 "name": name,
141 "in": "path",
142 "required": True,
143 "schema": {
144 "type": "string",
145 "pattern": converter.regex,
146 # "format": "uuid",
147 },
148 }
149 )
150
151 return parameters