main
  1from __future__ import annotations
  2
  3import json
  4from typing import TYPE_CHECKING, Any, List, Iterable, cast
  5from typing_extensions import TypeVar, assert_never
  6
  7import pydantic
  8
  9from .._tools import ResponsesPydanticFunctionTool
 10from ..._types import Omit
 11from ..._utils import is_given
 12from ..._compat import PYDANTIC_V1, model_parse_json
 13from ..._models import construct_type_unchecked
 14from .._pydantic import is_basemodel_type, is_dataclass_like_type
 15from ._completions import solve_response_format_t, type_to_response_format_param
 16from ...types.responses import (
 17    Response,
 18    ToolParam,
 19    ParsedContent,
 20    ParsedResponse,
 21    FunctionToolParam,
 22    ParsedResponseOutputItem,
 23    ParsedResponseOutputText,
 24    ResponseFunctionToolCall,
 25    ParsedResponseOutputMessage,
 26    ResponseFormatTextConfigParam,
 27    ParsedResponseFunctionToolCall,
 28)
 29from ...types.chat.completion_create_params import ResponseFormat
 30
 31TextFormatT = TypeVar(
 32    "TextFormatT",
 33    # if it isn't given then we don't do any parsing
 34    default=None,
 35)
 36
 37
 38def type_to_text_format_param(type_: type) -> ResponseFormatTextConfigParam:
 39    response_format_dict = type_to_response_format_param(type_)
 40    assert is_given(response_format_dict)
 41    response_format_dict = cast(ResponseFormat, response_format_dict)  # pyright: ignore[reportUnnecessaryCast]
 42    assert response_format_dict["type"] == "json_schema"
 43    assert "schema" in response_format_dict["json_schema"]
 44
 45    return {
 46        "type": "json_schema",
 47        "strict": True,
 48        "name": response_format_dict["json_schema"]["name"],
 49        "schema": response_format_dict["json_schema"]["schema"],
 50    }
 51
 52
 53def parse_response(
 54    *,
 55    text_format: type[TextFormatT] | Omit,
 56    input_tools: Iterable[ToolParam] | Omit | None,
 57    response: Response | ParsedResponse[object],
 58) -> ParsedResponse[TextFormatT]:
 59    solved_t = solve_response_format_t(text_format)
 60    output_list: List[ParsedResponseOutputItem[TextFormatT]] = []
 61
 62    for output in response.output:
 63        if output.type == "message":
 64            content_list: List[ParsedContent[TextFormatT]] = []
 65            for item in output.content:
 66                if item.type != "output_text":
 67                    content_list.append(item)
 68                    continue
 69
 70                content_list.append(
 71                    construct_type_unchecked(
 72                        type_=cast(Any, ParsedResponseOutputText)[solved_t],
 73                        value={
 74                            **item.to_dict(),
 75                            "parsed": parse_text(item.text, text_format=text_format),
 76                        },
 77                    )
 78                )
 79
 80            output_list.append(
 81                construct_type_unchecked(
 82                    type_=cast(Any, ParsedResponseOutputMessage)[solved_t],
 83                    value={
 84                        **output.to_dict(),
 85                        "content": content_list,
 86                    },
 87                )
 88            )
 89        elif output.type == "function_call":
 90            output_list.append(
 91                construct_type_unchecked(
 92                    type_=ParsedResponseFunctionToolCall,
 93                    value={
 94                        **output.to_dict(),
 95                        "parsed_arguments": parse_function_tool_arguments(
 96                            input_tools=input_tools, function_call=output
 97                        ),
 98                    },
 99                )
100            )
101        elif (
102            output.type == "computer_call"
103            or output.type == "file_search_call"
104            or output.type == "web_search_call"
105            or output.type == "reasoning"
106            or output.type == "compaction"
107            or output.type == "mcp_call"
108            or output.type == "mcp_approval_request"
109            or output.type == "image_generation_call"
110            or output.type == "code_interpreter_call"
111            or output.type == "local_shell_call"
112            or output.type == "shell_call"
113            or output.type == "shell_call_output"
114            or output.type == "apply_patch_call"
115            or output.type == "apply_patch_call_output"
116            or output.type == "mcp_list_tools"
117            or output.type == "exec"
118            or output.type == "custom_tool_call"
119        ):
120            output_list.append(output)
121        elif TYPE_CHECKING:  # type: ignore
122            assert_never(output)
123        else:
124            output_list.append(output)
125
126    return cast(
127        ParsedResponse[TextFormatT],
128        construct_type_unchecked(
129            type_=cast(Any, ParsedResponse)[solved_t],
130            value={
131                **response.to_dict(),
132                "output": output_list,
133            },
134        ),
135    )
136
137
138def parse_text(text: str, text_format: type[TextFormatT] | Omit) -> TextFormatT | None:
139    if not is_given(text_format):
140        return None
141
142    if is_basemodel_type(text_format):
143        return cast(TextFormatT, model_parse_json(text_format, text))
144
145    if is_dataclass_like_type(text_format):
146        if PYDANTIC_V1:
147            raise TypeError(f"Non BaseModel types are only supported with Pydantic v2 - {text_format}")
148
149        return pydantic.TypeAdapter(text_format).validate_json(text)
150
151    raise TypeError(f"Unable to automatically parse response format type {text_format}")
152
153
154def get_input_tool_by_name(*, input_tools: Iterable[ToolParam], name: str) -> FunctionToolParam | None:
155    for tool in input_tools:
156        if tool["type"] == "function" and tool.get("name") == name:
157            return tool
158
159    return None
160
161
162def parse_function_tool_arguments(
163    *,
164    input_tools: Iterable[ToolParam] | Omit | None,
165    function_call: ParsedResponseFunctionToolCall | ResponseFunctionToolCall,
166) -> object:
167    if input_tools is None or not is_given(input_tools):
168        return None
169
170    input_tool = get_input_tool_by_name(input_tools=input_tools, name=function_call.name)
171    if not input_tool:
172        return None
173
174    tool = cast(object, input_tool)
175    if isinstance(tool, ResponsesPydanticFunctionTool):
176        return model_parse_json(tool.model, function_call.arguments)
177
178    if not input_tool.get("strict"):
179        return None
180
181    return json.loads(function_call.arguments)