Commit adb0b31a

Stainless Bot <107565488+stainless-bot@users.noreply.github.com>
2024-02-06 17:50:54
chore(internal): support serialising iterable types (#1127)
1 parent 489dadf
Changed files (5)
src/openai/_utils/__init__.py
@@ -9,6 +9,7 @@ from ._utils import (
     is_mapping as is_mapping,
     is_tuple_t as is_tuple_t,
     parse_date as parse_date,
+    is_iterable as is_iterable,
     is_sequence as is_sequence,
     coerce_float as coerce_float,
     is_mapping_t as is_mapping_t,
@@ -33,6 +34,7 @@ from ._typing import (
     is_list_type as is_list_type,
     is_union_type as is_union_type,
     extract_type_arg as extract_type_arg,
+    is_iterable_type as is_iterable_type,
     is_required_type as is_required_type,
     is_annotated_type as is_annotated_type,
     strip_annotated_type as strip_annotated_type,
src/openai/_utils/_transform.py
@@ -9,11 +9,13 @@ import pydantic
 from ._utils import (
     is_list,
     is_mapping,
+    is_iterable,
 )
 from ._typing import (
     is_list_type,
     is_union_type,
     extract_type_arg,
+    is_iterable_type,
     is_required_type,
     is_annotated_type,
     strip_annotated_type,
@@ -157,7 +159,12 @@ def _transform_recursive(
     if is_typeddict(stripped_type) and is_mapping(data):
         return _transform_typeddict(data, stripped_type)
 
-    if is_list_type(stripped_type) and is_list(data):
+    if (
+        # List[T]
+        (is_list_type(stripped_type) and is_list(data))
+        # Iterable[T]
+        or (is_iterable_type(stripped_type) and is_iterable(data) and not isinstance(data, str))
+    ):
         inner_type = extract_type_arg(stripped_type, 0)
         return [_transform_recursive(d, annotation=annotation, inner_type=inner_type) for d in data]
 
src/openai/_utils/_typing.py
@@ -1,6 +1,7 @@
 from __future__ import annotations
 
-from typing import Any, TypeVar, cast
+from typing import Any, TypeVar, Iterable, cast
+from collections import abc as _c_abc
 from typing_extensions import Required, Annotated, get_args, get_origin
 
 from .._types import InheritsGeneric
@@ -15,6 +16,12 @@ def is_list_type(typ: type) -> bool:
     return (get_origin(typ) or typ) == list
 
 
+def is_iterable_type(typ: type) -> bool:
+    """If the given type is `typing.Iterable[T]`"""
+    origin = get_origin(typ) or typ
+    return origin == Iterable or origin == _c_abc.Iterable
+
+
 def is_union_type(typ: type) -> bool:
     return _is_union(get_origin(typ))
 
src/openai/_utils/_utils.py
@@ -164,6 +164,10 @@ def is_list(obj: object) -> TypeGuard[list[object]]:
     return isinstance(obj, list)
 
 
+def is_iterable(obj: object) -> TypeGuard[Iterable[object]]:
+    return isinstance(obj, Iterable)
+
+
 def deepcopy_minimal(item: _T) -> _T:
     """Minimal reimplementation of copy.deepcopy() that will only copy certain object types:
 
tests/test_transform.py
@@ -1,6 +1,6 @@
 from __future__ import annotations
 
-from typing import Any, List, Union, Optional
+from typing import Any, List, Union, Iterable, Optional, cast
 from datetime import date, datetime
 from typing_extensions import Required, Annotated, TypedDict
 
@@ -265,3 +265,35 @@ def test_pydantic_default_field() -> None:
     assert model.with_none_default == "bar"
     assert model.with_str_default == "baz"
     assert transform(model, Any) == {"with_none_default": "bar", "with_str_default": "baz"}
+
+
+class TypedDictIterableUnion(TypedDict):
+    foo: Annotated[Union[Bar8, Iterable[Baz8]], PropertyInfo(alias="FOO")]
+
+
+class Bar8(TypedDict):
+    foo_bar: Annotated[str, PropertyInfo(alias="fooBar")]
+
+
+class Baz8(TypedDict):
+    foo_baz: Annotated[str, PropertyInfo(alias="fooBaz")]
+
+
+def test_iterable_of_dictionaries() -> None:
+    assert transform({"foo": [{"foo_baz": "bar"}]}, TypedDictIterableUnion) == {"FOO": [{"fooBaz": "bar"}]}
+    assert cast(Any, transform({"foo": ({"foo_baz": "bar"},)}, TypedDictIterableUnion)) == {"FOO": [{"fooBaz": "bar"}]}
+
+    def my_iter() -> Iterable[Baz8]:
+        yield {"foo_baz": "hello"}
+        yield {"foo_baz": "world"}
+
+    assert transform({"foo": my_iter()}, TypedDictIterableUnion) == {"FOO": [{"fooBaz": "hello"}, {"fooBaz": "world"}]}
+
+
+class TypedDictIterableUnionStr(TypedDict):
+    foo: Annotated[Union[str, Iterable[Baz8]], PropertyInfo(alias="FOO")]
+
+
+def test_iterable_union_str() -> None:
+    assert transform({"foo": "bar"}, TypedDictIterableUnionStr) == {"FOO": "bar"}
+    assert cast(Any, transform(iter([{"foo_baz": "bar"}]), Union[str, Iterable[Baz8]])) == [{"fooBaz": "bar"}]