main
 1from __future__ import annotations
 2
 3from typing import Any, Dict, cast
 4
 5import pydantic
 6
 7from ._pydantic import to_strict_json_schema
 8from ..types.chat import ChatCompletionFunctionToolParam
 9from ..types.shared_params import FunctionDefinition
10from ..types.responses.function_tool_param import FunctionToolParam as ResponsesFunctionToolParam
11
12
13class PydanticFunctionTool(Dict[str, Any]):
14    """Dictionary wrapper so we can pass the given base model
15    throughout the entire request stack without having to special
16    case it.
17    """
18
19    model: type[pydantic.BaseModel]
20
21    def __init__(self, defn: FunctionDefinition, model: type[pydantic.BaseModel]) -> None:
22        super().__init__(defn)
23        self.model = model
24
25    def cast(self) -> FunctionDefinition:
26        return cast(FunctionDefinition, self)
27
28
29class ResponsesPydanticFunctionTool(Dict[str, Any]):
30    model: type[pydantic.BaseModel]
31
32    def __init__(self, tool: ResponsesFunctionToolParam, model: type[pydantic.BaseModel]) -> None:
33        super().__init__(tool)
34        self.model = model
35
36    def cast(self) -> ResponsesFunctionToolParam:
37        return cast(ResponsesFunctionToolParam, self)
38
39
40def pydantic_function_tool(
41    model: type[pydantic.BaseModel],
42    *,
43    name: str | None = None,  # inferred from class name by default
44    description: str | None = None,  # inferred from class docstring by default
45) -> ChatCompletionFunctionToolParam:
46    if description is None:
47        # note: we intentionally don't use `.getdoc()` to avoid
48        # including pydantic's docstrings
49        description = model.__doc__
50
51    function = PydanticFunctionTool(
52        {
53            "name": name or model.__name__,
54            "strict": True,
55            "parameters": to_strict_json_schema(model),
56        },
57        model,
58    ).cast()
59
60    if description is not None:
61        function["description"] = description
62
63    return {
64        "type": "function",
65        "function": function,
66    }