Commit ecd6e922
Changed files (3)
src
openai
lib
_parsing
tests
lib
src/openai/lib/_parsing/_completions.py
@@ -9,9 +9,9 @@ import pydantic
from .._tools import PydanticFunctionTool
from ..._types import NOT_GIVEN, NotGiven
from ..._utils import is_dict, is_given
-from ..._compat import model_parse_json
+from ..._compat import PYDANTIC_V2, model_parse_json
from ..._models import construct_type_unchecked
-from .._pydantic import to_strict_json_schema
+from .._pydantic import is_basemodel_type, to_strict_json_schema, is_dataclass_like_type
from ...types.chat import (
ParsedChoice,
ChatCompletion,
@@ -216,14 +216,16 @@ def is_parseable_tool(input_tool: ChatCompletionToolParam) -> bool:
return cast(FunctionDefinition, input_fn).get("strict") or False
-def is_basemodel_type(typ: type) -> TypeGuard[type[pydantic.BaseModel]]:
- return issubclass(typ, pydantic.BaseModel)
-
-
def _parse_content(response_format: type[ResponseFormatT], content: str) -> ResponseFormatT:
if is_basemodel_type(response_format):
return cast(ResponseFormatT, model_parse_json(response_format, content))
+ if is_dataclass_like_type(response_format):
+ if not PYDANTIC_V2:
+ raise TypeError(f"Non BaseModel types are only supported with Pydantic v2 - {response_format}")
+
+ return pydantic.TypeAdapter(response_format).validate_json(content)
+
raise TypeError(f"Unable to automatically parse response format type {response_format}")
@@ -241,14 +243,22 @@ def type_to_response_format_param(
# can only be a `type`
response_format = cast(type, response_format)
- if not is_basemodel_type(response_format):
+ json_schema_type: type[pydantic.BaseModel] | pydantic.TypeAdapter[Any] | None = None
+
+ if is_basemodel_type(response_format):
+ name = response_format.__name__
+ json_schema_type = response_format
+ elif is_dataclass_like_type(response_format):
+ name = response_format.__name__
+ json_schema_type = pydantic.TypeAdapter(response_format)
+ else:
raise TypeError(f"Unsupported response_format type - {response_format}")
return {
"type": "json_schema",
"json_schema": {
- "schema": to_strict_json_schema(response_format),
- "name": response_format.__name__,
+ "schema": to_strict_json_schema(json_schema_type),
+ "name": name,
"strict": True,
},
}
src/openai/lib/_pydantic.py
@@ -1,17 +1,26 @@
from __future__ import annotations
-from typing import Any
+import inspect
+from typing import Any, TypeVar
from typing_extensions import TypeGuard
import pydantic
from .._types import NOT_GIVEN
from .._utils import is_dict as _is_dict, is_list
-from .._compat import model_json_schema
+from .._compat import PYDANTIC_V2, model_json_schema
+_T = TypeVar("_T")
+
+
+def to_strict_json_schema(model: type[pydantic.BaseModel] | pydantic.TypeAdapter[Any]) -> dict[str, Any]:
+ if inspect.isclass(model) and is_basemodel_type(model):
+ schema = model_json_schema(model)
+ elif PYDANTIC_V2 and isinstance(model, pydantic.TypeAdapter):
+ schema = model.json_schema()
+ else:
+ raise TypeError(f"Non BaseModel types are only supported with Pydantic v2 - {model}")
-def to_strict_json_schema(model: type[pydantic.BaseModel]) -> dict[str, Any]:
- schema = model_json_schema(model)
return _ensure_strict_json_schema(schema, path=(), root=schema)
@@ -117,6 +126,15 @@ def resolve_ref(*, root: dict[str, object], ref: str) -> object:
return resolved
+def is_basemodel_type(typ: type) -> TypeGuard[type[pydantic.BaseModel]]:
+ return issubclass(typ, pydantic.BaseModel)
+
+
+def is_dataclass_like_type(typ: type) -> bool:
+ """Returns True if the given type likely used `@pydantic.dataclass`"""
+ return hasattr(typ, "__pydantic_config__")
+
+
def is_dict(obj: object) -> TypeGuard[dict[str, object]]:
# just pretend that we know there are only `str` keys
# as that check is not worth the performance cost
tests/lib/chat/test_completions.py
@@ -3,7 +3,7 @@ from __future__ import annotations
import os
import json
from enum import Enum
-from typing import Any, Callable, Optional
+from typing import Any, List, Callable, Optional
from typing_extensions import Literal, TypeVar
import httpx
@@ -317,6 +317,63 @@ def test_parse_pydantic_model_multiple_choices(
)
+@pytest.mark.respx(base_url=base_url)
+@pytest.mark.skipif(not PYDANTIC_V2, reason="dataclasses only supported in v2")
+def test_parse_pydantic_dataclass(client: OpenAI, respx_mock: MockRouter, monkeypatch: pytest.MonkeyPatch) -> None:
+ from pydantic.dataclasses import dataclass
+
+ @dataclass
+ class CalendarEvent:
+ name: str
+ date: str
+ participants: List[str]
+
+ completion = _make_snapshot_request(
+ lambda c: c.beta.chat.completions.parse(
+ model="gpt-4o-2024-08-06",
+ messages=[
+ {"role": "system", "content": "Extract the event information."},
+ {"role": "user", "content": "Alice and Bob are going to a science fair on Friday."},
+ ],
+ response_format=CalendarEvent,
+ ),
+ content_snapshot=snapshot(
+ '{"id": "chatcmpl-9wdGqXkJJARAz7rOrLH5u5FBwLjF3", "object": "chat.completion", "created": 1723761008, "model": "gpt-4o-2024-08-06", "choices": [{"index": 0, "message": {"role": "assistant", "content": "{\\"name\\":\\"Science Fair\\",\\"date\\":\\"Friday\\",\\"participants\\":[\\"Alice\\",\\"Bob\\"]}", "refusal": null}, "logprobs": null, "finish_reason": "stop"}], "usage": {"prompt_tokens": 32, "completion_tokens": 17, "total_tokens": 49}, "system_fingerprint": "fp_2a322c9ffc"}'
+ ),
+ mock_client=client,
+ respx_mock=respx_mock,
+ )
+
+ assert print_obj(completion, monkeypatch) == snapshot(
+ """\
+ParsedChatCompletion[CalendarEvent](
+ choices=[
+ ParsedChoice[CalendarEvent](
+ finish_reason='stop',
+ index=0,
+ logprobs=None,
+ message=ParsedChatCompletionMessage[CalendarEvent](
+ content='{"name":"Science Fair","date":"Friday","participants":["Alice","Bob"]}',
+ function_call=None,
+ parsed=CalendarEvent(name='Science Fair', date='Friday', participants=['Alice', 'Bob']),
+ refusal=None,
+ role='assistant',
+ tool_calls=[]
+ )
+ )
+ ],
+ created=1723761008,
+ id='chatcmpl-9wdGqXkJJARAz7rOrLH5u5FBwLjF3',
+ model='gpt-4o-2024-08-06',
+ object='chat.completion',
+ service_tier=None,
+ system_fingerprint='fp_2a322c9ffc',
+ usage=CompletionUsage(completion_tokens=17, prompt_tokens=32, total_tokens=49)
+)
+"""
+ )
+
+
@pytest.mark.respx(base_url=base_url)
def test_pydantic_tool_model_all_types(client: OpenAI, respx_mock: MockRouter, monkeypatch: pytest.MonkeyPatch) -> None:
completion = _make_snapshot_request(