Commit 152a8bda

Robert Craigie <robert@craigie.dev>
2024-09-27 06:32:48
feat(structured outputs): add support for accessing raw responses (#1748)
1 parent e1aeeb0
Changed files (2)
src
openai
resources
beta
tests
src/openai/resources/beta/chat/completions.py
@@ -2,16 +2,21 @@
 
 from __future__ import annotations
 
-from typing import Dict, List, Union, Iterable, Optional
+from typing import Dict, List, Type, Union, Iterable, Optional, cast
 from functools import partial
 from typing_extensions import Literal
 
 import httpx
 
+from .... import _legacy_response
 from ...._types import NOT_GIVEN, Body, Query, Headers, NotGiven
+from ...._utils import maybe_transform, async_maybe_transform
+from ...._compat import cached_property
 from ...._resource import SyncAPIResource, AsyncAPIResource
+from ...._response import to_streamed_response_wrapper, async_to_streamed_response_wrapper
 from ...._streaming import Stream
 from ....types.chat import completion_create_params
+from ...._base_client import make_request_options
 from ....lib._parsing import (
     ResponseFormatT,
     validate_input_tools as _validate_input_tools,
@@ -20,6 +25,7 @@ from ....lib._parsing import (
 )
 from ....types.chat_model import ChatModel
 from ....lib.streaming.chat import ChatCompletionStreamManager, AsyncChatCompletionStreamManager
+from ....types.chat.chat_completion import ChatCompletion
 from ....types.chat.chat_completion_chunk import ChatCompletionChunk
 from ....types.chat.parsed_chat_completion import ParsedChatCompletion
 from ....types.chat.chat_completion_tool_param import ChatCompletionToolParam
@@ -31,6 +37,25 @@ __all__ = ["Completions", "AsyncCompletions"]
 
 
 class Completions(SyncAPIResource):
+    @cached_property
+    def with_raw_response(self) -> CompletionsWithRawResponse:
+        """
+        This property can be used as a prefix for any HTTP method call to return the
+        the raw response object instead of the parsed content.
+
+        For more information, see https://www.github.com/openai/openai-python#accessing-raw-response-data-eg-headers
+        """
+        return CompletionsWithRawResponse(self)
+
+    @cached_property
+    def with_streaming_response(self) -> CompletionsWithStreamingResponse:
+        """
+        An alternative to `.with_raw_response` that doesn't eagerly read the response body.
+
+        For more information, see https://www.github.com/openai/openai-python#with_streaming_response
+        """
+        return CompletionsWithStreamingResponse(self)
+
     def parse(
         self,
         *,
@@ -113,39 +138,55 @@ class Completions(SyncAPIResource):
             **(extra_headers or {}),
         }
 
-        raw_completion = self._client.chat.completions.create(
-            messages=messages,
-            model=model,
-            response_format=_type_to_response_format(response_format),
-            frequency_penalty=frequency_penalty,
-            function_call=function_call,
-            functions=functions,
-            logit_bias=logit_bias,
-            logprobs=logprobs,
-            max_completion_tokens=max_completion_tokens,
-            max_tokens=max_tokens,
-            n=n,
-            parallel_tool_calls=parallel_tool_calls,
-            presence_penalty=presence_penalty,
-            seed=seed,
-            service_tier=service_tier,
-            stop=stop,
-            stream_options=stream_options,
-            temperature=temperature,
-            tool_choice=tool_choice,
-            tools=tools,
-            top_logprobs=top_logprobs,
-            top_p=top_p,
-            user=user,
-            extra_headers=extra_headers,
-            extra_query=extra_query,
-            extra_body=extra_body,
-            timeout=timeout,
-        )
-        return _parse_chat_completion(
-            response_format=response_format,
-            chat_completion=raw_completion,
-            input_tools=tools,
+        def parser(raw_completion: ChatCompletion) -> ParsedChatCompletion[ResponseFormatT]:
+            return _parse_chat_completion(
+                response_format=response_format,
+                chat_completion=raw_completion,
+                input_tools=tools,
+            )
+
+        return self._post(
+            "/chat/completions",
+            body=maybe_transform(
+                {
+                    "messages": messages,
+                    "model": model,
+                    "frequency_penalty": frequency_penalty,
+                    "function_call": function_call,
+                    "functions": functions,
+                    "logit_bias": logit_bias,
+                    "logprobs": logprobs,
+                    "max_completion_tokens": max_completion_tokens,
+                    "max_tokens": max_tokens,
+                    "n": n,
+                    "parallel_tool_calls": parallel_tool_calls,
+                    "presence_penalty": presence_penalty,
+                    "response_format": _type_to_response_format(response_format),
+                    "seed": seed,
+                    "service_tier": service_tier,
+                    "stop": stop,
+                    "stream": False,
+                    "stream_options": stream_options,
+                    "temperature": temperature,
+                    "tool_choice": tool_choice,
+                    "tools": tools,
+                    "top_logprobs": top_logprobs,
+                    "top_p": top_p,
+                    "user": user,
+                },
+                completion_create_params.CompletionCreateParams,
+            ),
+            options=make_request_options(
+                extra_headers=extra_headers,
+                extra_query=extra_query,
+                extra_body=extra_body,
+                timeout=timeout,
+                post_parser=parser,
+            ),
+            # we turn the `ChatCompletion` instance into a `ParsedChatCompletion`
+            # in the `parser` function above
+            cast_to=cast(Type[ParsedChatCompletion[ResponseFormatT]], ChatCompletion),
+            stream=False,
         )
 
     def stream(
@@ -247,6 +288,25 @@ class Completions(SyncAPIResource):
 
 
 class AsyncCompletions(AsyncAPIResource):
+    @cached_property
+    def with_raw_response(self) -> AsyncCompletionsWithRawResponse:
+        """
+        This property can be used as a prefix for any HTTP method call to return the
+        the raw response object instead of the parsed content.
+
+        For more information, see https://www.github.com/openai/openai-python#accessing-raw-response-data-eg-headers
+        """
+        return AsyncCompletionsWithRawResponse(self)
+
+    @cached_property
+    def with_streaming_response(self) -> AsyncCompletionsWithStreamingResponse:
+        """
+        An alternative to `.with_raw_response` that doesn't eagerly read the response body.
+
+        For more information, see https://www.github.com/openai/openai-python#with_streaming_response
+        """
+        return AsyncCompletionsWithStreamingResponse(self)
+
     async def parse(
         self,
         *,
@@ -329,39 +389,55 @@ class AsyncCompletions(AsyncAPIResource):
             **(extra_headers or {}),
         }
 
-        raw_completion = await self._client.chat.completions.create(
-            messages=messages,
-            model=model,
-            response_format=_type_to_response_format(response_format),
-            frequency_penalty=frequency_penalty,
-            function_call=function_call,
-            functions=functions,
-            logit_bias=logit_bias,
-            logprobs=logprobs,
-            max_completion_tokens=max_completion_tokens,
-            max_tokens=max_tokens,
-            n=n,
-            parallel_tool_calls=parallel_tool_calls,
-            presence_penalty=presence_penalty,
-            seed=seed,
-            service_tier=service_tier,
-            stop=stop,
-            stream_options=stream_options,
-            temperature=temperature,
-            tool_choice=tool_choice,
-            tools=tools,
-            top_logprobs=top_logprobs,
-            top_p=top_p,
-            user=user,
-            extra_headers=extra_headers,
-            extra_query=extra_query,
-            extra_body=extra_body,
-            timeout=timeout,
-        )
-        return _parse_chat_completion(
-            response_format=response_format,
-            chat_completion=raw_completion,
-            input_tools=tools,
+        def parser(raw_completion: ChatCompletion) -> ParsedChatCompletion[ResponseFormatT]:
+            return _parse_chat_completion(
+                response_format=response_format,
+                chat_completion=raw_completion,
+                input_tools=tools,
+            )
+
+        return await self._post(
+            "/chat/completions",
+            body=await async_maybe_transform(
+                {
+                    "messages": messages,
+                    "model": model,
+                    "frequency_penalty": frequency_penalty,
+                    "function_call": function_call,
+                    "functions": functions,
+                    "logit_bias": logit_bias,
+                    "logprobs": logprobs,
+                    "max_completion_tokens": max_completion_tokens,
+                    "max_tokens": max_tokens,
+                    "n": n,
+                    "parallel_tool_calls": parallel_tool_calls,
+                    "presence_penalty": presence_penalty,
+                    "response_format": _type_to_response_format(response_format),
+                    "seed": seed,
+                    "service_tier": service_tier,
+                    "stop": stop,
+                    "stream": False,
+                    "stream_options": stream_options,
+                    "temperature": temperature,
+                    "tool_choice": tool_choice,
+                    "tools": tools,
+                    "top_logprobs": top_logprobs,
+                    "top_p": top_p,
+                    "user": user,
+                },
+                completion_create_params.CompletionCreateParams,
+            ),
+            options=make_request_options(
+                extra_headers=extra_headers,
+                extra_query=extra_query,
+                extra_body=extra_body,
+                timeout=timeout,
+                post_parser=parser,
+            ),
+            # we turn the `ChatCompletion` instance into a `ParsedChatCompletion`
+            # in the `parser` function above
+            cast_to=cast(Type[ParsedChatCompletion[ResponseFormatT]], ChatCompletion),
+            stream=False,
         )
 
     def stream(
@@ -461,3 +537,39 @@ class AsyncCompletions(AsyncAPIResource):
             response_format=response_format,
             input_tools=tools,
         )
+
+
+class CompletionsWithRawResponse:
+    def __init__(self, completions: Completions) -> None:
+        self._completions = completions
+
+        self.parse = _legacy_response.to_raw_response_wrapper(
+            completions.parse,
+        )
+
+
+class AsyncCompletionsWithRawResponse:
+    def __init__(self, completions: AsyncCompletions) -> None:
+        self._completions = completions
+
+        self.parse = _legacy_response.async_to_raw_response_wrapper(
+            completions.parse,
+        )
+
+
+class CompletionsWithStreamingResponse:
+    def __init__(self, completions: Completions) -> None:
+        self._completions = completions
+
+        self.parse = to_streamed_response_wrapper(
+            completions.parse,
+        )
+
+
+class AsyncCompletionsWithStreamingResponse:
+    def __init__(self, completions: AsyncCompletions) -> None:
+        self._completions = completions
+
+        self.parse = async_to_streamed_response_wrapper(
+            completions.parse,
+        )
tests/lib/chat/test_completions.py
@@ -3,7 +3,7 @@ from __future__ import annotations
 import os
 import json
 from enum import Enum
-from typing import Any, List, Callable, Optional
+from typing import Any, List, Callable, Optional, Awaitable
 from typing_extensions import Literal, TypeVar
 
 import httpx
@@ -773,6 +773,139 @@ def test_parse_non_strict_tools(client: OpenAI) -> None:
         )
 
 
+@pytest.mark.respx(base_url=base_url)
+def test_parse_pydantic_raw_response(client: OpenAI, respx_mock: MockRouter, monkeypatch: pytest.MonkeyPatch) -> None:
+    class Location(BaseModel):
+        city: str
+        temperature: float
+        units: Literal["c", "f"]
+
+    response = _make_snapshot_request(
+        lambda c: c.beta.chat.completions.with_raw_response.parse(
+            model="gpt-4o-2024-08-06",
+            messages=[
+                {
+                    "role": "user",
+                    "content": "What's the weather like in SF?",
+                },
+            ],
+            response_format=Location,
+        ),
+        content_snapshot=snapshot(
+            '{"id": "chatcmpl-ABrDYCa8W1w66eUxKDO8TQF1m6trT", "object": "chat.completion", "created": 1727389540, "model": "gpt-4o-2024-08-06", "choices": [{"index": 0, "message": {"role": "assistant", "content": "{\\"city\\":\\"San Francisco\\",\\"temperature\\":58,\\"units\\":\\"f\\"}", "refusal": null}, "logprobs": null, "finish_reason": "stop"}], "usage": {"prompt_tokens": 79, "completion_tokens": 14, "total_tokens": 93, "completion_tokens_details": {"reasoning_tokens": 0}}, "system_fingerprint": "fp_5050236cbd"}'
+        ),
+        mock_client=client,
+        respx_mock=respx_mock,
+    )
+    assert response.http_request.headers.get("x-stainless-helper-method") == "beta.chat.completions.parse"
+
+    completion = response.parse()
+    message = completion.choices[0].message
+    assert message.parsed is not None
+    assert isinstance(message.parsed.city, str)
+    assert print_obj(completion, monkeypatch) == snapshot(
+        """\
+ParsedChatCompletion[Location](
+    choices=[
+        ParsedChoice[Location](
+            finish_reason='stop',
+            index=0,
+            logprobs=None,
+            message=ParsedChatCompletionMessage[Location](
+                content='{"city":"San Francisco","temperature":58,"units":"f"}',
+                function_call=None,
+                parsed=Location(city='San Francisco', temperature=58.0, units='f'),
+                refusal=None,
+                role='assistant',
+                tool_calls=[]
+            )
+        )
+    ],
+    created=1727389540,
+    id='chatcmpl-ABrDYCa8W1w66eUxKDO8TQF1m6trT',
+    model='gpt-4o-2024-08-06',
+    object='chat.completion',
+    service_tier=None,
+    system_fingerprint='fp_5050236cbd',
+    usage=CompletionUsage(
+        completion_tokens=14,
+        completion_tokens_details=CompletionTokensDetails(reasoning_tokens=0),
+        prompt_tokens=79,
+        total_tokens=93
+    )
+)
+"""
+    )
+
+
+@pytest.mark.respx(base_url=base_url)
+@pytest.mark.asyncio
+async def test_async_parse_pydantic_raw_response(
+    async_client: AsyncOpenAI, respx_mock: MockRouter, monkeypatch: pytest.MonkeyPatch
+) -> None:
+    class Location(BaseModel):
+        city: str
+        temperature: float
+        units: Literal["c", "f"]
+
+    response = await _make_async_snapshot_request(
+        lambda c: c.beta.chat.completions.with_raw_response.parse(
+            model="gpt-4o-2024-08-06",
+            messages=[
+                {
+                    "role": "user",
+                    "content": "What's the weather like in SF?",
+                },
+            ],
+            response_format=Location,
+        ),
+        content_snapshot=snapshot(
+            '{"id": "chatcmpl-ABrDQWOiw0PK5JOsxl1D9ooeQgznq", "object": "chat.completion", "created": 1727389532, "model": "gpt-4o-2024-08-06", "choices": [{"index": 0, "message": {"role": "assistant", "content": "{\\"city\\":\\"San Francisco\\",\\"temperature\\":65,\\"units\\":\\"f\\"}", "refusal": null}, "logprobs": null, "finish_reason": "stop"}], "usage": {"prompt_tokens": 79, "completion_tokens": 14, "total_tokens": 93, "completion_tokens_details": {"reasoning_tokens": 0}}, "system_fingerprint": "fp_5050236cbd"}'
+        ),
+        mock_client=async_client,
+        respx_mock=respx_mock,
+    )
+    assert response.http_request.headers.get("x-stainless-helper-method") == "beta.chat.completions.parse"
+
+    completion = response.parse()
+    message = completion.choices[0].message
+    assert message.parsed is not None
+    assert isinstance(message.parsed.city, str)
+    assert print_obj(completion, monkeypatch) == snapshot(
+        """\
+ParsedChatCompletion[Location](
+    choices=[
+        ParsedChoice[Location](
+            finish_reason='stop',
+            index=0,
+            logprobs=None,
+            message=ParsedChatCompletionMessage[Location](
+                content='{"city":"San Francisco","temperature":65,"units":"f"}',
+                function_call=None,
+                parsed=Location(city='San Francisco', temperature=65.0, units='f'),
+                refusal=None,
+                role='assistant',
+                tool_calls=[]
+            )
+        )
+    ],
+    created=1727389532,
+    id='chatcmpl-ABrDQWOiw0PK5JOsxl1D9ooeQgznq',
+    model='gpt-4o-2024-08-06',
+    object='chat.completion',
+    service_tier=None,
+    system_fingerprint='fp_5050236cbd',
+    usage=CompletionUsage(
+        completion_tokens=14,
+        completion_tokens_details=CompletionTokensDetails(reasoning_tokens=0),
+        prompt_tokens=79,
+        total_tokens=93
+    )
+)
+"""
+    )
+
+
 @pytest.mark.parametrize("sync", [True, False], ids=["sync", "async"])
 def test_parse_method_in_sync(sync: bool, client: OpenAI, async_client: AsyncOpenAI) -> None:
     checking_client: OpenAI | AsyncOpenAI = client if sync else async_client
@@ -824,3 +957,45 @@ def _make_snapshot_request(
         client.close()
 
     return result
+
+
+async def _make_async_snapshot_request(
+    func: Callable[[AsyncOpenAI], Awaitable[_T]],
+    *,
+    content_snapshot: Any,
+    respx_mock: MockRouter,
+    mock_client: AsyncOpenAI,
+) -> _T:
+    live = os.environ.get("OPENAI_LIVE") == "1"
+    if live:
+
+        async def _on_response(response: httpx.Response) -> None:
+            # update the content snapshot
+            assert json.dumps(json.loads(await response.aread())) == content_snapshot
+
+        respx_mock.stop()
+
+        client = AsyncOpenAI(
+            http_client=httpx.AsyncClient(
+                event_hooks={
+                    "response": [_on_response],
+                }
+            )
+        )
+    else:
+        respx_mock.post("/chat/completions").mock(
+            return_value=httpx.Response(
+                200,
+                content=content_snapshot._old_value,
+                headers={"content-type": "application/json"},
+            )
+        )
+
+        client = mock_client
+
+    result = await func(client)
+
+    if live:
+        await client.close()
+
+    return result