main
1from __future__ import annotations
2
3import json
4from typing import TYPE_CHECKING, Any, List, Iterable, cast
5from typing_extensions import TypeVar, assert_never
6
7import pydantic
8
9from .._tools import ResponsesPydanticFunctionTool
10from ..._types import Omit
11from ..._utils import is_given
12from ..._compat import PYDANTIC_V1, model_parse_json
13from ..._models import construct_type_unchecked
14from .._pydantic import is_basemodel_type, is_dataclass_like_type
15from ._completions import solve_response_format_t, type_to_response_format_param
16from ...types.responses import (
17 Response,
18 ToolParam,
19 ParsedContent,
20 ParsedResponse,
21 FunctionToolParam,
22 ParsedResponseOutputItem,
23 ParsedResponseOutputText,
24 ResponseFunctionToolCall,
25 ParsedResponseOutputMessage,
26 ResponseFormatTextConfigParam,
27 ParsedResponseFunctionToolCall,
28)
29from ...types.chat.completion_create_params import ResponseFormat
30
31TextFormatT = TypeVar(
32 "TextFormatT",
33 # if it isn't given then we don't do any parsing
34 default=None,
35)
36
37
38def type_to_text_format_param(type_: type) -> ResponseFormatTextConfigParam:
39 response_format_dict = type_to_response_format_param(type_)
40 assert is_given(response_format_dict)
41 response_format_dict = cast(ResponseFormat, response_format_dict) # pyright: ignore[reportUnnecessaryCast]
42 assert response_format_dict["type"] == "json_schema"
43 assert "schema" in response_format_dict["json_schema"]
44
45 return {
46 "type": "json_schema",
47 "strict": True,
48 "name": response_format_dict["json_schema"]["name"],
49 "schema": response_format_dict["json_schema"]["schema"],
50 }
51
52
53def parse_response(
54 *,
55 text_format: type[TextFormatT] | Omit,
56 input_tools: Iterable[ToolParam] | Omit | None,
57 response: Response | ParsedResponse[object],
58) -> ParsedResponse[TextFormatT]:
59 solved_t = solve_response_format_t(text_format)
60 output_list: List[ParsedResponseOutputItem[TextFormatT]] = []
61
62 for output in response.output:
63 if output.type == "message":
64 content_list: List[ParsedContent[TextFormatT]] = []
65 for item in output.content:
66 if item.type != "output_text":
67 content_list.append(item)
68 continue
69
70 content_list.append(
71 construct_type_unchecked(
72 type_=cast(Any, ParsedResponseOutputText)[solved_t],
73 value={
74 **item.to_dict(),
75 "parsed": parse_text(item.text, text_format=text_format),
76 },
77 )
78 )
79
80 output_list.append(
81 construct_type_unchecked(
82 type_=cast(Any, ParsedResponseOutputMessage)[solved_t],
83 value={
84 **output.to_dict(),
85 "content": content_list,
86 },
87 )
88 )
89 elif output.type == "function_call":
90 output_list.append(
91 construct_type_unchecked(
92 type_=ParsedResponseFunctionToolCall,
93 value={
94 **output.to_dict(),
95 "parsed_arguments": parse_function_tool_arguments(
96 input_tools=input_tools, function_call=output
97 ),
98 },
99 )
100 )
101 elif (
102 output.type == "computer_call"
103 or output.type == "file_search_call"
104 or output.type == "web_search_call"
105 or output.type == "reasoning"
106 or output.type == "compaction"
107 or output.type == "mcp_call"
108 or output.type == "mcp_approval_request"
109 or output.type == "image_generation_call"
110 or output.type == "code_interpreter_call"
111 or output.type == "local_shell_call"
112 or output.type == "shell_call"
113 or output.type == "shell_call_output"
114 or output.type == "apply_patch_call"
115 or output.type == "apply_patch_call_output"
116 or output.type == "mcp_list_tools"
117 or output.type == "exec"
118 or output.type == "custom_tool_call"
119 ):
120 output_list.append(output)
121 elif TYPE_CHECKING: # type: ignore
122 assert_never(output)
123 else:
124 output_list.append(output)
125
126 return cast(
127 ParsedResponse[TextFormatT],
128 construct_type_unchecked(
129 type_=cast(Any, ParsedResponse)[solved_t],
130 value={
131 **response.to_dict(),
132 "output": output_list,
133 },
134 ),
135 )
136
137
138def parse_text(text: str, text_format: type[TextFormatT] | Omit) -> TextFormatT | None:
139 if not is_given(text_format):
140 return None
141
142 if is_basemodel_type(text_format):
143 return cast(TextFormatT, model_parse_json(text_format, text))
144
145 if is_dataclass_like_type(text_format):
146 if PYDANTIC_V1:
147 raise TypeError(f"Non BaseModel types are only supported with Pydantic v2 - {text_format}")
148
149 return pydantic.TypeAdapter(text_format).validate_json(text)
150
151 raise TypeError(f"Unable to automatically parse response format type {text_format}")
152
153
154def get_input_tool_by_name(*, input_tools: Iterable[ToolParam], name: str) -> FunctionToolParam | None:
155 for tool in input_tools:
156 if tool["type"] == "function" and tool.get("name") == name:
157 return tool
158
159 return None
160
161
162def parse_function_tool_arguments(
163 *,
164 input_tools: Iterable[ToolParam] | Omit | None,
165 function_call: ParsedResponseFunctionToolCall | ResponseFunctionToolCall,
166) -> object:
167 if input_tools is None or not is_given(input_tools):
168 return None
169
170 input_tool = get_input_tool_by_name(input_tools=input_tools, name=function_call.name)
171 if not input_tool:
172 return None
173
174 tool = cast(object, input_tool)
175 if isinstance(tool, ResponsesPydanticFunctionTool):
176 return model_parse_json(tool.model, function_call.arguments)
177
178 if not input_tool.get("strict"):
179 return None
180
181 return json.loads(function_call.arguments)