main
  1from __future__ import annotations
  2
  3import os
  4import json
  5from typing import Any, Callable, Awaitable
  6from typing_extensions import TypeVar
  7
  8import httpx
  9from respx import MockRouter
 10from inline_snapshot import get_snapshot_value
 11
 12from openai import OpenAI, AsyncOpenAI
 13
 14_T = TypeVar("_T")
 15
 16
 17def make_snapshot_request(
 18    func: Callable[[OpenAI], _T],
 19    *,
 20    content_snapshot: Any,
 21    respx_mock: MockRouter,
 22    mock_client: OpenAI,
 23    path: str,
 24) -> _T:
 25    live = os.environ.get("OPENAI_LIVE") == "1"
 26    if live:
 27
 28        def _on_response(response: httpx.Response) -> None:
 29            # update the content snapshot
 30            assert json.dumps(json.loads(response.read())) == content_snapshot
 31
 32        respx_mock.stop()
 33
 34        client = OpenAI(
 35            http_client=httpx.Client(
 36                event_hooks={
 37                    "response": [_on_response],
 38                }
 39            )
 40        )
 41    else:
 42        respx_mock.post(path).mock(
 43            return_value=httpx.Response(
 44                200,
 45                content=get_snapshot_value(content_snapshot),
 46                headers={"content-type": "application/json"},
 47            )
 48        )
 49
 50        client = mock_client
 51
 52    result = func(client)
 53
 54    if live:
 55        client.close()
 56
 57    return result
 58
 59
 60async def make_async_snapshot_request(
 61    func: Callable[[AsyncOpenAI], Awaitable[_T]],
 62    *,
 63    content_snapshot: Any,
 64    respx_mock: MockRouter,
 65    mock_client: AsyncOpenAI,
 66    path: str,
 67) -> _T:
 68    live = os.environ.get("OPENAI_LIVE") == "1"
 69    if live:
 70
 71        async def _on_response(response: httpx.Response) -> None:
 72            # update the content snapshot
 73            assert json.dumps(json.loads(await response.aread())) == content_snapshot
 74
 75        respx_mock.stop()
 76
 77        client = AsyncOpenAI(
 78            http_client=httpx.AsyncClient(
 79                event_hooks={
 80                    "response": [_on_response],
 81                }
 82            )
 83        )
 84    else:
 85        respx_mock.post(path).mock(
 86            return_value=httpx.Response(
 87                200,
 88                content=get_snapshot_value(content_snapshot),
 89                headers={"content-type": "application/json"},
 90            )
 91        )
 92
 93        client = mock_client
 94
 95    result = await func(client)
 96
 97    if live:
 98        await client.close()
 99
100    return result