Plain is headed towards 1.0! Subscribe for development updates →

  1import json
  2from http import HTTPMethod
  3from typing import Any
  4
  5from plain.urls import Router, URLPattern, URLResolver
  6from plain.urls.converters import get_converters
  7
  8from .utils import merge_data
  9
 10
 11class OpenAPISchemaGenerator:
 12    def __init__(self, router: Router):
 13        self.url_converters = {
 14            class_instance.__class__: key
 15            for key, class_instance in get_converters().items()
 16        }
 17
 18        # Get initial schema from the router
 19        self.schema = getattr(router, "openapi_schema", {}).copy()
 20        self.components = getattr(router, "openapi_components", {}).copy()
 21
 22        self.schema["paths"] = self.get_paths(router.urls)
 23
 24        if self.components:
 25            self.schema["components"] = self.components
 26
 27    def as_json(self, indent):
 28        return json.dumps(self.schema, indent=indent, sort_keys=True)
 29
 30    def as_yaml(self, indent):
 31        import yaml
 32
 33        # Don't want to get anchors when we dump...
 34        cleaned = json.loads(self.as_json(indent=0))
 35        return yaml.safe_dump(cleaned, indent=indent, sort_keys=True)
 36
 37    def get_paths(self, urls) -> dict[str, dict[str, Any]]:
 38        paths = {}
 39
 40        for url_pattern in urls:
 41            if isinstance(url_pattern, URLResolver):
 42                paths.update(self.get_paths(url_pattern.url_patterns))
 43            elif isinstance(url_pattern, URLPattern):
 44                if operations := self.operations_for_url_pattern(url_pattern):
 45                    path = self.path_from_url_pattern(url_pattern, "/")
 46                    # TODO could have class level summary/description?
 47                    paths[path] = operations
 48            else:
 49                raise ValueError(f"Unknown url pattern: {url_pattern}")
 50
 51        return paths
 52
 53    def path_from_url_pattern(self, url_pattern, root_path) -> str:
 54        path = root_path + str(url_pattern.pattern)
 55
 56        for name, converter in url_pattern.pattern.converters.items():
 57            key = self.url_converters[converter.__class__]
 58            path = path.replace(f"<{key}:{name}>", f"{{{name}}}")
 59        return path
 60
 61    def extract_components(self, obj):
 62        """
 63        Extract components from a view or router.
 64        """
 65        if hasattr(obj, "openapi_components"):
 66            self.components = merge_data(
 67                self.components,
 68                getattr(obj, "openapi_components", {}),
 69            )
 70
 71    def operations_for_url_pattern(self, url_pattern) -> dict[str, Any]:
 72        operations = {}
 73
 74        for vc in reversed(url_pattern.view.view_class.__mro__):
 75            exclude_http_methods = [
 76                HTTPMethod.TRACE,
 77                HTTPMethod.OPTIONS,
 78                HTTPMethod.CONNECT,
 79            ]
 80
 81            for method in [
 82                x.lower() for x in HTTPMethod if x not in exclude_http_methods
 83            ]:
 84                class_method = getattr(vc, method, None)
 85                if not class_method:
 86                    continue
 87
 88                operation = {}
 89
 90                # Get anything on from the view class itself,
 91                # then override it with the method-specific data
 92                self.extract_components(vc)
 93                operation = merge_data(
 94                    operation,
 95                    getattr(vc, "openapi_schema", {}),
 96                )
 97
 98                # Get the schema that applies to the specific method
 99                self.extract_components(class_method)
100                operation = merge_data(
101                    operation,
102                    getattr(class_method, "openapi_schema", {}),
103                )
104
105                # Get URL parameters if nothing else was defined
106                if operation and "parameters" not in operation:
107                    if parameters := self.parameters_from_url_patterns([url_pattern]):
108                        operation["parameters"] = parameters
109
110                # If there are no responses in the 2XX or 3XX range, then don't return it at all.
111                # Most likely the developer didn't define any actual responses for their endpoint,
112                # and all we did was inherit the base error responses.
113                keep_operation = False
114                for status_code in operation.get("responses", {}).keys():
115                    if status_code.startswith("2") or status_code.startswith("3"):
116                        keep_operation = True
117                        break
118
119                if operation and keep_operation:
120                    if method in operations:
121                        # Merge operation with existing data
122                        operations[method] = merge_data(operations[method], operation)
123                    else:
124                        operations[method] = operation
125
126        return operations
127
128    def parameters_from_url_patterns(self, url_patterns) -> list[dict[str, Any]]:
129        """Need to process any parent/included url patterns too"""
130        parameters = []
131
132        for url_pattern in url_patterns:
133            for name, converter in url_pattern.pattern.converters.items():
134                parameters.append(
135                    {
136                        "name": name,
137                        "in": "path",
138                        "required": True,
139                        "schema": {
140                            "type": "string",
141                            "pattern": converter.regex,
142                            # "format": "uuid",
143                        },
144                    }
145                )
146
147        return parameters