Commit 5f52e473

Robert Craigie <robert@craigie.dev>
2024-08-07 17:13:08
fix(client): raise helpful error message for response_format misuse
1 parent 631a2a7
Changed files (2)
src
openai
resources
tests
api_resources
src/openai/resources/chat/completions.py
@@ -2,10 +2,12 @@
 
 from __future__ import annotations
 
+import inspect
 from typing import Dict, List, Union, Iterable, Optional, overload
 from typing_extensions import Literal
 
 import httpx
+import pydantic
 
 from ... import _legacy_response
 from ..._types import NOT_GIVEN, Body, Query, Headers, NotGiven
@@ -647,6 +649,7 @@ class Completions(SyncAPIResource):
         extra_body: Body | None = None,
         timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
     ) -> ChatCompletion | Stream[ChatCompletionChunk]:
+        validate_response_format(response_format)
         return self._post(
             "/chat/completions",
             body=maybe_transform(
@@ -1302,6 +1305,7 @@ class AsyncCompletions(AsyncAPIResource):
         extra_body: Body | None = None,
         timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
     ) -> ChatCompletion | AsyncStream[ChatCompletionChunk]:
+        validate_response_format(response_format)
         return await self._post(
             "/chat/completions",
             body=await async_maybe_transform(
@@ -1375,3 +1379,10 @@ class AsyncCompletionsWithStreamingResponse:
         self.create = async_to_streamed_response_wrapper(
             completions.create,
         )
+
+
+def validate_response_format(response_format: object) -> None:
+    if inspect.isclass(response_format) and issubclass(response_format, pydantic.BaseModel):
+        raise TypeError(
+            "You tried to pass a `BaseModel` class to `chat.completions.create()`; You must use `beta.chat.completions.parse()` instead"
+        )
tests/api_resources/chat/test_completions.py
@@ -6,6 +6,7 @@ import os
 from typing import Any, cast
 
 import pytest
+import pydantic
 
 from openai import OpenAI, AsyncOpenAI
 from tests.utils import assert_matches_type
@@ -257,6 +258,23 @@ class TestCompletions:
 
         assert cast(Any, response.is_closed) is True
 
+    @parametrize
+    def test_method_create_disallows_pydantic(self, client: OpenAI) -> None:
+        class MyModel(pydantic.BaseModel):
+            a: str
+
+        with pytest.raises(TypeError, match=r"You tried to pass a `BaseModel` class"):
+            client.chat.completions.create(
+                messages=[
+                    {
+                        "content": "string",
+                        "role": "system",
+                    }
+                ],
+                model="gpt-4o",
+                response_format=cast(Any, MyModel),
+            )
+
 
 class TestAsyncCompletions:
     parametrize = pytest.mark.parametrize("async_client", [False, True], indirect=True, ids=["loose", "strict"])
@@ -498,3 +516,20 @@ class TestAsyncCompletions:
             await stream.close()
 
         assert cast(Any, response.is_closed) is True
+
+    @parametrize
+    async def test_method_create_disallows_pydantic(self, async_client: AsyncOpenAI) -> None:
+        class MyModel(pydantic.BaseModel):
+            a: str
+
+        with pytest.raises(TypeError, match=r"You tried to pass a `BaseModel` class"):
+            await async_client.chat.completions.create(
+                messages=[
+                    {
+                        "content": "string",
+                        "role": "system",
+                    }
+                ],
+                model="gpt-4o",
+                response_format=cast(Any, MyModel),
+            )