Plain is headed towards 1.0! Subscribe for development updates →

plain.api

  1from typing import Any
  2from uuid import UUID
  3
  4from plain.forms import fields
  5from plain.urls import get_resolver
  6from plain.urls.converters import get_converters
  7from plain.views import View
  8
  9from .responses import JsonResponse, JsonResponseCreated, JsonResponseList
 10from .views import APIBaseView
 11
 12
 13class OpenAPISchemaView(View):
 14    openapi_title: str
 15    openapi_version: str
 16    # openapi_urlrouter: str
 17
 18    def get(self):
 19        # TODO can heaviliy cache this - browser headers? or cache the schema obj?
 20        return JsonResponse(
 21            OpenAPISchema(
 22                title=self.openapi_title,
 23                version=self.openapi_version,
 24                # urlrouter=self.openapi_urlrouter,
 25            ),
 26            json_dumps_params={"sort_keys": True},
 27        )
 28
 29
 30class OpenAPISchema(dict):
 31    def __init__(self, *, title: str, version: str):
 32        self.url_converters = {
 33            class_instance.__class__: key
 34            for key, class_instance in get_converters().items()
 35        }
 36        paths = self.get_paths()
 37        super().__init__(
 38            openapi="3.0.0",
 39            info={
 40                "title": title,
 41                "version": version,
 42                # **moreinfo, or info is a dict?
 43            },
 44            paths=paths,
 45            #               "404": {
 46            #     "$ref": "#/components/responses/not_found"
 47            #   },
 48            #   "422": {
 49            #     "$ref": "#/components/responses/validation_failed_simple"
 50            #   }
 51        )
 52
 53    # def extract_components(self, paths):
 54    #     """Look through the paths and find and repeated definitions
 55    #     that can be pulled out as components."""
 56    #     from collections import Counter
 57    #     components = Counter()
 58    #     for path in paths.values():
 59
 60    def get_paths(self) -> dict[str, dict[str, Any]]:
 61        resolver = get_resolver()  # (self.urlrouter)
 62        paths = {}
 63
 64        for url_pattern in resolver.url_patterns:
 65            for pattern, root_path in self.url_patterns_from_url_pattern(
 66                url_pattern, "/"
 67            ):
 68                path = self.path_from_url_pattern(pattern, root_path)
 69                if operations := self.operations_from_url_pattern(pattern):
 70                    paths[path] = operations
 71                    if parameters := self.parameters_from_url_patterns(
 72                        [url_pattern, pattern]
 73                    ):
 74                        # Assume all methods have the same parameters for now (path params)
 75                        for method in operations:
 76                            operations[method]["parameters"] = parameters
 77
 78        return paths
 79
 80    def url_patterns_from_url_pattern(self, url_pattern, root_path) -> list:
 81        if hasattr(url_pattern, "url_patterns"):
 82            include_path = self.path_from_url_pattern(url_pattern, root_path)
 83            url_patterns = []
 84            for u in url_pattern.url_patterns:
 85                url_patterns.extend(self.url_patterns_from_url_pattern(u, include_path))
 86            return url_patterns
 87        else:
 88            return [(url_pattern, root_path)]
 89
 90    def path_from_url_pattern(self, url_pattern, root_path) -> str:
 91        path = root_path + str(url_pattern.pattern)
 92
 93        for name, converter in url_pattern.pattern.converters.items():
 94            key = self.url_converters[converter.__class__]
 95            path = path.replace(f"<{key}:{name}>", f"{{{name}}}")
 96        return path
 97
 98    def parameters_from_url_patterns(self, url_patterns) -> list[dict[str, Any]]:
 99        """Need to process any parent/included url patterns too"""
100        parameters = []
101
102        for url_pattern in url_patterns:
103            for name, converter in url_pattern.pattern.converters.items():
104                parameters.append(
105                    {
106                        "name": name,
107                        "in": "path",
108                        "required": True,
109                        "schema": {
110                            "type": "string",
111                            "pattern": converter.regex,
112                            # "format": "uuid",
113                        },
114                    }
115                )
116
117        return parameters
118
119    def operations_from_url_pattern(self, url_pattern) -> dict[str, Any]:
120        view_class = url_pattern.callback.view_class
121
122        if not issubclass(view_class, APIBaseView):
123            return {}
124
125        operations = {}
126
127        for method in view_class.allowed_http_methods:
128            if responses := self.responses_from_class_method(view_class, method):
129                operations[method] = {
130                    "responses": responses,
131                }
132
133            if parameters := self.request_body_from_class_method(view_class, method):
134                operations[method]["requestBody"] = parameters
135
136        return operations
137
138    def request_body_from_class_method(self, view_class, method: str) -> dict:
139        """Gets parameters from the form_class on a view"""
140
141        if method not in ("post", "put", "patch"):
142            return {}
143
144        form_class = view_class.form_class
145        if not form_class:
146            return {}
147
148        parameters = []
149        # Any args or kwargs in form.__init__ need to be optional
150        # for this to work...
151        for name, field in form_class().fields.items():
152            parameters.append(
153                {
154                    "name": name,
155                    # "in": "query",
156                    # "required": field.required,
157                    "schema": self.type_to_schema_obj(field),
158                }
159            )
160
161        return {
162            "content": {
163                "application/json": {
164                    "schema": {
165                        "type": "object",
166                        "properties": {p["name"]: p["schema"] for p in parameters},
167                    }
168                },
169            },
170        }
171
172    def responses_from_class_method(
173        self, view_class, method: str
174    ) -> dict[str, dict[str, Any]]:
175        class_method = getattr(view_class, method)
176        return_type = class_method.__annotations__["return"]
177
178        if hasattr(return_type, "status_code"):
179            return_types = [return_type]
180        else:
181            # Assume union...
182            return_types = return_type.__args__
183
184        responses: dict[str, dict[str, Any]] = {}
185
186        for return_type in return_types:
187            if return_type is JsonResponse or return_type is JsonResponseCreated:
188                schema = self.type_to_schema_obj(
189                    view_class.object_to_dict.__annotations__["return"]
190                )
191
192                content = {"application/json": {"schema": schema}}
193            elif return_type is JsonResponseList:
194                schema = self.type_to_schema_obj(
195                    view_class.object_to_dict.__annotations__["return"]
196                )
197
198                content = {
199                    "application/json": {
200                        "schema": {
201                            "type": "array",
202                            "items": schema,
203                        }
204                    }
205                }
206            else:
207                content = None
208
209            response_key = str(return_type.status_code)
210            responses[response_key] = {}
211
212            if description := getattr(return_type, "openapi_description", ""):
213                responses[response_key]["description"] = description
214
215            responses["5XX"] = {
216                "description": "Server error",
217            }
218
219            if content:
220                responses[response_key]["content"] = content
221
222        return responses
223
224    def type_to_schema_obj(self, t) -> dict[str, Any]:
225        # if it's a union with None, add nullable: true
226
227        # if t has a comment, add description
228        # import inspect
229        # if description := inspect.getdoc(t):
230        #     extra_fields = {"description": description}
231        # else:
232        extra_fields: dict[str, Any] = {}
233
234        if hasattr(t, "__annotations__") and t.__annotations__:
235            # It's a TypedDict...
236            return {
237                "type": "object",
238                "properties": {
239                    k: self.type_to_schema_obj(v) for k, v in t.__annotations__.items()
240                },
241                **extra_fields,
242            }
243
244        if hasattr(t, "__origin__"):
245            if t.__origin__ is list:
246                return {
247                    "type": "array",
248                    "items": self.type_to_schema_obj(t.__args__[0]),
249                    **extra_fields,
250                }
251            elif t.__origin__ is dict:
252                return {
253                    "type": "object",
254                    "properties": {
255                        k: self.type_to_schema_obj(v)
256                        for k, v in t.__args__[1].__annotations__.items()
257                    },
258                    **extra_fields,
259                }
260            else:
261                raise ValueError(f"Unknown type: {t}")
262
263        if hasattr(t, "__args__") and len(t.__args__) == 2 and type(None) in t.__args__:
264            return {
265                **self.type_to_schema_obj(t.__args__[0]),
266                "nullable": True,
267                **extra_fields,
268            }
269
270        type_mappings: dict[Any, dict] = {
271            str: {
272                "type": "string",
273            },
274            int: {
275                "type": "integer",
276            },
277            float: {
278                "type": "number",
279            },
280            bool: {
281                "type": "boolean",
282            },
283            dict: {
284                "type": "object",
285            },
286            list: {
287                "type": "array",
288            },
289            UUID: {
290                "type": "string",
291                "format": "uuid",
292            },
293            fields.IntegerField: {
294                "type": "integer",
295            },
296            fields.FloatField: {
297                "type": "number",
298            },
299            fields.DateTimeField: {
300                "type": "string",
301                "format": "date-time",
302            },
303            fields.DateField: {
304                "type": "string",
305                "format": "date",
306            },
307            fields.TimeField: {
308                "type": "string",
309                "format": "time",
310            },
311            fields.EmailField: {
312                "type": "string",
313                "format": "email",
314            },
315            fields.URLField: {
316                "type": "string",
317                "format": "uri",
318            },
319            fields.UUIDField: {
320                "type": "string",
321                "format": "uuid",
322            },
323            fields.DecimalField: {
324                "type": "number",
325            },
326            # fields.FileField: {
327            #     "type": "string",
328            #     "format": "binary",
329            # },
330            fields.ImageField: {
331                "type": "string",
332                "format": "binary",
333            },
334            fields.BooleanField: {
335                "type": "boolean",
336            },
337            fields.NullBooleanField: {
338                "type": "boolean",
339                "nullable": True,
340            },
341            fields.CharField: {
342                "type": "string",
343            },
344            fields.EmailField: {
345                "type": "string",
346                "format": "email",
347            },
348        }
349
350        schema = type_mappings.get(t, {})
351        if not schema:
352            schema = type_mappings.get(t.__class__, {})
353            if not schema:
354                raise ValueError(f"Unknown type: {t}")
355
356        return {**schema, **extra_fields}