main
  1from __future__ import annotations
  2
  3import io
  4import pathlib
  5from typing import Any, Dict, List, Union, TypeVar, Iterable, Optional, cast
  6from datetime import date, datetime
  7from typing_extensions import Required, Annotated, TypedDict
  8
  9import pytest
 10
 11from openai._types import Base64FileInput, omit, not_given
 12from openai._utils import (
 13    PropertyInfo,
 14    transform as _transform,
 15    parse_datetime,
 16    async_transform as _async_transform,
 17)
 18from openai._compat import PYDANTIC_V1
 19from openai._models import BaseModel
 20
 21_T = TypeVar("_T")
 22
 23SAMPLE_FILE_PATH = pathlib.Path(__file__).parent.joinpath("sample_file.txt")
 24
 25
 26async def transform(
 27    data: _T,
 28    expected_type: object,
 29    use_async: bool,
 30) -> _T:
 31    if use_async:
 32        return await _async_transform(data, expected_type=expected_type)
 33
 34    return _transform(data, expected_type=expected_type)
 35
 36
 37parametrize = pytest.mark.parametrize("use_async", [False, True], ids=["sync", "async"])
 38
 39
 40class Foo1(TypedDict):
 41    foo_bar: Annotated[str, PropertyInfo(alias="fooBar")]
 42
 43
 44@parametrize
 45@pytest.mark.asyncio
 46async def test_top_level_alias(use_async: bool) -> None:
 47    assert await transform({"foo_bar": "hello"}, expected_type=Foo1, use_async=use_async) == {"fooBar": "hello"}
 48
 49
 50class Foo2(TypedDict):
 51    bar: Bar2
 52
 53
 54class Bar2(TypedDict):
 55    this_thing: Annotated[int, PropertyInfo(alias="this__thing")]
 56    baz: Annotated[Baz2, PropertyInfo(alias="Baz")]
 57
 58
 59class Baz2(TypedDict):
 60    my_baz: Annotated[str, PropertyInfo(alias="myBaz")]
 61
 62
 63@parametrize
 64@pytest.mark.asyncio
 65async def test_recursive_typeddict(use_async: bool) -> None:
 66    assert await transform({"bar": {"this_thing": 1}}, Foo2, use_async) == {"bar": {"this__thing": 1}}
 67    assert await transform({"bar": {"baz": {"my_baz": "foo"}}}, Foo2, use_async) == {"bar": {"Baz": {"myBaz": "foo"}}}
 68
 69
 70class Foo3(TypedDict):
 71    things: List[Bar3]
 72
 73
 74class Bar3(TypedDict):
 75    my_field: Annotated[str, PropertyInfo(alias="myField")]
 76
 77
 78@parametrize
 79@pytest.mark.asyncio
 80async def test_list_of_typeddict(use_async: bool) -> None:
 81    result = await transform({"things": [{"my_field": "foo"}, {"my_field": "foo2"}]}, Foo3, use_async)
 82    assert result == {"things": [{"myField": "foo"}, {"myField": "foo2"}]}
 83
 84
 85class Foo4(TypedDict):
 86    foo: Union[Bar4, Baz4]
 87
 88
 89class Bar4(TypedDict):
 90    foo_bar: Annotated[str, PropertyInfo(alias="fooBar")]
 91
 92
 93class Baz4(TypedDict):
 94    foo_baz: Annotated[str, PropertyInfo(alias="fooBaz")]
 95
 96
 97@parametrize
 98@pytest.mark.asyncio
 99async def test_union_of_typeddict(use_async: bool) -> None:
100    assert await transform({"foo": {"foo_bar": "bar"}}, Foo4, use_async) == {"foo": {"fooBar": "bar"}}
101    assert await transform({"foo": {"foo_baz": "baz"}}, Foo4, use_async) == {"foo": {"fooBaz": "baz"}}
102    assert await transform({"foo": {"foo_baz": "baz", "foo_bar": "bar"}}, Foo4, use_async) == {
103        "foo": {"fooBaz": "baz", "fooBar": "bar"}
104    }
105
106
107class Foo5(TypedDict):
108    foo: Annotated[Union[Bar4, List[Baz4]], PropertyInfo(alias="FOO")]
109
110
111class Bar5(TypedDict):
112    foo_bar: Annotated[str, PropertyInfo(alias="fooBar")]
113
114
115class Baz5(TypedDict):
116    foo_baz: Annotated[str, PropertyInfo(alias="fooBaz")]
117
118
119@parametrize
120@pytest.mark.asyncio
121async def test_union_of_list(use_async: bool) -> None:
122    assert await transform({"foo": {"foo_bar": "bar"}}, Foo5, use_async) == {"FOO": {"fooBar": "bar"}}
123    assert await transform(
124        {
125            "foo": [
126                {"foo_baz": "baz"},
127                {"foo_baz": "baz"},
128            ]
129        },
130        Foo5,
131        use_async,
132    ) == {"FOO": [{"fooBaz": "baz"}, {"fooBaz": "baz"}]}
133
134
135class Foo6(TypedDict):
136    bar: Annotated[str, PropertyInfo(alias="Bar")]
137
138
139@parametrize
140@pytest.mark.asyncio
141async def test_includes_unknown_keys(use_async: bool) -> None:
142    assert await transform({"bar": "bar", "baz_": {"FOO": 1}}, Foo6, use_async) == {
143        "Bar": "bar",
144        "baz_": {"FOO": 1},
145    }
146
147
148class Foo7(TypedDict):
149    bar: Annotated[List[Bar7], PropertyInfo(alias="bAr")]
150    foo: Bar7
151
152
153class Bar7(TypedDict):
154    foo: str
155
156
157@parametrize
158@pytest.mark.asyncio
159async def test_ignores_invalid_input(use_async: bool) -> None:
160    assert await transform({"bar": "<foo>"}, Foo7, use_async) == {"bAr": "<foo>"}
161    assert await transform({"foo": "<foo>"}, Foo7, use_async) == {"foo": "<foo>"}
162
163
164class DatetimeDict(TypedDict, total=False):
165    foo: Annotated[datetime, PropertyInfo(format="iso8601")]
166
167    bar: Annotated[Optional[datetime], PropertyInfo(format="iso8601")]
168
169    required: Required[Annotated[Optional[datetime], PropertyInfo(format="iso8601")]]
170
171    list_: Required[Annotated[Optional[List[datetime]], PropertyInfo(format="iso8601")]]
172
173    union: Annotated[Union[int, datetime], PropertyInfo(format="iso8601")]
174
175
176class DateDict(TypedDict, total=False):
177    foo: Annotated[date, PropertyInfo(format="iso8601")]
178
179
180class DatetimeModel(BaseModel):
181    foo: datetime
182
183
184class DateModel(BaseModel):
185    foo: Optional[date]
186
187
188@parametrize
189@pytest.mark.asyncio
190async def test_iso8601_format(use_async: bool) -> None:
191    dt = datetime.fromisoformat("2023-02-23T14:16:36.337692+00:00")
192    tz = "+00:00" if PYDANTIC_V1 else "Z"
193    assert await transform({"foo": dt}, DatetimeDict, use_async) == {"foo": "2023-02-23T14:16:36.337692+00:00"}  # type: ignore[comparison-overlap]
194    assert await transform(DatetimeModel(foo=dt), Any, use_async) == {"foo": "2023-02-23T14:16:36.337692" + tz}  # type: ignore[comparison-overlap]
195
196    dt = dt.replace(tzinfo=None)
197    assert await transform({"foo": dt}, DatetimeDict, use_async) == {"foo": "2023-02-23T14:16:36.337692"}  # type: ignore[comparison-overlap]
198    assert await transform(DatetimeModel(foo=dt), Any, use_async) == {"foo": "2023-02-23T14:16:36.337692"}  # type: ignore[comparison-overlap]
199
200    assert await transform({"foo": None}, DateDict, use_async) == {"foo": None}  # type: ignore[comparison-overlap]
201    assert await transform(DateModel(foo=None), Any, use_async) == {"foo": None}  # type: ignore
202    assert await transform({"foo": date.fromisoformat("2023-02-23")}, DateDict, use_async) == {"foo": "2023-02-23"}  # type: ignore[comparison-overlap]
203    assert await transform(DateModel(foo=date.fromisoformat("2023-02-23")), DateDict, use_async) == {
204        "foo": "2023-02-23"
205    }  # type: ignore[comparison-overlap]
206
207
208@parametrize
209@pytest.mark.asyncio
210async def test_optional_iso8601_format(use_async: bool) -> None:
211    dt = datetime.fromisoformat("2023-02-23T14:16:36.337692+00:00")
212    assert await transform({"bar": dt}, DatetimeDict, use_async) == {"bar": "2023-02-23T14:16:36.337692+00:00"}  # type: ignore[comparison-overlap]
213
214    assert await transform({"bar": None}, DatetimeDict, use_async) == {"bar": None}
215
216
217@parametrize
218@pytest.mark.asyncio
219async def test_required_iso8601_format(use_async: bool) -> None:
220    dt = datetime.fromisoformat("2023-02-23T14:16:36.337692+00:00")
221    assert await transform({"required": dt}, DatetimeDict, use_async) == {
222        "required": "2023-02-23T14:16:36.337692+00:00"
223    }  # type: ignore[comparison-overlap]
224
225    assert await transform({"required": None}, DatetimeDict, use_async) == {"required": None}
226
227
228@parametrize
229@pytest.mark.asyncio
230async def test_union_datetime(use_async: bool) -> None:
231    dt = datetime.fromisoformat("2023-02-23T14:16:36.337692+00:00")
232    assert await transform({"union": dt}, DatetimeDict, use_async) == {  # type: ignore[comparison-overlap]
233        "union": "2023-02-23T14:16:36.337692+00:00"
234    }
235
236    assert await transform({"union": "foo"}, DatetimeDict, use_async) == {"union": "foo"}
237
238
239@parametrize
240@pytest.mark.asyncio
241async def test_nested_list_iso6801_format(use_async: bool) -> None:
242    dt1 = datetime.fromisoformat("2023-02-23T14:16:36.337692+00:00")
243    dt2 = parse_datetime("2022-01-15T06:34:23Z")
244    assert await transform({"list_": [dt1, dt2]}, DatetimeDict, use_async) == {  # type: ignore[comparison-overlap]
245        "list_": ["2023-02-23T14:16:36.337692+00:00", "2022-01-15T06:34:23+00:00"]
246    }
247
248
249@parametrize
250@pytest.mark.asyncio
251async def test_datetime_custom_format(use_async: bool) -> None:
252    dt = parse_datetime("2022-01-15T06:34:23Z")
253
254    result = await transform(dt, Annotated[datetime, PropertyInfo(format="custom", format_template="%H")], use_async)
255    assert result == "06"  # type: ignore[comparison-overlap]
256
257
258class DateDictWithRequiredAlias(TypedDict, total=False):
259    required_prop: Required[Annotated[date, PropertyInfo(format="iso8601", alias="prop")]]
260
261
262@parametrize
263@pytest.mark.asyncio
264async def test_datetime_with_alias(use_async: bool) -> None:
265    assert await transform({"required_prop": None}, DateDictWithRequiredAlias, use_async) == {"prop": None}  # type: ignore[comparison-overlap]
266    assert await transform(
267        {"required_prop": date.fromisoformat("2023-02-23")}, DateDictWithRequiredAlias, use_async
268    ) == {"prop": "2023-02-23"}  # type: ignore[comparison-overlap]
269
270
271class MyModel(BaseModel):
272    foo: str
273
274
275@parametrize
276@pytest.mark.asyncio
277async def test_pydantic_model_to_dictionary(use_async: bool) -> None:
278    assert cast(Any, await transform(MyModel(foo="hi!"), Any, use_async)) == {"foo": "hi!"}
279    assert cast(Any, await transform(MyModel.construct(foo="hi!"), Any, use_async)) == {"foo": "hi!"}
280
281
282@parametrize
283@pytest.mark.asyncio
284async def test_pydantic_empty_model(use_async: bool) -> None:
285    assert cast(Any, await transform(MyModel.construct(), Any, use_async)) == {}
286
287
288@parametrize
289@pytest.mark.asyncio
290async def test_pydantic_unknown_field(use_async: bool) -> None:
291    assert cast(Any, await transform(MyModel.construct(my_untyped_field=True), Any, use_async)) == {
292        "my_untyped_field": True
293    }
294
295
296@parametrize
297@pytest.mark.asyncio
298async def test_pydantic_mismatched_types(use_async: bool) -> None:
299    model = MyModel.construct(foo=True)
300    if PYDANTIC_V1:
301        params = await transform(model, Any, use_async)
302    else:
303        with pytest.warns(UserWarning):
304            params = await transform(model, Any, use_async)
305    assert cast(Any, params) == {"foo": True}
306
307
308@parametrize
309@pytest.mark.asyncio
310async def test_pydantic_mismatched_object_type(use_async: bool) -> None:
311    model = MyModel.construct(foo=MyModel.construct(hello="world"))
312    if PYDANTIC_V1:
313        params = await transform(model, Any, use_async)
314    else:
315        with pytest.warns(UserWarning):
316            params = await transform(model, Any, use_async)
317    assert cast(Any, params) == {"foo": {"hello": "world"}}
318
319
320class ModelNestedObjects(BaseModel):
321    nested: MyModel
322
323
324@parametrize
325@pytest.mark.asyncio
326async def test_pydantic_nested_objects(use_async: bool) -> None:
327    model = ModelNestedObjects.construct(nested={"foo": "stainless"})
328    assert isinstance(model.nested, MyModel)
329    assert cast(Any, await transform(model, Any, use_async)) == {"nested": {"foo": "stainless"}}
330
331
332class ModelWithDefaultField(BaseModel):
333    foo: str
334    with_none_default: Union[str, None] = None
335    with_str_default: str = "foo"
336
337
338@parametrize
339@pytest.mark.asyncio
340async def test_pydantic_default_field(use_async: bool) -> None:
341    # should be excluded when defaults are used
342    model = ModelWithDefaultField.construct()
343    assert model.with_none_default is None
344    assert model.with_str_default == "foo"
345    assert cast(Any, await transform(model, Any, use_async)) == {}
346
347    # should be included when the default value is explicitly given
348    model = ModelWithDefaultField.construct(with_none_default=None, with_str_default="foo")
349    assert model.with_none_default is None
350    assert model.with_str_default == "foo"
351    assert cast(Any, await transform(model, Any, use_async)) == {"with_none_default": None, "with_str_default": "foo"}
352
353    # should be included when a non-default value is explicitly given
354    model = ModelWithDefaultField.construct(with_none_default="bar", with_str_default="baz")
355    assert model.with_none_default == "bar"
356    assert model.with_str_default == "baz"
357    assert cast(Any, await transform(model, Any, use_async)) == {"with_none_default": "bar", "with_str_default": "baz"}
358
359
360class TypedDictIterableUnion(TypedDict):
361    foo: Annotated[Union[Bar8, Iterable[Baz8]], PropertyInfo(alias="FOO")]
362
363
364class Bar8(TypedDict):
365    foo_bar: Annotated[str, PropertyInfo(alias="fooBar")]
366
367
368class Baz8(TypedDict):
369    foo_baz: Annotated[str, PropertyInfo(alias="fooBaz")]
370
371
372@parametrize
373@pytest.mark.asyncio
374async def test_iterable_of_dictionaries(use_async: bool) -> None:
375    assert await transform({"foo": [{"foo_baz": "bar"}]}, TypedDictIterableUnion, use_async) == {
376        "FOO": [{"fooBaz": "bar"}]
377    }
378    assert cast(Any, await transform({"foo": ({"foo_baz": "bar"},)}, TypedDictIterableUnion, use_async)) == {
379        "FOO": [{"fooBaz": "bar"}]
380    }
381
382    def my_iter() -> Iterable[Baz8]:
383        yield {"foo_baz": "hello"}
384        yield {"foo_baz": "world"}
385
386    assert await transform({"foo": my_iter()}, TypedDictIterableUnion, use_async) == {
387        "FOO": [{"fooBaz": "hello"}, {"fooBaz": "world"}]
388    }
389
390
391@parametrize
392@pytest.mark.asyncio
393async def test_dictionary_items(use_async: bool) -> None:
394    class DictItems(TypedDict):
395        foo_baz: Annotated[str, PropertyInfo(alias="fooBaz")]
396
397    assert await transform({"foo": {"foo_baz": "bar"}}, Dict[str, DictItems], use_async) == {"foo": {"fooBaz": "bar"}}
398
399
400class TypedDictIterableUnionStr(TypedDict):
401    foo: Annotated[Union[str, Iterable[Baz8]], PropertyInfo(alias="FOO")]
402
403
404@parametrize
405@pytest.mark.asyncio
406async def test_iterable_union_str(use_async: bool) -> None:
407    assert await transform({"foo": "bar"}, TypedDictIterableUnionStr, use_async) == {"FOO": "bar"}
408    assert cast(Any, await transform(iter([{"foo_baz": "bar"}]), Union[str, Iterable[Baz8]], use_async)) == [
409        {"fooBaz": "bar"}
410    ]
411
412
413class TypedDictBase64Input(TypedDict):
414    foo: Annotated[Union[str, Base64FileInput], PropertyInfo(format="base64")]
415
416
417@parametrize
418@pytest.mark.asyncio
419async def test_base64_file_input(use_async: bool) -> None:
420    # strings are left as-is
421    assert await transform({"foo": "bar"}, TypedDictBase64Input, use_async) == {"foo": "bar"}
422
423    # pathlib.Path is automatically converted to base64
424    assert await transform({"foo": SAMPLE_FILE_PATH}, TypedDictBase64Input, use_async) == {
425        "foo": "SGVsbG8sIHdvcmxkIQo="
426    }  # type: ignore[comparison-overlap]
427
428    # io instances are automatically converted to base64
429    assert await transform({"foo": io.StringIO("Hello, world!")}, TypedDictBase64Input, use_async) == {
430        "foo": "SGVsbG8sIHdvcmxkIQ=="
431    }  # type: ignore[comparison-overlap]
432    assert await transform({"foo": io.BytesIO(b"Hello, world!")}, TypedDictBase64Input, use_async) == {
433        "foo": "SGVsbG8sIHdvcmxkIQ=="
434    }  # type: ignore[comparison-overlap]
435
436
437@parametrize
438@pytest.mark.asyncio
439async def test_transform_skipping(use_async: bool) -> None:
440    # lists of ints are left as-is
441    data = [1, 2, 3]
442    assert await transform(data, List[int], use_async) is data
443
444    # iterables of ints are converted to a list
445    data = iter([1, 2, 3])
446    assert await transform(data, Iterable[int], use_async) == [1, 2, 3]
447
448
449@parametrize
450@pytest.mark.asyncio
451async def test_strips_notgiven(use_async: bool) -> None:
452    assert await transform({"foo_bar": "bar"}, Foo1, use_async) == {"fooBar": "bar"}
453    assert await transform({"foo_bar": not_given}, Foo1, use_async) == {}
454
455
456@parametrize
457@pytest.mark.asyncio
458async def test_strips_omit(use_async: bool) -> None:
459    assert await transform({"foo_bar": "bar"}, Foo1, use_async) == {"fooBar": "bar"}
460    assert await transform({"foo_bar": omit}, Foo1, use_async) == {}