Commit 266edeba

Robert Craigie <robert@craigie.dev>
2025-08-11 23:20:48
refactor(tests): share snapshot utils
1 parent f03096c
tests/lib/chat/test_completions.py
@@ -1,12 +1,9 @@
 from __future__ import annotations
 
-import os
-import json
 from enum import Enum
-from typing import Any, List, Callable, Optional, Awaitable
+from typing import List, Optional
 from typing_extensions import Literal, TypeVar
 
-import httpx
 import pytest
 from respx import MockRouter
 from pydantic import Field, BaseModel
@@ -17,8 +14,9 @@ from openai import OpenAI, AsyncOpenAI
 from openai._utils import assert_signatures_in_sync
 from openai._compat import PYDANTIC_V2
 
-from ._utils import print_obj, get_snapshot_value
+from ..utils import print_obj
 from ...conftest import base_url
+from ..snapshots import make_snapshot_request, make_async_snapshot_request
 from ..schema_types.query import Query
 
 _T = TypeVar("_T")
@@ -32,7 +30,7 @@ _T = TypeVar("_T")
 
 @pytest.mark.respx(base_url=base_url)
 def test_parse_nothing(client: OpenAI, respx_mock: MockRouter, monkeypatch: pytest.MonkeyPatch) -> None:
-    completion = _make_snapshot_request(
+    completion = make_snapshot_request(
         lambda c: c.chat.completions.parse(
             model="gpt-4o-2024-08-06",
             messages=[
@@ -100,7 +98,7 @@ def test_parse_pydantic_model(client: OpenAI, respx_mock: MockRouter, monkeypatc
         temperature: float
         units: Literal["c", "f"]
 
-    completion = _make_snapshot_request(
+    completion = make_snapshot_request(
         lambda c: c.chat.completions.parse(
             model="gpt-4o-2024-08-06",
             messages=[
@@ -170,7 +168,7 @@ def test_parse_pydantic_model_optional_default(
         temperature: float
         units: Optional[Literal["c", "f"]] = None
 
-    completion = _make_snapshot_request(
+    completion = make_snapshot_request(
         lambda c: c.chat.completions.parse(
             model="gpt-4o-2024-08-06",
             messages=[
@@ -247,7 +245,7 @@ def test_parse_pydantic_model_enum(client: OpenAI, respx_mock: MockRouter, monke
     if not PYDANTIC_V2:
         ColorDetection.update_forward_refs(**locals())  # type: ignore
 
-    completion = _make_snapshot_request(
+    completion = make_snapshot_request(
         lambda c: c.chat.completions.parse(
             model="gpt-4o-2024-08-06",
             messages=[
@@ -292,7 +290,7 @@ def test_parse_pydantic_model_multiple_choices(
         temperature: float
         units: Literal["c", "f"]
 
-    completion = _make_snapshot_request(
+    completion = make_snapshot_request(
         lambda c: c.chat.completions.parse(
             model="gpt-4o-2024-08-06",
             messages=[
@@ -375,7 +373,7 @@ def test_parse_pydantic_dataclass(client: OpenAI, respx_mock: MockRouter, monkey
         date: str
         participants: List[str]
 
-    completion = _make_snapshot_request(
+    completion = make_snapshot_request(
         lambda c: c.chat.completions.parse(
             model="gpt-4o-2024-08-06",
             messages=[
@@ -436,7 +434,7 @@ ParsedChatCompletion[CalendarEvent](
 
 @pytest.mark.respx(base_url=base_url)
 def test_pydantic_tool_model_all_types(client: OpenAI, respx_mock: MockRouter, monkeypatch: pytest.MonkeyPatch) -> None:
-    completion = _make_snapshot_request(
+    completion = make_snapshot_request(
         lambda c: c.chat.completions.parse(
             model="gpt-4o-2024-08-06",
             messages=[
@@ -521,7 +519,7 @@ def test_parse_max_tokens_reached(client: OpenAI, respx_mock: MockRouter) -> Non
         units: Literal["c", "f"]
 
     with pytest.raises(openai.LengthFinishReasonError):
-        _make_snapshot_request(
+        make_snapshot_request(
             lambda c: c.chat.completions.parse(
                 model="gpt-4o-2024-08-06",
                 messages=[
@@ -548,7 +546,7 @@ def test_parse_pydantic_model_refusal(client: OpenAI, respx_mock: MockRouter, mo
         temperature: float
         units: Literal["c", "f"]
 
-    completion = _make_snapshot_request(
+    completion = make_snapshot_request(
         lambda c: c.chat.completions.parse(
             model="gpt-4o-2024-08-06",
             messages=[
@@ -596,7 +594,7 @@ def test_parse_pydantic_tool(client: OpenAI, respx_mock: MockRouter, monkeypatch
         country: str
         units: Literal["c", "f"] = "c"
 
-    completion = _make_snapshot_request(
+    completion = make_snapshot_request(
         lambda c: c.chat.completions.parse(
             model="gpt-4o-2024-08-06",
             messages=[
@@ -662,7 +660,7 @@ def test_parse_multiple_pydantic_tools(client: OpenAI, respx_mock: MockRouter, m
         ticker: str
         exchange: str
 
-    completion = _make_snapshot_request(
+    completion = make_snapshot_request(
         lambda c: c.chat.completions.parse(
             model="gpt-4o-2024-08-06",
             messages=[
@@ -733,7 +731,7 @@ def test_parse_multiple_pydantic_tools(client: OpenAI, respx_mock: MockRouter, m
 
 @pytest.mark.respx(base_url=base_url)
 def test_parse_strict_tools(client: OpenAI, respx_mock: MockRouter, monkeypatch: pytest.MonkeyPatch) -> None:
-    completion = _make_snapshot_request(
+    completion = make_snapshot_request(
         lambda c: c.chat.completions.parse(
             model="gpt-4o-2024-08-06",
             messages=[
@@ -830,7 +828,7 @@ def test_parse_pydantic_raw_response(client: OpenAI, respx_mock: MockRouter, mon
         temperature: float
         units: Literal["c", "f"]
 
-    response = _make_snapshot_request(
+    response = make_snapshot_request(
         lambda c: c.chat.completions.with_raw_response.parse(
             model="gpt-4o-2024-08-06",
             messages=[
@@ -906,7 +904,7 @@ async def test_async_parse_pydantic_raw_response(
         temperature: float
         units: Literal["c", "f"]
 
-    response = await _make_async_snapshot_request(
+    response = await make_async_snapshot_request(
         lambda c: c.chat.completions.with_raw_response.parse(
             model="gpt-4o-2024-08-06",
             messages=[
@@ -981,87 +979,3 @@ def test_parse_method_in_sync(sync: bool, client: OpenAI, async_client: AsyncOpe
         checking_client.chat.completions.parse,
         exclude_params={"response_format", "stream"},
     )
-
-
-def _make_snapshot_request(
-    func: Callable[[OpenAI], _T],
-    *,
-    content_snapshot: Any,
-    respx_mock: MockRouter,
-    mock_client: OpenAI,
-) -> _T:
-    live = os.environ.get("OPENAI_LIVE") == "1"
-    if live:
-
-        def _on_response(response: httpx.Response) -> None:
-            # update the content snapshot
-            assert json.dumps(json.loads(response.read())) == content_snapshot
-
-        respx_mock.stop()
-
-        client = OpenAI(
-            http_client=httpx.Client(
-                event_hooks={
-                    "response": [_on_response],
-                }
-            )
-        )
-    else:
-        respx_mock.post("/chat/completions").mock(
-            return_value=httpx.Response(
-                200,
-                content=get_snapshot_value(content_snapshot),
-                headers={"content-type": "application/json"},
-            )
-        )
-
-        client = mock_client
-
-    result = func(client)
-
-    if live:
-        client.close()
-
-    return result
-
-
-async def _make_async_snapshot_request(
-    func: Callable[[AsyncOpenAI], Awaitable[_T]],
-    *,
-    content_snapshot: Any,
-    respx_mock: MockRouter,
-    mock_client: AsyncOpenAI,
-) -> _T:
-    live = os.environ.get("OPENAI_LIVE") == "1"
-    if live:
-
-        async def _on_response(response: httpx.Response) -> None:
-            # update the content snapshot
-            assert json.dumps(json.loads(await response.aread())) == content_snapshot
-
-        respx_mock.stop()
-
-        client = AsyncOpenAI(
-            http_client=httpx.AsyncClient(
-                event_hooks={
-                    "response": [_on_response],
-                }
-            )
-        )
-    else:
-        respx_mock.post("/chat/completions").mock(
-            return_value=httpx.Response(
-                200,
-                content=get_snapshot_value(content_snapshot),
-                headers={"content-type": "application/json"},
-            )
-        )
-
-        client = mock_client
-
-    result = await func(client)
-
-    if live:
-        await client.close()
-
-    return result
tests/lib/chat/test_completions_streaming.py
@@ -30,7 +30,7 @@ from openai.lib.streaming.chat import (
 )
 from openai.lib._parsing._completions import ResponseFormatT
 
-from ._utils import print_obj, get_snapshot_value
+from ..utils import print_obj, get_snapshot_value
 from ...conftest import base_url
 
 _T = TypeVar("_T")
tests/lib/snapshots.py
@@ -0,0 +1,99 @@
+from __future__ import annotations
+
+import os
+import json
+from typing import Any, Callable, Awaitable
+from typing_extensions import TypeVar
+
+import httpx
+from respx import MockRouter
+
+from openai import OpenAI, AsyncOpenAI
+
+from .utils import get_snapshot_value
+
+_T = TypeVar("_T")
+
+
+def make_snapshot_request(
+    func: Callable[[OpenAI], _T],
+    *,
+    content_snapshot: Any,
+    respx_mock: MockRouter,
+    mock_client: OpenAI,
+) -> _T:
+    live = os.environ.get("OPENAI_LIVE") == "1"
+    if live:
+
+        def _on_response(response: httpx.Response) -> None:
+            # update the content snapshot
+            assert json.dumps(json.loads(response.read())) == content_snapshot
+
+        respx_mock.stop()
+
+        client = OpenAI(
+            http_client=httpx.Client(
+                event_hooks={
+                    "response": [_on_response],
+                }
+            )
+        )
+    else:
+        respx_mock.post("/chat/completions").mock(
+            return_value=httpx.Response(
+                200,
+                content=get_snapshot_value(content_snapshot),
+                headers={"content-type": "application/json"},
+            )
+        )
+
+        client = mock_client
+
+    result = func(client)
+
+    if live:
+        client.close()
+
+    return result
+
+
+async def make_async_snapshot_request(
+    func: Callable[[AsyncOpenAI], Awaitable[_T]],
+    *,
+    content_snapshot: Any,
+    respx_mock: MockRouter,
+    mock_client: AsyncOpenAI,
+) -> _T:
+    live = os.environ.get("OPENAI_LIVE") == "1"
+    if live:
+
+        async def _on_response(response: httpx.Response) -> None:
+            # update the content snapshot
+            assert json.dumps(json.loads(await response.aread())) == content_snapshot
+
+        respx_mock.stop()
+
+        client = AsyncOpenAI(
+            http_client=httpx.AsyncClient(
+                event_hooks={
+                    "response": [_on_response],
+                }
+            )
+        )
+    else:
+        respx_mock.post("/chat/completions").mock(
+            return_value=httpx.Response(
+                200,
+                content=get_snapshot_value(content_snapshot),
+                headers={"content-type": "application/json"},
+            )
+        )
+
+        client = mock_client
+
+    result = await func(client)
+
+    if live:
+        await client.close()
+
+    return result
tests/lib/chat/_utils.py → tests/lib/utils.py
@@ -7,7 +7,7 @@ from typing_extensions import TypeAlias
 import pytest
 import pydantic
 
-from ...utils import rich_print_str
+from ..utils import rich_print_str
 
 ReprArgs: TypeAlias = "Iterable[tuple[str | None, Any]]"