main
  1from __future__ import annotations
  2
  3import json
  4import logging
  5from typing import TYPE_CHECKING, Any, Iterable, cast
  6from typing_extensions import TypeVar, TypeGuard, assert_never
  7
  8import pydantic
  9
 10from .._tools import PydanticFunctionTool
 11from ..._types import Omit, omit
 12from ..._utils import is_dict, is_given
 13from ..._compat import PYDANTIC_V1, model_parse_json
 14from ..._models import construct_type_unchecked
 15from .._pydantic import is_basemodel_type, to_strict_json_schema, is_dataclass_like_type
 16from ...types.chat import (
 17    ParsedChoice,
 18    ChatCompletion,
 19    ParsedFunction,
 20    ParsedChatCompletion,
 21    ChatCompletionMessage,
 22    ParsedFunctionToolCall,
 23    ParsedChatCompletionMessage,
 24    ChatCompletionToolUnionParam,
 25    ChatCompletionFunctionToolParam,
 26    completion_create_params,
 27)
 28from ..._exceptions import LengthFinishReasonError, ContentFilterFinishReasonError
 29from ...types.shared_params import FunctionDefinition
 30from ...types.chat.completion_create_params import ResponseFormat as ResponseFormatParam
 31from ...types.chat.chat_completion_message_function_tool_call import Function
 32
 33ResponseFormatT = TypeVar(
 34    "ResponseFormatT",
 35    # if it isn't given then we don't do any parsing
 36    default=None,
 37)
 38_default_response_format: None = None
 39
 40log: logging.Logger = logging.getLogger("openai.lib.parsing")
 41
 42
 43def is_strict_chat_completion_tool_param(
 44    tool: ChatCompletionToolUnionParam,
 45) -> TypeGuard[ChatCompletionFunctionToolParam]:
 46    """Check if the given tool is a strict ChatCompletionFunctionToolParam."""
 47    if not tool["type"] == "function":
 48        return False
 49    if tool["function"].get("strict") is not True:
 50        return False
 51
 52    return True
 53
 54
 55def select_strict_chat_completion_tools(
 56    tools: Iterable[ChatCompletionToolUnionParam] | Omit = omit,
 57) -> Iterable[ChatCompletionFunctionToolParam] | Omit:
 58    """Select only the strict ChatCompletionFunctionToolParams from the given tools."""
 59    if not is_given(tools):
 60        return omit
 61
 62    return [t for t in tools if is_strict_chat_completion_tool_param(t)]
 63
 64
 65def validate_input_tools(
 66    tools: Iterable[ChatCompletionToolUnionParam] | Omit = omit,
 67) -> Iterable[ChatCompletionFunctionToolParam] | Omit:
 68    if not is_given(tools):
 69        return omit
 70
 71    for tool in tools:
 72        if tool["type"] != "function":
 73            raise ValueError(
 74                f"Currently only `function` tool types support auto-parsing; Received `{tool['type']}`",
 75            )
 76
 77        strict = tool["function"].get("strict")
 78        if strict is not True:
 79            raise ValueError(
 80                f"`{tool['function']['name']}` is not strict. Only `strict` function tools can be auto-parsed"
 81            )
 82
 83    return cast(Iterable[ChatCompletionFunctionToolParam], tools)
 84
 85
 86def parse_chat_completion(
 87    *,
 88    response_format: type[ResponseFormatT] | completion_create_params.ResponseFormat | Omit,
 89    input_tools: Iterable[ChatCompletionToolUnionParam] | Omit,
 90    chat_completion: ChatCompletion | ParsedChatCompletion[object],
 91) -> ParsedChatCompletion[ResponseFormatT]:
 92    if is_given(input_tools):
 93        input_tools = [t for t in input_tools]
 94    else:
 95        input_tools = []
 96
 97    choices: list[ParsedChoice[ResponseFormatT]] = []
 98    for choice in chat_completion.choices:
 99        if choice.finish_reason == "length":
100            raise LengthFinishReasonError(completion=chat_completion)
101
102        if choice.finish_reason == "content_filter":
103            raise ContentFilterFinishReasonError()
104
105        message = choice.message
106
107        tool_calls: list[ParsedFunctionToolCall] = []
108        if message.tool_calls:
109            for tool_call in message.tool_calls:
110                if tool_call.type == "function":
111                    tool_call_dict = tool_call.to_dict()
112                    tool_calls.append(
113                        construct_type_unchecked(
114                            value={
115                                **tool_call_dict,
116                                "function": {
117                                    **cast(Any, tool_call_dict["function"]),
118                                    "parsed_arguments": parse_function_tool_arguments(
119                                        input_tools=input_tools, function=tool_call.function
120                                    ),
121                                },
122                            },
123                            type_=ParsedFunctionToolCall,
124                        )
125                    )
126                elif tool_call.type == "custom":
127                    # warn user that custom tool calls are not callable here
128                    log.warning(
129                        "Custom tool calls are not callable. Ignoring tool call: %s - %s",
130                        tool_call.id,
131                        tool_call.custom.name,
132                        stacklevel=2,
133                    )
134                elif TYPE_CHECKING:  # type: ignore[unreachable]
135                    assert_never(tool_call)
136                else:
137                    tool_calls.append(tool_call)
138
139        choices.append(
140            construct_type_unchecked(
141                type_=cast(Any, ParsedChoice)[solve_response_format_t(response_format)],
142                value={
143                    **choice.to_dict(),
144                    "message": {
145                        **message.to_dict(),
146                        "parsed": maybe_parse_content(
147                            response_format=response_format,
148                            message=message,
149                        ),
150                        "tool_calls": tool_calls if tool_calls else None,
151                    },
152                },
153            )
154        )
155
156    return cast(
157        ParsedChatCompletion[ResponseFormatT],
158        construct_type_unchecked(
159            type_=cast(Any, ParsedChatCompletion)[solve_response_format_t(response_format)],
160            value={
161                **chat_completion.to_dict(),
162                "choices": choices,
163            },
164        ),
165    )
166
167
168def get_input_tool_by_name(
169    *, input_tools: list[ChatCompletionToolUnionParam], name: str
170) -> ChatCompletionFunctionToolParam | None:
171    return next((t for t in input_tools if t["type"] == "function" and t.get("function", {}).get("name") == name), None)
172
173
174def parse_function_tool_arguments(
175    *, input_tools: list[ChatCompletionToolUnionParam], function: Function | ParsedFunction
176) -> object | None:
177    input_tool = get_input_tool_by_name(input_tools=input_tools, name=function.name)
178    if not input_tool:
179        return None
180
181    input_fn = cast(object, input_tool.get("function"))
182    if isinstance(input_fn, PydanticFunctionTool):
183        return model_parse_json(input_fn.model, function.arguments)
184
185    input_fn = cast(FunctionDefinition, input_fn)
186
187    if not input_fn.get("strict"):
188        return None
189
190    return json.loads(function.arguments)  # type: ignore[no-any-return]
191
192
193def maybe_parse_content(
194    *,
195    response_format: type[ResponseFormatT] | ResponseFormatParam | Omit,
196    message: ChatCompletionMessage | ParsedChatCompletionMessage[object],
197) -> ResponseFormatT | None:
198    if has_rich_response_format(response_format) and message.content and not message.refusal:
199        return _parse_content(response_format, message.content)
200
201    return None
202
203
204def solve_response_format_t(
205    response_format: type[ResponseFormatT] | ResponseFormatParam | Omit,
206) -> type[ResponseFormatT]:
207    """Return the runtime type for the given response format.
208
209    If no response format is given, or if we won't auto-parse the response format
210    then we default to `None`.
211    """
212    if has_rich_response_format(response_format):
213        return response_format
214
215    return cast("type[ResponseFormatT]", _default_response_format)
216
217
218def has_parseable_input(
219    *,
220    response_format: type | ResponseFormatParam | Omit,
221    input_tools: Iterable[ChatCompletionToolUnionParam] | Omit = omit,
222) -> bool:
223    if has_rich_response_format(response_format):
224        return True
225
226    for input_tool in input_tools or []:
227        if is_parseable_tool(input_tool):
228            return True
229
230    return False
231
232
233def has_rich_response_format(
234    response_format: type[ResponseFormatT] | ResponseFormatParam | Omit,
235) -> TypeGuard[type[ResponseFormatT]]:
236    if not is_given(response_format):
237        return False
238
239    if is_response_format_param(response_format):
240        return False
241
242    return True
243
244
245def is_response_format_param(response_format: object) -> TypeGuard[ResponseFormatParam]:
246    return is_dict(response_format)
247
248
249def is_parseable_tool(input_tool: ChatCompletionToolUnionParam) -> bool:
250    if input_tool["type"] != "function":
251        return False
252
253    input_fn = cast(object, input_tool.get("function"))
254    if isinstance(input_fn, PydanticFunctionTool):
255        return True
256
257    return cast(FunctionDefinition, input_fn).get("strict") or False
258
259
260def _parse_content(response_format: type[ResponseFormatT], content: str) -> ResponseFormatT:
261    if is_basemodel_type(response_format):
262        return cast(ResponseFormatT, model_parse_json(response_format, content))
263
264    if is_dataclass_like_type(response_format):
265        if PYDANTIC_V1:
266            raise TypeError(f"Non BaseModel types are only supported with Pydantic v2 - {response_format}")
267
268        return pydantic.TypeAdapter(response_format).validate_json(content)
269
270    raise TypeError(f"Unable to automatically parse response format type {response_format}")
271
272
273def type_to_response_format_param(
274    response_format: type | completion_create_params.ResponseFormat | Omit,
275) -> ResponseFormatParam | Omit:
276    if not is_given(response_format):
277        return omit
278
279    if is_response_format_param(response_format):
280        return response_format
281
282    # type checkers don't narrow the negation of a `TypeGuard` as it isn't
283    # a safe default behaviour but we know that at this point the `response_format`
284    # can only be a `type`
285    response_format = cast(type, response_format)
286
287    json_schema_type: type[pydantic.BaseModel] | pydantic.TypeAdapter[Any] | None = None
288
289    if is_basemodel_type(response_format):
290        name = response_format.__name__
291        json_schema_type = response_format
292    elif is_dataclass_like_type(response_format):
293        name = response_format.__name__
294        json_schema_type = pydantic.TypeAdapter(response_format)
295    else:
296        raise TypeError(f"Unsupported response_format type - {response_format}")
297
298    return {
299        "type": "json_schema",
300        "json_schema": {
301            "schema": to_strict_json_schema(json_schema_type),
302            "name": name,
303            "strict": True,
304        },
305    }