Commit d21cd6c0

stainless-app[bot] <142633134+stainless-app[bot]@users.noreply.github.com>
2024-11-04 22:59:58
fix: support json safe serialization for basemodel subclasses (#1844)
1 parent 643e500
src/openai/_utils/__init__.py
@@ -6,6 +6,7 @@ from ._utils import (
     is_list as is_list,
     is_given as is_given,
     is_tuple as is_tuple,
+    json_safe as json_safe,
     lru_cache as lru_cache,
     is_mapping as is_mapping,
     is_tuple_t as is_tuple_t,
src/openai/_utils/_transform.py
@@ -191,7 +191,7 @@ def _transform_recursive(
         return data
 
     if isinstance(data, pydantic.BaseModel):
-        return model_dump(data, exclude_unset=True)
+        return model_dump(data, exclude_unset=True, mode="json")
 
     annotated_type = _get_annotated_type(annotation)
     if annotated_type is None:
@@ -329,7 +329,7 @@ async def _async_transform_recursive(
         return data
 
     if isinstance(data, pydantic.BaseModel):
-        return model_dump(data, exclude_unset=True)
+        return model_dump(data, exclude_unset=True, mode="json")
 
     annotated_type = _get_annotated_type(annotation)
     if annotated_type is None:
src/openai/_utils/_utils.py
@@ -16,6 +16,7 @@ from typing import (
     overload,
 )
 from pathlib import Path
+from datetime import date, datetime
 from typing_extensions import TypeGuard
 
 import sniffio
@@ -395,3 +396,19 @@ def lru_cache(*, maxsize: int | None = 128) -> Callable[[CallableT], CallableT]:
         maxsize=maxsize,
     )
     return cast(Any, wrapper)  # type: ignore[no-any-return]
+
+
+def json_safe(data: object) -> object:
+    """Translates a mapping / sequence recursively in the same fashion
+    as `pydantic` v2's `model_dump(mode="json")`.
+    """
+    if is_mapping(data):
+        return {json_safe(key): json_safe(value) for key, value in data.items()}
+
+    if is_iterable(data) and not isinstance(data, (str, bytes, bytearray)):
+        return [json_safe(item) for item in data]
+
+    if isinstance(data, (datetime, date)):
+        return data.isoformat()
+
+    return data
src/openai/_compat.py
@@ -2,7 +2,7 @@ from __future__ import annotations
 
 from typing import TYPE_CHECKING, Any, Union, Generic, TypeVar, Callable, cast, overload
 from datetime import date, datetime
-from typing_extensions import Self
+from typing_extensions import Self, Literal
 
 import pydantic
 from pydantic.fields import FieldInfo
@@ -137,9 +137,11 @@ def model_dump(
     exclude_unset: bool = False,
     exclude_defaults: bool = False,
     warnings: bool = True,
+    mode: Literal["json", "python"] = "python",
 ) -> dict[str, Any]:
-    if PYDANTIC_V2:
+    if PYDANTIC_V2 or hasattr(model, "model_dump"):
         return model.model_dump(
+            mode=mode,
             exclude=exclude,
             exclude_unset=exclude_unset,
             exclude_defaults=exclude_defaults,
src/openai/_models.py
@@ -38,6 +38,7 @@ from ._utils import (
     PropertyInfo,
     is_list,
     is_given,
+    json_safe,
     lru_cache,
     is_mapping,
     parse_date,
@@ -304,8 +305,8 @@ class BaseModel(pydantic.BaseModel):
             Returns:
                 A dictionary representation of the model.
             """
-            if mode != "python":
-                raise ValueError("mode is only supported in Pydantic v2")
+            if mode not in {"json", "python"}:
+                raise ValueError("mode must be either 'json' or 'python'")
             if round_trip != False:
                 raise ValueError("round_trip is only supported in Pydantic v2")
             if warnings != True:
@@ -314,7 +315,7 @@ class BaseModel(pydantic.BaseModel):
                 raise ValueError("context is only supported in Pydantic v2")
             if serialize_as_any != False:
                 raise ValueError("serialize_as_any is only supported in Pydantic v2")
-            return super().dict(  # pyright: ignore[reportDeprecated]
+            dumped = super().dict(  # pyright: ignore[reportDeprecated]
                 include=include,
                 exclude=exclude,
                 by_alias=by_alias,
@@ -323,6 +324,8 @@ class BaseModel(pydantic.BaseModel):
                 exclude_none=exclude_none,
             )
 
+            return cast(dict[str, Any], json_safe(dumped)) if mode == "json" else dumped
+
         @override
         def model_dump_json(
             self,
tests/test_models.py
@@ -520,19 +520,15 @@ def test_to_dict() -> None:
     assert m3.to_dict(exclude_none=True) == {}
     assert m3.to_dict(exclude_defaults=True) == {}
 
-    if PYDANTIC_V2:
-
-        class Model2(BaseModel):
-            created_at: datetime
+    class Model2(BaseModel):
+        created_at: datetime
 
-        time_str = "2024-03-21T11:39:01.275859"
-        m4 = Model2.construct(created_at=time_str)
-        assert m4.to_dict(mode="python") == {"created_at": datetime.fromisoformat(time_str)}
-        assert m4.to_dict(mode="json") == {"created_at": time_str}
-    else:
-        with pytest.raises(ValueError, match="mode is only supported in Pydantic v2"):
-            m.to_dict(mode="json")
+    time_str = "2024-03-21T11:39:01.275859"
+    m4 = Model2.construct(created_at=time_str)
+    assert m4.to_dict(mode="python") == {"created_at": datetime.fromisoformat(time_str)}
+    assert m4.to_dict(mode="json") == {"created_at": time_str}
 
+    if not PYDANTIC_V2:
         with pytest.raises(ValueError, match="warnings is only supported in Pydantic v2"):
             m.to_dict(warnings=False)
 
@@ -558,9 +554,6 @@ def test_forwards_compat_model_dump_method() -> None:
     assert m3.model_dump(exclude_none=True) == {}
 
     if not PYDANTIC_V2:
-        with pytest.raises(ValueError, match="mode is only supported in Pydantic v2"):
-            m.model_dump(mode="json")
-
         with pytest.raises(ValueError, match="round_trip is only supported in Pydantic v2"):
             m.model_dump(round_trip=True)
 
tests/test_transform.py
@@ -177,17 +177,32 @@ class DateDict(TypedDict, total=False):
     foo: Annotated[date, PropertyInfo(format="iso8601")]
 
 
+class DatetimeModel(BaseModel):
+    foo: datetime
+
+
+class DateModel(BaseModel):
+    foo: Optional[date]
+
+
 @parametrize
 @pytest.mark.asyncio
 async def test_iso8601_format(use_async: bool) -> None:
     dt = datetime.fromisoformat("2023-02-23T14:16:36.337692+00:00")
+    tz = "Z" if PYDANTIC_V2 else "+00:00"
     assert await transform({"foo": dt}, DatetimeDict, use_async) == {"foo": "2023-02-23T14:16:36.337692+00:00"}  # type: ignore[comparison-overlap]
+    assert await transform(DatetimeModel(foo=dt), Any, use_async) == {"foo": "2023-02-23T14:16:36.337692" + tz}  # type: ignore[comparison-overlap]
 
     dt = dt.replace(tzinfo=None)
     assert await transform({"foo": dt}, DatetimeDict, use_async) == {"foo": "2023-02-23T14:16:36.337692"}  # type: ignore[comparison-overlap]
+    assert await transform(DatetimeModel(foo=dt), Any, use_async) == {"foo": "2023-02-23T14:16:36.337692"}  # type: ignore[comparison-overlap]
 
     assert await transform({"foo": None}, DateDict, use_async) == {"foo": None}  # type: ignore[comparison-overlap]
+    assert await transform(DateModel(foo=None), Any, use_async) == {"foo": None}  # type: ignore
     assert await transform({"foo": date.fromisoformat("2023-02-23")}, DateDict, use_async) == {"foo": "2023-02-23"}  # type: ignore[comparison-overlap]
+    assert await transform(DateModel(foo=date.fromisoformat("2023-02-23")), DateDict, use_async) == {
+        "foo": "2023-02-23"
+    }  # type: ignore[comparison-overlap]
 
 
 @parametrize