Commit d21cd6c0
Changed files (7)
src
src/openai/_utils/__init__.py
@@ -6,6 +6,7 @@ from ._utils import (
is_list as is_list,
is_given as is_given,
is_tuple as is_tuple,
+ json_safe as json_safe,
lru_cache as lru_cache,
is_mapping as is_mapping,
is_tuple_t as is_tuple_t,
src/openai/_utils/_transform.py
@@ -191,7 +191,7 @@ def _transform_recursive(
return data
if isinstance(data, pydantic.BaseModel):
- return model_dump(data, exclude_unset=True)
+ return model_dump(data, exclude_unset=True, mode="json")
annotated_type = _get_annotated_type(annotation)
if annotated_type is None:
@@ -329,7 +329,7 @@ async def _async_transform_recursive(
return data
if isinstance(data, pydantic.BaseModel):
- return model_dump(data, exclude_unset=True)
+ return model_dump(data, exclude_unset=True, mode="json")
annotated_type = _get_annotated_type(annotation)
if annotated_type is None:
src/openai/_utils/_utils.py
@@ -16,6 +16,7 @@ from typing import (
overload,
)
from pathlib import Path
+from datetime import date, datetime
from typing_extensions import TypeGuard
import sniffio
@@ -395,3 +396,19 @@ def lru_cache(*, maxsize: int | None = 128) -> Callable[[CallableT], CallableT]:
maxsize=maxsize,
)
return cast(Any, wrapper) # type: ignore[no-any-return]
+
+
+def json_safe(data: object) -> object:
+ """Translates a mapping / sequence recursively in the same fashion
+ as `pydantic` v2's `model_dump(mode="json")`.
+ """
+ if is_mapping(data):
+ return {json_safe(key): json_safe(value) for key, value in data.items()}
+
+ if is_iterable(data) and not isinstance(data, (str, bytes, bytearray)):
+ return [json_safe(item) for item in data]
+
+ if isinstance(data, (datetime, date)):
+ return data.isoformat()
+
+ return data
src/openai/_compat.py
@@ -2,7 +2,7 @@ from __future__ import annotations
from typing import TYPE_CHECKING, Any, Union, Generic, TypeVar, Callable, cast, overload
from datetime import date, datetime
-from typing_extensions import Self
+from typing_extensions import Self, Literal
import pydantic
from pydantic.fields import FieldInfo
@@ -137,9 +137,11 @@ def model_dump(
exclude_unset: bool = False,
exclude_defaults: bool = False,
warnings: bool = True,
+ mode: Literal["json", "python"] = "python",
) -> dict[str, Any]:
- if PYDANTIC_V2:
+ if PYDANTIC_V2 or hasattr(model, "model_dump"):
return model.model_dump(
+ mode=mode,
exclude=exclude,
exclude_unset=exclude_unset,
exclude_defaults=exclude_defaults,
src/openai/_models.py
@@ -38,6 +38,7 @@ from ._utils import (
PropertyInfo,
is_list,
is_given,
+ json_safe,
lru_cache,
is_mapping,
parse_date,
@@ -304,8 +305,8 @@ class BaseModel(pydantic.BaseModel):
Returns:
A dictionary representation of the model.
"""
- if mode != "python":
- raise ValueError("mode is only supported in Pydantic v2")
+ if mode not in {"json", "python"}:
+ raise ValueError("mode must be either 'json' or 'python'")
if round_trip != False:
raise ValueError("round_trip is only supported in Pydantic v2")
if warnings != True:
@@ -314,7 +315,7 @@ class BaseModel(pydantic.BaseModel):
raise ValueError("context is only supported in Pydantic v2")
if serialize_as_any != False:
raise ValueError("serialize_as_any is only supported in Pydantic v2")
- return super().dict( # pyright: ignore[reportDeprecated]
+ dumped = super().dict( # pyright: ignore[reportDeprecated]
include=include,
exclude=exclude,
by_alias=by_alias,
@@ -323,6 +324,8 @@ class BaseModel(pydantic.BaseModel):
exclude_none=exclude_none,
)
+ return cast(dict[str, Any], json_safe(dumped)) if mode == "json" else dumped
+
@override
def model_dump_json(
self,
tests/test_models.py
@@ -520,19 +520,15 @@ def test_to_dict() -> None:
assert m3.to_dict(exclude_none=True) == {}
assert m3.to_dict(exclude_defaults=True) == {}
- if PYDANTIC_V2:
-
- class Model2(BaseModel):
- created_at: datetime
+ class Model2(BaseModel):
+ created_at: datetime
- time_str = "2024-03-21T11:39:01.275859"
- m4 = Model2.construct(created_at=time_str)
- assert m4.to_dict(mode="python") == {"created_at": datetime.fromisoformat(time_str)}
- assert m4.to_dict(mode="json") == {"created_at": time_str}
- else:
- with pytest.raises(ValueError, match="mode is only supported in Pydantic v2"):
- m.to_dict(mode="json")
+ time_str = "2024-03-21T11:39:01.275859"
+ m4 = Model2.construct(created_at=time_str)
+ assert m4.to_dict(mode="python") == {"created_at": datetime.fromisoformat(time_str)}
+ assert m4.to_dict(mode="json") == {"created_at": time_str}
+ if not PYDANTIC_V2:
with pytest.raises(ValueError, match="warnings is only supported in Pydantic v2"):
m.to_dict(warnings=False)
@@ -558,9 +554,6 @@ def test_forwards_compat_model_dump_method() -> None:
assert m3.model_dump(exclude_none=True) == {}
if not PYDANTIC_V2:
- with pytest.raises(ValueError, match="mode is only supported in Pydantic v2"):
- m.model_dump(mode="json")
-
with pytest.raises(ValueError, match="round_trip is only supported in Pydantic v2"):
m.model_dump(round_trip=True)
tests/test_transform.py
@@ -177,17 +177,32 @@ class DateDict(TypedDict, total=False):
foo: Annotated[date, PropertyInfo(format="iso8601")]
+class DatetimeModel(BaseModel):
+ foo: datetime
+
+
+class DateModel(BaseModel):
+ foo: Optional[date]
+
+
@parametrize
@pytest.mark.asyncio
async def test_iso8601_format(use_async: bool) -> None:
dt = datetime.fromisoformat("2023-02-23T14:16:36.337692+00:00")
+ tz = "Z" if PYDANTIC_V2 else "+00:00"
assert await transform({"foo": dt}, DatetimeDict, use_async) == {"foo": "2023-02-23T14:16:36.337692+00:00"} # type: ignore[comparison-overlap]
+ assert await transform(DatetimeModel(foo=dt), Any, use_async) == {"foo": "2023-02-23T14:16:36.337692" + tz} # type: ignore[comparison-overlap]
dt = dt.replace(tzinfo=None)
assert await transform({"foo": dt}, DatetimeDict, use_async) == {"foo": "2023-02-23T14:16:36.337692"} # type: ignore[comparison-overlap]
+ assert await transform(DatetimeModel(foo=dt), Any, use_async) == {"foo": "2023-02-23T14:16:36.337692"} # type: ignore[comparison-overlap]
assert await transform({"foo": None}, DateDict, use_async) == {"foo": None} # type: ignore[comparison-overlap]
+ assert await transform(DateModel(foo=None), Any, use_async) == {"foo": None} # type: ignore
assert await transform({"foo": date.fromisoformat("2023-02-23")}, DateDict, use_async) == {"foo": "2023-02-23"} # type: ignore[comparison-overlap]
+ assert await transform(DateModel(foo=date.fromisoformat("2023-02-23")), DateDict, use_async) == {
+ "foo": "2023-02-23"
+ } # type: ignore[comparison-overlap]
@parametrize