Commit af6a9437

stainless-app[bot] <142633134+stainless-app[bot]@users.noreply.github.com>
2025-02-06 21:00:22
chore(internal): fix type traversing dictionary params (#2097)
1 parent 2c20ea7
Changed files (2)
src
openai
tests
src/openai/_utils/_transform.py
@@ -25,7 +25,7 @@ from ._typing import (
     is_annotated_type,
     strip_annotated_type,
 )
-from .._compat import model_dump, is_typeddict
+from .._compat import get_origin, model_dump, is_typeddict
 
 _T = TypeVar("_T")
 
@@ -164,9 +164,14 @@ def _transform_recursive(
         inner_type = annotation
 
     stripped_type = strip_annotated_type(inner_type)
+    origin = get_origin(stripped_type) or stripped_type
     if is_typeddict(stripped_type) and is_mapping(data):
         return _transform_typeddict(data, stripped_type)
 
+    if origin == dict and is_mapping(data):
+        items_type = get_args(stripped_type)[1]
+        return {key: _transform_recursive(value, annotation=items_type) for key, value in data.items()}
+
     if (
         # List[T]
         (is_list_type(stripped_type) and is_list(data))
@@ -307,9 +312,14 @@ async def _async_transform_recursive(
         inner_type = annotation
 
     stripped_type = strip_annotated_type(inner_type)
+    origin = get_origin(stripped_type) or stripped_type
     if is_typeddict(stripped_type) and is_mapping(data):
         return await _async_transform_typeddict(data, stripped_type)
 
+    if origin == dict and is_mapping(data):
+        items_type = get_args(stripped_type)[1]
+        return {key: _transform_recursive(value, annotation=items_type) for key, value in data.items()}
+
     if (
         # List[T]
         (is_list_type(stripped_type) and is_list(data))
tests/test_transform.py
@@ -2,7 +2,7 @@ from __future__ import annotations
 
 import io
 import pathlib
-from typing import Any, List, Union, TypeVar, Iterable, Optional, cast
+from typing import Any, Dict, List, Union, TypeVar, Iterable, Optional, cast
 from datetime import date, datetime
 from typing_extensions import Required, Annotated, TypedDict
 
@@ -388,6 +388,15 @@ async def test_iterable_of_dictionaries(use_async: bool) -> None:
     }
 
 
+@parametrize
+@pytest.mark.asyncio
+async def test_dictionary_items(use_async: bool) -> None:
+    class DictItems(TypedDict):
+        foo_baz: Annotated[str, PropertyInfo(alias="fooBaz")]
+
+    assert await transform({"foo": {"foo_baz": "bar"}}, Dict[str, DictItems], use_async) == {"foo": {"fooBaz": "bar"}}
+
+
 class TypedDictIterableUnionStr(TypedDict):
     foo: Annotated[Union[str, Iterable[Baz8]], PropertyInfo(alias="FOO")]