main
1from __future__ import annotations
2
3from typing import Any, Dict, cast
4
5import pydantic
6
7from ._pydantic import to_strict_json_schema
8from ..types.chat import ChatCompletionFunctionToolParam
9from ..types.shared_params import FunctionDefinition
10from ..types.responses.function_tool_param import FunctionToolParam as ResponsesFunctionToolParam
11
12
13class PydanticFunctionTool(Dict[str, Any]):
14 """Dictionary wrapper so we can pass the given base model
15 throughout the entire request stack without having to special
16 case it.
17 """
18
19 model: type[pydantic.BaseModel]
20
21 def __init__(self, defn: FunctionDefinition, model: type[pydantic.BaseModel]) -> None:
22 super().__init__(defn)
23 self.model = model
24
25 def cast(self) -> FunctionDefinition:
26 return cast(FunctionDefinition, self)
27
28
29class ResponsesPydanticFunctionTool(Dict[str, Any]):
30 model: type[pydantic.BaseModel]
31
32 def __init__(self, tool: ResponsesFunctionToolParam, model: type[pydantic.BaseModel]) -> None:
33 super().__init__(tool)
34 self.model = model
35
36 def cast(self) -> ResponsesFunctionToolParam:
37 return cast(ResponsesFunctionToolParam, self)
38
39
40def pydantic_function_tool(
41 model: type[pydantic.BaseModel],
42 *,
43 name: str | None = None, # inferred from class name by default
44 description: str | None = None, # inferred from class docstring by default
45) -> ChatCompletionFunctionToolParam:
46 if description is None:
47 # note: we intentionally don't use `.getdoc()` to avoid
48 # including pydantic's docstrings
49 description = model.__doc__
50
51 function = PydanticFunctionTool(
52 {
53 "name": name or model.__name__,
54 "strict": True,
55 "parameters": to_strict_json_schema(model),
56 },
57 model,
58 ).cast()
59
60 if description is not None:
61 function["description"] = description
62
63 return {
64 "type": "function",
65 "function": function,
66 }