Plain is headed towards 1.0! Subscribe for development updates →

  1import json
  2from http import HTTPMethod
  3from typing import Any
  4
  5from plain.urls import Router, URLPattern, URLResolver
  6from plain.urls.converters import get_converters
  7
  8from .utils import merge_data
  9
 10
 11class OpenAPISchemaGenerator:
 12    def __init__(self, router: Router):
 13        self.url_converters = {
 14            class_instance.__class__: key
 15            for key, class_instance in get_converters().items()
 16        }
 17
 18        # Get initial schema from the router
 19        self.schema = getattr(router, "openapi_schema", {}).copy()
 20        self.components = getattr(router, "openapi_components", {}).copy()
 21
 22        self.schema["paths"] = self.get_paths(router.urls)
 23
 24        if self.components:
 25            self.schema["components"] = self.components
 26
 27    def as_json(self, indent: int) -> str:
 28        return json.dumps(self.schema, indent=indent, sort_keys=True)
 29
 30    def as_yaml(self, indent: int) -> str:
 31        import yaml
 32
 33        # Don't want to get anchors when we dump...
 34        cleaned = json.loads(self.as_json(indent=0))
 35        return yaml.safe_dump(cleaned, indent=indent, sort_keys=True)
 36
 37    def get_paths(
 38        self, urls: list[URLPattern | URLResolver]
 39    ) -> dict[str, dict[str, Any]]:
 40        paths = {}
 41
 42        for url_pattern in urls:
 43            if isinstance(url_pattern, URLResolver):
 44                paths.update(self.get_paths(url_pattern.url_patterns))
 45            elif isinstance(url_pattern, URLPattern):
 46                if operations := self.operations_for_url_pattern(url_pattern):
 47                    path = self.path_from_url_pattern(url_pattern, "/")
 48                    # TODO could have class level summary/description?
 49                    paths[path] = operations
 50            else:
 51                raise ValueError(f"Unknown url pattern: {url_pattern}")
 52
 53        return paths
 54
 55    def path_from_url_pattern(self, url_pattern: URLPattern, root_path: str) -> str:
 56        path = root_path + str(url_pattern.pattern)
 57
 58        for name, converter in url_pattern.pattern.converters.items():
 59            key = self.url_converters[converter.__class__]
 60            path = path.replace(f"<{key}:{name}>", f"{{{name}}}")
 61        return path
 62
 63    def extract_components(self, obj: Any) -> None:
 64        """
 65        Extract components from a view or router.
 66        """
 67        if hasattr(obj, "openapi_components"):
 68            self.components = merge_data(
 69                self.components,
 70                getattr(obj, "openapi_components", {}),
 71            )
 72
 73    def operations_for_url_pattern(self, url_pattern: URLPattern) -> dict[str, Any]:
 74        operations = {}
 75
 76        for vc in reversed(url_pattern.view.view_class.__mro__):
 77            exclude_http_methods = [
 78                HTTPMethod.TRACE,
 79                HTTPMethod.OPTIONS,
 80                HTTPMethod.CONNECT,
 81            ]
 82
 83            for method in [
 84                x.lower() for x in HTTPMethod if x not in exclude_http_methods
 85            ]:
 86                class_method = getattr(vc, method, None)
 87                if not class_method:
 88                    continue
 89
 90                operation = {}
 91
 92                # Get anything on from the view class itself,
 93                # then override it with the method-specific data
 94                self.extract_components(vc)
 95                operation = merge_data(
 96                    operation,
 97                    getattr(vc, "openapi_schema", {}),
 98                )
 99
100                # Get the schema that applies to the specific method
101                self.extract_components(class_method)
102                operation = merge_data(
103                    operation,
104                    getattr(class_method, "openapi_schema", {}),
105                )
106
107                # Get URL parameters if nothing else was defined
108                if operation and "parameters" not in operation:
109                    if parameters := self.parameters_from_url_patterns([url_pattern]):
110                        operation["parameters"] = parameters
111
112                # If there are no responses in the 2XX or 3XX range, then don't return it at all.
113                # Most likely the developer didn't define any actual responses for their endpoint,
114                # and all we did was inherit the base error responses.
115                keep_operation = False
116                for status_code in operation.get("responses", {}).keys():
117                    if status_code.startswith("2") or status_code.startswith("3"):
118                        keep_operation = True
119                        break
120
121                if operation and keep_operation:
122                    if method in operations:
123                        # Merge operation with existing data
124                        operations[method] = merge_data(operations[method], operation)
125                    else:
126                        operations[method] = operation
127
128        return operations
129
130    def parameters_from_url_patterns(
131        self, url_patterns: list[URLPattern]
132    ) -> list[dict[str, Any]]:
133        """Need to process any parent/included url patterns too"""
134        parameters = []
135
136        for url_pattern in url_patterns:
137            for name, converter in url_pattern.pattern.converters.items():
138                parameters.append(
139                    {
140                        "name": name,
141                        "in": "path",
142                        "required": True,
143                        "schema": {
144                            "type": "string",
145                            "pattern": converter.regex,
146                            # "format": "uuid",
147                        },
148                    }
149                )
150
151        return parameters