main
1from __future__ import annotations
2
3import json
4import logging
5from typing import TYPE_CHECKING, Any, Iterable, cast
6from typing_extensions import TypeVar, TypeGuard, assert_never
7
8import pydantic
9
10from .._tools import PydanticFunctionTool
11from ..._types import Omit, omit
12from ..._utils import is_dict, is_given
13from ..._compat import PYDANTIC_V1, model_parse_json
14from ..._models import construct_type_unchecked
15from .._pydantic import is_basemodel_type, to_strict_json_schema, is_dataclass_like_type
16from ...types.chat import (
17 ParsedChoice,
18 ChatCompletion,
19 ParsedFunction,
20 ParsedChatCompletion,
21 ChatCompletionMessage,
22 ParsedFunctionToolCall,
23 ParsedChatCompletionMessage,
24 ChatCompletionToolUnionParam,
25 ChatCompletionFunctionToolParam,
26 completion_create_params,
27)
28from ..._exceptions import LengthFinishReasonError, ContentFilterFinishReasonError
29from ...types.shared_params import FunctionDefinition
30from ...types.chat.completion_create_params import ResponseFormat as ResponseFormatParam
31from ...types.chat.chat_completion_message_function_tool_call import Function
32
33ResponseFormatT = TypeVar(
34 "ResponseFormatT",
35 # if it isn't given then we don't do any parsing
36 default=None,
37)
38_default_response_format: None = None
39
40log: logging.Logger = logging.getLogger("openai.lib.parsing")
41
42
43def is_strict_chat_completion_tool_param(
44 tool: ChatCompletionToolUnionParam,
45) -> TypeGuard[ChatCompletionFunctionToolParam]:
46 """Check if the given tool is a strict ChatCompletionFunctionToolParam."""
47 if not tool["type"] == "function":
48 return False
49 if tool["function"].get("strict") is not True:
50 return False
51
52 return True
53
54
55def select_strict_chat_completion_tools(
56 tools: Iterable[ChatCompletionToolUnionParam] | Omit = omit,
57) -> Iterable[ChatCompletionFunctionToolParam] | Omit:
58 """Select only the strict ChatCompletionFunctionToolParams from the given tools."""
59 if not is_given(tools):
60 return omit
61
62 return [t for t in tools if is_strict_chat_completion_tool_param(t)]
63
64
65def validate_input_tools(
66 tools: Iterable[ChatCompletionToolUnionParam] | Omit = omit,
67) -> Iterable[ChatCompletionFunctionToolParam] | Omit:
68 if not is_given(tools):
69 return omit
70
71 for tool in tools:
72 if tool["type"] != "function":
73 raise ValueError(
74 f"Currently only `function` tool types support auto-parsing; Received `{tool['type']}`",
75 )
76
77 strict = tool["function"].get("strict")
78 if strict is not True:
79 raise ValueError(
80 f"`{tool['function']['name']}` is not strict. Only `strict` function tools can be auto-parsed"
81 )
82
83 return cast(Iterable[ChatCompletionFunctionToolParam], tools)
84
85
86def parse_chat_completion(
87 *,
88 response_format: type[ResponseFormatT] | completion_create_params.ResponseFormat | Omit,
89 input_tools: Iterable[ChatCompletionToolUnionParam] | Omit,
90 chat_completion: ChatCompletion | ParsedChatCompletion[object],
91) -> ParsedChatCompletion[ResponseFormatT]:
92 if is_given(input_tools):
93 input_tools = [t for t in input_tools]
94 else:
95 input_tools = []
96
97 choices: list[ParsedChoice[ResponseFormatT]] = []
98 for choice in chat_completion.choices:
99 if choice.finish_reason == "length":
100 raise LengthFinishReasonError(completion=chat_completion)
101
102 if choice.finish_reason == "content_filter":
103 raise ContentFilterFinishReasonError()
104
105 message = choice.message
106
107 tool_calls: list[ParsedFunctionToolCall] = []
108 if message.tool_calls:
109 for tool_call in message.tool_calls:
110 if tool_call.type == "function":
111 tool_call_dict = tool_call.to_dict()
112 tool_calls.append(
113 construct_type_unchecked(
114 value={
115 **tool_call_dict,
116 "function": {
117 **cast(Any, tool_call_dict["function"]),
118 "parsed_arguments": parse_function_tool_arguments(
119 input_tools=input_tools, function=tool_call.function
120 ),
121 },
122 },
123 type_=ParsedFunctionToolCall,
124 )
125 )
126 elif tool_call.type == "custom":
127 # warn user that custom tool calls are not callable here
128 log.warning(
129 "Custom tool calls are not callable. Ignoring tool call: %s - %s",
130 tool_call.id,
131 tool_call.custom.name,
132 stacklevel=2,
133 )
134 elif TYPE_CHECKING: # type: ignore[unreachable]
135 assert_never(tool_call)
136 else:
137 tool_calls.append(tool_call)
138
139 choices.append(
140 construct_type_unchecked(
141 type_=cast(Any, ParsedChoice)[solve_response_format_t(response_format)],
142 value={
143 **choice.to_dict(),
144 "message": {
145 **message.to_dict(),
146 "parsed": maybe_parse_content(
147 response_format=response_format,
148 message=message,
149 ),
150 "tool_calls": tool_calls if tool_calls else None,
151 },
152 },
153 )
154 )
155
156 return cast(
157 ParsedChatCompletion[ResponseFormatT],
158 construct_type_unchecked(
159 type_=cast(Any, ParsedChatCompletion)[solve_response_format_t(response_format)],
160 value={
161 **chat_completion.to_dict(),
162 "choices": choices,
163 },
164 ),
165 )
166
167
168def get_input_tool_by_name(
169 *, input_tools: list[ChatCompletionToolUnionParam], name: str
170) -> ChatCompletionFunctionToolParam | None:
171 return next((t for t in input_tools if t["type"] == "function" and t.get("function", {}).get("name") == name), None)
172
173
174def parse_function_tool_arguments(
175 *, input_tools: list[ChatCompletionToolUnionParam], function: Function | ParsedFunction
176) -> object | None:
177 input_tool = get_input_tool_by_name(input_tools=input_tools, name=function.name)
178 if not input_tool:
179 return None
180
181 input_fn = cast(object, input_tool.get("function"))
182 if isinstance(input_fn, PydanticFunctionTool):
183 return model_parse_json(input_fn.model, function.arguments)
184
185 input_fn = cast(FunctionDefinition, input_fn)
186
187 if not input_fn.get("strict"):
188 return None
189
190 return json.loads(function.arguments) # type: ignore[no-any-return]
191
192
193def maybe_parse_content(
194 *,
195 response_format: type[ResponseFormatT] | ResponseFormatParam | Omit,
196 message: ChatCompletionMessage | ParsedChatCompletionMessage[object],
197) -> ResponseFormatT | None:
198 if has_rich_response_format(response_format) and message.content and not message.refusal:
199 return _parse_content(response_format, message.content)
200
201 return None
202
203
204def solve_response_format_t(
205 response_format: type[ResponseFormatT] | ResponseFormatParam | Omit,
206) -> type[ResponseFormatT]:
207 """Return the runtime type for the given response format.
208
209 If no response format is given, or if we won't auto-parse the response format
210 then we default to `None`.
211 """
212 if has_rich_response_format(response_format):
213 return response_format
214
215 return cast("type[ResponseFormatT]", _default_response_format)
216
217
218def has_parseable_input(
219 *,
220 response_format: type | ResponseFormatParam | Omit,
221 input_tools: Iterable[ChatCompletionToolUnionParam] | Omit = omit,
222) -> bool:
223 if has_rich_response_format(response_format):
224 return True
225
226 for input_tool in input_tools or []:
227 if is_parseable_tool(input_tool):
228 return True
229
230 return False
231
232
233def has_rich_response_format(
234 response_format: type[ResponseFormatT] | ResponseFormatParam | Omit,
235) -> TypeGuard[type[ResponseFormatT]]:
236 if not is_given(response_format):
237 return False
238
239 if is_response_format_param(response_format):
240 return False
241
242 return True
243
244
245def is_response_format_param(response_format: object) -> TypeGuard[ResponseFormatParam]:
246 return is_dict(response_format)
247
248
249def is_parseable_tool(input_tool: ChatCompletionToolUnionParam) -> bool:
250 if input_tool["type"] != "function":
251 return False
252
253 input_fn = cast(object, input_tool.get("function"))
254 if isinstance(input_fn, PydanticFunctionTool):
255 return True
256
257 return cast(FunctionDefinition, input_fn).get("strict") or False
258
259
260def _parse_content(response_format: type[ResponseFormatT], content: str) -> ResponseFormatT:
261 if is_basemodel_type(response_format):
262 return cast(ResponseFormatT, model_parse_json(response_format, content))
263
264 if is_dataclass_like_type(response_format):
265 if PYDANTIC_V1:
266 raise TypeError(f"Non BaseModel types are only supported with Pydantic v2 - {response_format}")
267
268 return pydantic.TypeAdapter(response_format).validate_json(content)
269
270 raise TypeError(f"Unable to automatically parse response format type {response_format}")
271
272
273def type_to_response_format_param(
274 response_format: type | completion_create_params.ResponseFormat | Omit,
275) -> ResponseFormatParam | Omit:
276 if not is_given(response_format):
277 return omit
278
279 if is_response_format_param(response_format):
280 return response_format
281
282 # type checkers don't narrow the negation of a `TypeGuard` as it isn't
283 # a safe default behaviour but we know that at this point the `response_format`
284 # can only be a `type`
285 response_format = cast(type, response_format)
286
287 json_schema_type: type[pydantic.BaseModel] | pydantic.TypeAdapter[Any] | None = None
288
289 if is_basemodel_type(response_format):
290 name = response_format.__name__
291 json_schema_type = response_format
292 elif is_dataclass_like_type(response_format):
293 name = response_format.__name__
294 json_schema_type = pydantic.TypeAdapter(response_format)
295 else:
296 raise TypeError(f"Unsupported response_format type - {response_format}")
297
298 return {
299 "type": "json_schema",
300 "json_schema": {
301 "schema": to_strict_json_schema(json_schema_type),
302 "name": name,
303 "strict": True,
304 },
305 }