1"""Derive JSON Schema from a Tool's `__init__` signature + type hints."""
  2
  3from __future__ import annotations
  4
  5import inspect
  6import types
  7import typing
  8from collections.abc import Callable
  9from typing import Any, Literal, get_args, get_origin, get_type_hints
 10
 11JSON_SCHEMA_DIALECT = "https://json-schema.org/draft/2020-12/schema"
 12
 13_PRIMITIVE_TO_JSON_SCHEMA: dict[type, str] = {
 14    str: "string",
 15    int: "integer",
 16    float: "number",
 17    bool: "boolean",
 18}
 19
 20
 21def build_input_schema(fn: Callable[..., Any]) -> dict[str, Any]:
 22    """Derive a JSON Schema `object` from `fn`'s type hints.
 23
 24    Supports primitives (str/int/float/bool), `list[T]`, `dict`, `T | None`,
 25    `Literal[...]`, and falls back to permissive `string` for anything else.
 26    """
 27    sig = inspect.signature(fn)
 28    try:
 29        hints = get_type_hints(fn)
 30    except (NameError, TypeError):
 31        # Unresolvable forward refs or un-inspectable signatures: fall
 32        # back to no hints and let every param default to string.
 33        hints = {}
 34
 35    properties: dict[str, Any] = {}
 36    required: list[str] = []
 37
 38    for param_name, param in sig.parameters.items():
 39        if param_name in ("self", "cls"):
 40            continue
 41
 42        # Missing annotation → `Any` (permissive string), not `type(None)`.
 43        hint = hints.get(param_name, Any)
 44        prop, is_optional = _type_to_schema(hint)
 45        properties[param_name] = prop
 46
 47        has_default = param.default is not inspect.Parameter.empty
 48        if not has_default and not is_optional:
 49            required.append(param_name)
 50
 51    schema: dict[str, Any] = {
 52        "$schema": JSON_SCHEMA_DIALECT,
 53        "type": "object",
 54        "properties": properties,
 55    }
 56    if required:
 57        schema["required"] = required
 58    return schema
 59
 60
 61def _type_to_schema(hint: Any) -> tuple[dict[str, Any], bool]:
 62    if hint is type(None):
 63        return {"type": "null"}, True
 64    if hint is Any or hint is inspect.Parameter.empty:
 65        return {"type": "string"}, False
 66    if isinstance(hint, type) and hint in _PRIMITIVE_TO_JSON_SCHEMA:
 67        return {"type": _PRIMITIVE_TO_JSON_SCHEMA[hint]}, False
 68
 69    origin = get_origin(hint)
 70    args = get_args(hint)
 71
 72    if origin in (typing.Union, types.UnionType):
 73        non_none = [a for a in args if a is not type(None)]
 74        has_none = len(non_none) < len(args)
 75        branches = [_type_to_schema(a)[0] for a in non_none]
 76        # Keep the None branch in the schema — clients need it to know an
 77        # explicit `null` is accepted. `is_optional` separately tells the
 78        # outer builder not to mark the field required.
 79        if has_none:
 80            branches.append({"type": "null"})
 81        if len(branches) == 1:
 82            return branches[0], has_none
 83        return {"anyOf": branches}, has_none
 84
 85    if origin is Literal:
 86        enum_values = list(args)
 87        schema: dict[str, Any] = {"enum": enum_values}
 88        primitive_types = {
 89            p
 90            for v in enum_values
 91            if (p := _PRIMITIVE_TO_JSON_SCHEMA.get(type(v))) is not None
 92        }
 93        if len(primitive_types) == 1:
 94            schema["type"] = primitive_types.pop()
 95        return schema, False
 96
 97    if origin in (list, tuple, set, frozenset) or hint in (
 98        list,
 99        tuple,
100        set,
101        frozenset,
102    ):
103        items = _type_to_schema(args[0])[0] if args else {}
104        return {"type": "array", "items": items}, False
105
106    if origin is dict or hint is dict:
107        return {"type": "object"}, False
108
109    return {"type": "string"}, False