Commit e1aeeb0e

Robert Craigie <robert@craigie.dev>
2024-09-27 06:12:51
chore(pydantic v1): exclude specific properties when rich printing (#1751)
1 parent 70edb21
src/openai/_models.py
@@ -2,7 +2,7 @@ from __future__ import annotations
 
 import os
 import inspect
-from typing import TYPE_CHECKING, Any, Type, Union, Generic, TypeVar, Callable, Optional, cast
+from typing import TYPE_CHECKING, Any, Type, Tuple, Union, Generic, TypeVar, Callable, Optional, cast
 from datetime import date, datetime
 from typing_extensions import (
     Unpack,
@@ -10,6 +10,7 @@ from typing_extensions import (
     ClassVar,
     Protocol,
     Required,
+    Sequence,
     ParamSpec,
     TypedDict,
     TypeGuard,
@@ -72,6 +73,8 @@ _BaseModelT = TypeVar("_BaseModelT", bound="BaseModel")
 
 P = ParamSpec("P")
 
+ReprArgs = Sequence[Tuple[Optional[str], Any]]
+
 
 @runtime_checkable
 class _ConfigProtocol(Protocol):
@@ -94,6 +97,11 @@ class BaseModel(pydantic.BaseModel):
         class Config(pydantic.BaseConfig):  # pyright: ignore[reportDeprecated]
             extra: Any = pydantic.Extra.allow  # type: ignore
 
+        @override
+        def __repr_args__(self) -> ReprArgs:
+            # we don't want these attributes to be included when something like `rich.print` is used
+            return [arg for arg in super().__repr_args__() if arg[0] not in {"_request_id", "__exclude_fields__"}]
+
     if TYPE_CHECKING:
         _request_id: Optional[str] = None
         """The ID of the request, returned via the X-Request-ID header. Useful for debugging requests and reporting issues to OpenAI.
tests/lib/chat/_utils.py
@@ -1,14 +1,14 @@
 from __future__ import annotations
 
-import io
 import inspect
 from typing import Any, Iterable
 from typing_extensions import TypeAlias
 
-import rich
 import pytest
 import pydantic
 
+from ...utils import rich_print_str
+
 ReprArgs: TypeAlias = "Iterable[tuple[str | None, Any]]"
 
 
@@ -26,12 +26,7 @@ def print_obj(obj: object, monkeypatch: pytest.MonkeyPatch) -> str:
     with monkeypatch.context() as m:
         m.setattr(pydantic.BaseModel, "__repr_args__", __repr_args__)
 
-        buf = io.StringIO()
-
-        console = rich.console.Console(file=buf, width=120)
-        console.print(obj)
-
-        string = buf.getvalue()
+        string = rich_print_str(obj)
 
         # we remove all `fn_name.<locals>.` occurences
         # so that we can share the same snapshots between
tests/test_legacy_response.py
@@ -11,6 +11,8 @@ from openai._streaming import Stream
 from openai._base_client import FinalRequestOptions
 from openai._legacy_response import LegacyAPIResponse
 
+from .utils import rich_print_str
+
 
 class PydanticModel(pydantic.BaseModel): ...
 
@@ -85,6 +87,8 @@ def test_response_basemodel_request_id(client: OpenAI) -> None:
     assert obj.foo == "hello!"
     assert obj.bar == 2
     assert obj.to_dict() == {"foo": "hello!", "bar": 2}
+    assert "_request_id" not in rich_print_str(obj)
+    assert "__exclude_fields__" not in rich_print_str(obj)
 
 
 def test_response_parse_annotated_type(client: OpenAI) -> None:
tests/test_response.py
@@ -18,6 +18,8 @@ from openai._response import (
 from openai._streaming import Stream
 from openai._base_client import FinalRequestOptions
 
+from .utils import rich_print_str
+
 
 class ConcreteBaseAPIResponse(APIResponse[bytes]): ...
 
@@ -175,6 +177,8 @@ def test_response_basemodel_request_id(client: OpenAI) -> None:
     assert obj.foo == "hello!"
     assert obj.bar == 2
     assert obj.to_dict() == {"foo": "hello!", "bar": 2}
+    assert "_request_id" not in rich_print_str(obj)
+    assert "__exclude_fields__" not in rich_print_str(obj)
 
 
 @pytest.mark.asyncio
tests/utils.py
@@ -1,5 +1,6 @@
 from __future__ import annotations
 
+import io
 import os
 import inspect
 import traceback
@@ -8,6 +9,8 @@ from typing import Any, TypeVar, Iterator, cast
 from datetime import date, datetime
 from typing_extensions import Literal, get_args, get_origin, assert_type
 
+import rich
+
 from openai._types import Omit, NoneType
 from openai._utils import (
     is_dict,
@@ -138,6 +141,16 @@ def _assert_list_type(type_: type[object], value: object) -> None:
         assert_type(inner_type, entry)  # type: ignore
 
 
+def rich_print_str(obj: object) -> str:
+    """Like `rich.print()` but returns the string instead"""
+    buf = io.StringIO()
+
+    console = rich.console.Console(file=buf, width=120)
+    console.print(obj)
+
+    return buf.getvalue()
+
+
 @contextlib.contextmanager
 def update_env(**new_env: str | Omit) -> Iterator[None]:
     old = os.environ.copy()