1"""Derive JSON Schema from a Tool's `__init__` signature + type hints."""
2
3from __future__ import annotations
4
5import inspect
6import types
7import typing
8from collections.abc import Callable
9from typing import Any, Literal, get_args, get_origin, get_type_hints
10
11JSON_SCHEMA_DIALECT = "https://json-schema.org/draft/2020-12/schema"
12
13_PRIMITIVE_TO_JSON_SCHEMA: dict[type, str] = {
14 str: "string",
15 int: "integer",
16 float: "number",
17 bool: "boolean",
18}
19
20
21def build_input_schema(fn: Callable[..., Any]) -> dict[str, Any]:
22 """Derive a JSON Schema `object` from `fn`'s type hints.
23
24 Supports primitives (str/int/float/bool), `list[T]`, `dict`, `T | None`,
25 `Literal[...]`, and falls back to permissive `string` for anything else.
26 """
27 sig = inspect.signature(fn)
28 try:
29 hints = get_type_hints(fn)
30 except (NameError, TypeError):
31 # Unresolvable forward refs or un-inspectable signatures: fall
32 # back to no hints and let every param default to string.
33 hints = {}
34
35 properties: dict[str, Any] = {}
36 required: list[str] = []
37
38 for param_name, param in sig.parameters.items():
39 if param_name in ("self", "cls"):
40 continue
41
42 # Missing annotation → `Any` (permissive string), not `type(None)`.
43 hint = hints.get(param_name, Any)
44 prop, is_optional = _type_to_schema(hint)
45 properties[param_name] = prop
46
47 has_default = param.default is not inspect.Parameter.empty
48 if not has_default and not is_optional:
49 required.append(param_name)
50
51 schema: dict[str, Any] = {
52 "$schema": JSON_SCHEMA_DIALECT,
53 "type": "object",
54 "properties": properties,
55 }
56 if required:
57 schema["required"] = required
58 return schema
59
60
61def _type_to_schema(hint: Any) -> tuple[dict[str, Any], bool]:
62 if hint is type(None):
63 return {"type": "null"}, True
64 if hint is Any or hint is inspect.Parameter.empty:
65 return {"type": "string"}, False
66 if isinstance(hint, type) and hint in _PRIMITIVE_TO_JSON_SCHEMA:
67 return {"type": _PRIMITIVE_TO_JSON_SCHEMA[hint]}, False
68
69 origin = get_origin(hint)
70 args = get_args(hint)
71
72 if origin in (typing.Union, types.UnionType):
73 non_none = [a for a in args if a is not type(None)]
74 has_none = len(non_none) < len(args)
75 branches = [_type_to_schema(a)[0] for a in non_none]
76 # Keep the None branch in the schema — clients need it to know an
77 # explicit `null` is accepted. `is_optional` separately tells the
78 # outer builder not to mark the field required.
79 if has_none:
80 branches.append({"type": "null"})
81 if len(branches) == 1:
82 return branches[0], has_none
83 return {"anyOf": branches}, has_none
84
85 if origin is Literal:
86 enum_values = list(args)
87 schema: dict[str, Any] = {"enum": enum_values}
88 primitive_types = {
89 p
90 for v in enum_values
91 if (p := _PRIMITIVE_TO_JSON_SCHEMA.get(type(v))) is not None
92 }
93 if len(primitive_types) == 1:
94 schema["type"] = primitive_types.pop()
95 return schema, False
96
97 if origin in (list, tuple, set, frozenset) or hint in (
98 list,
99 tuple,
100 set,
101 frozenset,
102 ):
103 items = _type_to_schema(args[0])[0] if args else {}
104 return {"type": "array", "items": items}, False
105
106 if origin is dict or hint is dict:
107 return {"type": "object"}, False
108
109 return {"type": "string"}, False