Plain is headed towards 1.0! Subscribe for development updates →

plain.api

  1from typing import Any
  2from uuid import UUID
  3
  4from plain.forms import fields
  5from plain.urls import get_resolver
  6from plain.urls.converters import get_converters
  7from plain.views import View
  8
  9from .responses import JsonResponse, JsonResponseCreated, JsonResponseList
 10from .views import APIBaseView
 11
 12
 13class OpenAPISchemaView(View):
 14    openapi_title: str
 15    openapi_version: str
 16    openapi_urlconf: str
 17
 18    def get(self):
 19        # TODO can heaviliy cache this - browser headers? or cache the schema obj?
 20        return JsonResponse(
 21            OpenAPISchema(
 22                title=self.openapi_title,
 23                version=self.openapi_version,
 24                urlconf=self.openapi_urlconf,
 25            ),
 26            json_dumps_params={"sort_keys": True},
 27        )
 28
 29
 30class OpenAPISchema(dict):
 31    def __init__(self, *, title: str, version: str, urlconf="app.api.urls"):
 32        self.urlconf = urlconf
 33        self.url_converters = {
 34            class_instance.__class__: key
 35            for key, class_instance in get_converters().items()
 36        }
 37        paths = self.get_paths()
 38        super().__init__(
 39            openapi="3.0.0",
 40            info={
 41                "title": title,
 42                "version": version,
 43                # **moreinfo, or info is a dict?
 44            },
 45            paths=paths,
 46            #               "404": {
 47            #     "$ref": "#/components/responses/not_found"
 48            #   },
 49            #   "422": {
 50            #     "$ref": "#/components/responses/validation_failed_simple"
 51            #   }
 52        )
 53
 54    # def extract_components(self, paths):
 55    #     """Look through the paths and find and repeated definitions
 56    #     that can be pulled out as components."""
 57    #     from collections import Counter
 58    #     components = Counter()
 59    #     for path in paths.values():
 60
 61    def get_paths(self) -> dict[str, dict[str, Any]]:
 62        resolver = get_resolver(self.urlconf)
 63        paths = {}
 64
 65        for url_pattern in resolver.url_patterns:
 66            for pattern, root_path in self.url_patterns_from_url_pattern(
 67                url_pattern, "/"
 68            ):
 69                path = self.path_from_url_pattern(pattern, root_path)
 70                if operations := self.operations_from_url_pattern(pattern):
 71                    paths[path] = operations
 72                    if parameters := self.parameters_from_url_patterns(
 73                        [url_pattern, pattern]
 74                    ):
 75                        # Assume all methods have the same parameters for now (path params)
 76                        for method in operations:
 77                            operations[method]["parameters"] = parameters
 78
 79        return paths
 80
 81    def url_patterns_from_url_pattern(self, url_pattern, root_path) -> list:
 82        if hasattr(url_pattern, "url_patterns"):
 83            include_path = self.path_from_url_pattern(url_pattern, root_path)
 84            url_patterns = []
 85            for u in url_pattern.url_patterns:
 86                url_patterns.extend(self.url_patterns_from_url_pattern(u, include_path))
 87            return url_patterns
 88        else:
 89            return [(url_pattern, root_path)]
 90
 91    def path_from_url_pattern(self, url_pattern, root_path) -> str:
 92        path = root_path + str(url_pattern.pattern)
 93
 94        for name, converter in url_pattern.pattern.converters.items():
 95            key = self.url_converters[converter.__class__]
 96            path = path.replace(f"<{key}:{name}>", f"{{{name}}}")
 97        return path
 98
 99    def parameters_from_url_patterns(self, url_patterns) -> list[dict[str, Any]]:
100        """Need to process any parent/included url patterns too"""
101        parameters = []
102
103        for url_pattern in url_patterns:
104            for name, converter in url_pattern.pattern.converters.items():
105                parameters.append(
106                    {
107                        "name": name,
108                        "in": "path",
109                        "required": True,
110                        "schema": {
111                            "type": "string",
112                            "pattern": converter.regex,
113                            # "format": "uuid",
114                        },
115                    }
116                )
117
118        return parameters
119
120    def operations_from_url_pattern(self, url_pattern) -> dict[str, Any]:
121        view_class = url_pattern.callback.view_class
122
123        if not issubclass(view_class, APIBaseView):
124            return {}
125
126        operations = {}
127
128        known_http_method_names = [
129            "get",
130            "post",
131            "put",
132            "patch",
133            "delete",
134            # Don't care about these ones...
135            # "head",
136            # "options",
137            # "trace",
138        ]
139
140        for method in known_http_method_names:
141            if not hasattr(view_class, method):
142                continue
143
144            if responses := self.responses_from_class_method(view_class, method):
145                operations[method] = {
146                    "responses": responses,
147                }
148
149            if parameters := self.request_body_from_class_method(view_class, method):
150                operations[method]["requestBody"] = parameters
151
152        return operations
153
154    def request_body_from_class_method(self, view_class, method: str) -> dict:
155        """Gets parameters from the form_class on a view"""
156
157        if method not in ("post", "put", "patch"):
158            return {}
159
160        form_class = view_class.form_class
161        if not form_class:
162            return {}
163
164        parameters = []
165        # Any args or kwargs in form.__init__ need to be optional
166        # for this to work...
167        for name, field in form_class().fields.items():
168            parameters.append(
169                {
170                    "name": name,
171                    # "in": "query",
172                    # "required": field.required,
173                    "schema": self.type_to_schema_obj(field),
174                }
175            )
176
177        return {
178            "content": {
179                "application/json": {
180                    "schema": {
181                        "type": "object",
182                        "properties": {p["name"]: p["schema"] for p in parameters},
183                    }
184                },
185            },
186        }
187
188    def responses_from_class_method(
189        self, view_class, method: str
190    ) -> dict[str, dict[str, Any]]:
191        class_method = getattr(view_class, method)
192        return_type = class_method.__annotations__["return"]
193
194        if hasattr(return_type, "status_code"):
195            return_types = [return_type]
196        else:
197            # Assume union...
198            return_types = return_type.__args__
199
200        responses: dict[str, dict[str, Any]] = {}
201
202        for return_type in return_types:
203            if return_type is JsonResponse or return_type is JsonResponseCreated:
204                schema = self.type_to_schema_obj(
205                    view_class.object_to_dict.__annotations__["return"]
206                )
207
208                content = {"application/json": {"schema": schema}}
209            elif return_type is JsonResponseList:
210                schema = self.type_to_schema_obj(
211                    view_class.object_to_dict.__annotations__["return"]
212                )
213
214                content = {
215                    "application/json": {
216                        "schema": {
217                            "type": "array",
218                            "items": schema,
219                        }
220                    }
221                }
222            else:
223                content = None
224
225            response_key = str(return_type.status_code)
226            responses[response_key] = {}
227
228            if description := getattr(return_type, "openapi_description", ""):
229                responses[response_key]["description"] = description
230
231            responses["5XX"] = {
232                "description": "Server error",
233            }
234
235            if content:
236                responses[response_key]["content"] = content
237
238        return responses
239
240    def type_to_schema_obj(self, t) -> dict[str, Any]:
241        # if it's a union with None, add nullable: true
242
243        # if t has a comment, add description
244        # import inspect
245        # if description := inspect.getdoc(t):
246        #     extra_fields = {"description": description}
247        # else:
248        extra_fields: dict[str, Any] = {}
249
250        if hasattr(t, "__annotations__") and t.__annotations__:
251            # It's a TypedDict...
252            return {
253                "type": "object",
254                "properties": {
255                    k: self.type_to_schema_obj(v) for k, v in t.__annotations__.items()
256                },
257                **extra_fields,
258            }
259
260        if hasattr(t, "__origin__"):
261            if t.__origin__ is list:
262                return {
263                    "type": "array",
264                    "items": self.type_to_schema_obj(t.__args__[0]),
265                    **extra_fields,
266                }
267            elif t.__origin__ is dict:
268                return {
269                    "type": "object",
270                    "properties": {
271                        k: self.type_to_schema_obj(v)
272                        for k, v in t.__args__[1].__annotations__.items()
273                    },
274                    **extra_fields,
275                }
276            else:
277                raise ValueError(f"Unknown type: {t}")
278
279        if hasattr(t, "__args__") and len(t.__args__) == 2 and type(None) in t.__args__:
280            return {
281                **self.type_to_schema_obj(t.__args__[0]),
282                "nullable": True,
283                **extra_fields,
284            }
285
286        type_mappings: dict[Any, dict] = {
287            str: {
288                "type": "string",
289            },
290            int: {
291                "type": "integer",
292            },
293            float: {
294                "type": "number",
295            },
296            bool: {
297                "type": "boolean",
298            },
299            dict: {
300                "type": "object",
301            },
302            list: {
303                "type": "array",
304            },
305            UUID: {
306                "type": "string",
307                "format": "uuid",
308            },
309            fields.IntegerField: {
310                "type": "integer",
311            },
312            fields.FloatField: {
313                "type": "number",
314            },
315            fields.DateTimeField: {
316                "type": "string",
317                "format": "date-time",
318            },
319            fields.DateField: {
320                "type": "string",
321                "format": "date",
322            },
323            fields.TimeField: {
324                "type": "string",
325                "format": "time",
326            },
327            fields.EmailField: {
328                "type": "string",
329                "format": "email",
330            },
331            fields.URLField: {
332                "type": "string",
333                "format": "uri",
334            },
335            fields.UUIDField: {
336                "type": "string",
337                "format": "uuid",
338            },
339            fields.DecimalField: {
340                "type": "number",
341            },
342            # fields.FileField: {
343            #     "type": "string",
344            #     "format": "binary",
345            # },
346            fields.ImageField: {
347                "type": "string",
348                "format": "binary",
349            },
350            fields.BooleanField: {
351                "type": "boolean",
352            },
353            fields.NullBooleanField: {
354                "type": "boolean",
355                "nullable": True,
356            },
357            fields.CharField: {
358                "type": "string",
359            },
360            fields.SlugField: {
361                "type": "string",
362            },
363            fields.EmailField: {
364                "type": "string",
365                "format": "email",
366            },
367        }
368
369        schema = type_mappings.get(t, {})
370        if not schema:
371            schema = type_mappings.get(t.__class__, {})
372            if not schema:
373                raise ValueError(f"Unknown type: {t}")
374
375        return {**schema, **extra_fields}