Commit ecd6e922

Robert Craigie <robert@craigie.dev>
2024-08-20 23:09:41
feat(parsing): add support for pydantic dataclasses (#1655)
1 parent e8c28f2
Changed files (3)
src
openai
tests
src/openai/lib/_parsing/_completions.py
@@ -9,9 +9,9 @@ import pydantic
 from .._tools import PydanticFunctionTool
 from ..._types import NOT_GIVEN, NotGiven
 from ..._utils import is_dict, is_given
-from ..._compat import model_parse_json
+from ..._compat import PYDANTIC_V2, model_parse_json
 from ..._models import construct_type_unchecked
-from .._pydantic import to_strict_json_schema
+from .._pydantic import is_basemodel_type, to_strict_json_schema, is_dataclass_like_type
 from ...types.chat import (
     ParsedChoice,
     ChatCompletion,
@@ -216,14 +216,16 @@ def is_parseable_tool(input_tool: ChatCompletionToolParam) -> bool:
     return cast(FunctionDefinition, input_fn).get("strict") or False
 
 
-def is_basemodel_type(typ: type) -> TypeGuard[type[pydantic.BaseModel]]:
-    return issubclass(typ, pydantic.BaseModel)
-
-
 def _parse_content(response_format: type[ResponseFormatT], content: str) -> ResponseFormatT:
     if is_basemodel_type(response_format):
         return cast(ResponseFormatT, model_parse_json(response_format, content))
 
+    if is_dataclass_like_type(response_format):
+        if not PYDANTIC_V2:
+            raise TypeError(f"Non BaseModel types are only supported with Pydantic v2 - {response_format}")
+
+        return pydantic.TypeAdapter(response_format).validate_json(content)
+
     raise TypeError(f"Unable to automatically parse response format type {response_format}")
 
 
@@ -241,14 +243,22 @@ def type_to_response_format_param(
     # can only be a `type`
     response_format = cast(type, response_format)
 
-    if not is_basemodel_type(response_format):
+    json_schema_type: type[pydantic.BaseModel] | pydantic.TypeAdapter[Any] | None = None
+
+    if is_basemodel_type(response_format):
+        name = response_format.__name__
+        json_schema_type = response_format
+    elif is_dataclass_like_type(response_format):
+        name = response_format.__name__
+        json_schema_type = pydantic.TypeAdapter(response_format)
+    else:
         raise TypeError(f"Unsupported response_format type - {response_format}")
 
     return {
         "type": "json_schema",
         "json_schema": {
-            "schema": to_strict_json_schema(response_format),
-            "name": response_format.__name__,
+            "schema": to_strict_json_schema(json_schema_type),
+            "name": name,
             "strict": True,
         },
     }
src/openai/lib/_pydantic.py
@@ -1,17 +1,26 @@
 from __future__ import annotations
 
-from typing import Any
+import inspect
+from typing import Any, TypeVar
 from typing_extensions import TypeGuard
 
 import pydantic
 
 from .._types import NOT_GIVEN
 from .._utils import is_dict as _is_dict, is_list
-from .._compat import model_json_schema
+from .._compat import PYDANTIC_V2, model_json_schema
 
+_T = TypeVar("_T")
+
+
+def to_strict_json_schema(model: type[pydantic.BaseModel] | pydantic.TypeAdapter[Any]) -> dict[str, Any]:
+    if inspect.isclass(model) and is_basemodel_type(model):
+        schema = model_json_schema(model)
+    elif PYDANTIC_V2 and isinstance(model, pydantic.TypeAdapter):
+        schema = model.json_schema()
+    else:
+        raise TypeError(f"Non BaseModel types are only supported with Pydantic v2 - {model}")
 
-def to_strict_json_schema(model: type[pydantic.BaseModel]) -> dict[str, Any]:
-    schema = model_json_schema(model)
     return _ensure_strict_json_schema(schema, path=(), root=schema)
 
 
@@ -117,6 +126,15 @@ def resolve_ref(*, root: dict[str, object], ref: str) -> object:
     return resolved
 
 
+def is_basemodel_type(typ: type) -> TypeGuard[type[pydantic.BaseModel]]:
+    return issubclass(typ, pydantic.BaseModel)
+
+
+def is_dataclass_like_type(typ: type) -> bool:
+    """Returns True if the given type likely used `@pydantic.dataclass`"""
+    return hasattr(typ, "__pydantic_config__")
+
+
 def is_dict(obj: object) -> TypeGuard[dict[str, object]]:
     # just pretend that we know there are only `str` keys
     # as that check is not worth the performance cost
tests/lib/chat/test_completions.py
@@ -3,7 +3,7 @@ from __future__ import annotations
 import os
 import json
 from enum import Enum
-from typing import Any, Callable, Optional
+from typing import Any, List, Callable, Optional
 from typing_extensions import Literal, TypeVar
 
 import httpx
@@ -317,6 +317,63 @@ def test_parse_pydantic_model_multiple_choices(
     )
 
 
+@pytest.mark.respx(base_url=base_url)
+@pytest.mark.skipif(not PYDANTIC_V2, reason="dataclasses only supported in v2")
+def test_parse_pydantic_dataclass(client: OpenAI, respx_mock: MockRouter, monkeypatch: pytest.MonkeyPatch) -> None:
+    from pydantic.dataclasses import dataclass
+
+    @dataclass
+    class CalendarEvent:
+        name: str
+        date: str
+        participants: List[str]
+
+    completion = _make_snapshot_request(
+        lambda c: c.beta.chat.completions.parse(
+            model="gpt-4o-2024-08-06",
+            messages=[
+                {"role": "system", "content": "Extract the event information."},
+                {"role": "user", "content": "Alice and Bob are going to a science fair on Friday."},
+            ],
+            response_format=CalendarEvent,
+        ),
+        content_snapshot=snapshot(
+            '{"id": "chatcmpl-9wdGqXkJJARAz7rOrLH5u5FBwLjF3", "object": "chat.completion", "created": 1723761008, "model": "gpt-4o-2024-08-06", "choices": [{"index": 0, "message": {"role": "assistant", "content": "{\\"name\\":\\"Science Fair\\",\\"date\\":\\"Friday\\",\\"participants\\":[\\"Alice\\",\\"Bob\\"]}", "refusal": null}, "logprobs": null, "finish_reason": "stop"}], "usage": {"prompt_tokens": 32, "completion_tokens": 17, "total_tokens": 49}, "system_fingerprint": "fp_2a322c9ffc"}'
+        ),
+        mock_client=client,
+        respx_mock=respx_mock,
+    )
+
+    assert print_obj(completion, monkeypatch) == snapshot(
+        """\
+ParsedChatCompletion[CalendarEvent](
+    choices=[
+        ParsedChoice[CalendarEvent](
+            finish_reason='stop',
+            index=0,
+            logprobs=None,
+            message=ParsedChatCompletionMessage[CalendarEvent](
+                content='{"name":"Science Fair","date":"Friday","participants":["Alice","Bob"]}',
+                function_call=None,
+                parsed=CalendarEvent(name='Science Fair', date='Friday', participants=['Alice', 'Bob']),
+                refusal=None,
+                role='assistant',
+                tool_calls=[]
+            )
+        )
+    ],
+    created=1723761008,
+    id='chatcmpl-9wdGqXkJJARAz7rOrLH5u5FBwLjF3',
+    model='gpt-4o-2024-08-06',
+    object='chat.completion',
+    service_tier=None,
+    system_fingerprint='fp_2a322c9ffc',
+    usage=CompletionUsage(completion_tokens=17, prompt_tokens=32, total_tokens=49)
+)
+"""
+    )
+
+
 @pytest.mark.respx(base_url=base_url)
 def test_pydantic_tool_model_all_types(client: OpenAI, respx_mock: MockRouter, monkeypatch: pytest.MonkeyPatch) -> None:
     completion = _make_snapshot_request(