Commit e1aeeb0e
Changed files (5)
src
openai
tests
src/openai/_models.py
@@ -2,7 +2,7 @@ from __future__ import annotations
import os
import inspect
-from typing import TYPE_CHECKING, Any, Type, Union, Generic, TypeVar, Callable, Optional, cast
+from typing import TYPE_CHECKING, Any, Type, Tuple, Union, Generic, TypeVar, Callable, Optional, cast
from datetime import date, datetime
from typing_extensions import (
Unpack,
@@ -10,6 +10,7 @@ from typing_extensions import (
ClassVar,
Protocol,
Required,
+ Sequence,
ParamSpec,
TypedDict,
TypeGuard,
@@ -72,6 +73,8 @@ _BaseModelT = TypeVar("_BaseModelT", bound="BaseModel")
P = ParamSpec("P")
+ReprArgs = Sequence[Tuple[Optional[str], Any]]
+
@runtime_checkable
class _ConfigProtocol(Protocol):
@@ -94,6 +97,11 @@ class BaseModel(pydantic.BaseModel):
class Config(pydantic.BaseConfig): # pyright: ignore[reportDeprecated]
extra: Any = pydantic.Extra.allow # type: ignore
+ @override
+ def __repr_args__(self) -> ReprArgs:
+ # we don't want these attributes to be included when something like `rich.print` is used
+ return [arg for arg in super().__repr_args__() if arg[0] not in {"_request_id", "__exclude_fields__"}]
+
if TYPE_CHECKING:
_request_id: Optional[str] = None
"""The ID of the request, returned via the X-Request-ID header. Useful for debugging requests and reporting issues to OpenAI.
tests/lib/chat/_utils.py
@@ -1,14 +1,14 @@
from __future__ import annotations
-import io
import inspect
from typing import Any, Iterable
from typing_extensions import TypeAlias
-import rich
import pytest
import pydantic
+from ...utils import rich_print_str
+
ReprArgs: TypeAlias = "Iterable[tuple[str | None, Any]]"
@@ -26,12 +26,7 @@ def print_obj(obj: object, monkeypatch: pytest.MonkeyPatch) -> str:
with monkeypatch.context() as m:
m.setattr(pydantic.BaseModel, "__repr_args__", __repr_args__)
- buf = io.StringIO()
-
- console = rich.console.Console(file=buf, width=120)
- console.print(obj)
-
- string = buf.getvalue()
+ string = rich_print_str(obj)
# we remove all `fn_name.<locals>.` occurences
# so that we can share the same snapshots between
tests/test_legacy_response.py
@@ -11,6 +11,8 @@ from openai._streaming import Stream
from openai._base_client import FinalRequestOptions
from openai._legacy_response import LegacyAPIResponse
+from .utils import rich_print_str
+
class PydanticModel(pydantic.BaseModel): ...
@@ -85,6 +87,8 @@ def test_response_basemodel_request_id(client: OpenAI) -> None:
assert obj.foo == "hello!"
assert obj.bar == 2
assert obj.to_dict() == {"foo": "hello!", "bar": 2}
+ assert "_request_id" not in rich_print_str(obj)
+ assert "__exclude_fields__" not in rich_print_str(obj)
def test_response_parse_annotated_type(client: OpenAI) -> None:
tests/test_response.py
@@ -18,6 +18,8 @@ from openai._response import (
from openai._streaming import Stream
from openai._base_client import FinalRequestOptions
+from .utils import rich_print_str
+
class ConcreteBaseAPIResponse(APIResponse[bytes]): ...
@@ -175,6 +177,8 @@ def test_response_basemodel_request_id(client: OpenAI) -> None:
assert obj.foo == "hello!"
assert obj.bar == 2
assert obj.to_dict() == {"foo": "hello!", "bar": 2}
+ assert "_request_id" not in rich_print_str(obj)
+ assert "__exclude_fields__" not in rich_print_str(obj)
@pytest.mark.asyncio
tests/utils.py
@@ -1,5 +1,6 @@
from __future__ import annotations
+import io
import os
import inspect
import traceback
@@ -8,6 +9,8 @@ from typing import Any, TypeVar, Iterator, cast
from datetime import date, datetime
from typing_extensions import Literal, get_args, get_origin, assert_type
+import rich
+
from openai._types import Omit, NoneType
from openai._utils import (
is_dict,
@@ -138,6 +141,16 @@ def _assert_list_type(type_: type[object], value: object) -> None:
assert_type(inner_type, entry) # type: ignore
+def rich_print_str(obj: object) -> str:
+ """Like `rich.print()` but returns the string instead"""
+ buf = io.StringIO()
+
+ console = rich.console.Console(file=buf, width=120)
+ console.print(obj)
+
+ return buf.getvalue()
+
+
@contextlib.contextmanager
def update_env(**new_env: str | Omit) -> Iterator[None]:
old = os.environ.copy()