main
1from __future__ import annotations
2
3import io
4import os
5import inspect
6import traceback
7import contextlib
8from typing import Any, TypeVar, Iterator, Sequence, ForwardRef, cast
9from datetime import date, datetime
10from typing_extensions import Literal, get_args, get_origin, assert_type
11
12import rich
13
14from openai._types import Omit, NoneType
15from openai._utils import (
16 is_dict,
17 is_list,
18 is_list_type,
19 is_union_type,
20 extract_type_arg,
21 is_sequence_type,
22 is_annotated_type,
23 is_type_alias_type,
24)
25from openai._compat import PYDANTIC_V1, field_outer_type, get_model_fields
26from openai._models import BaseModel
27
28BaseModelT = TypeVar("BaseModelT", bound=BaseModel)
29
30
31def evaluate_forwardref(forwardref: ForwardRef, globalns: dict[str, Any]) -> type:
32 return eval(str(forwardref), globalns) # type: ignore[no-any-return]
33
34
35def assert_matches_model(model: type[BaseModelT], value: BaseModelT, *, path: list[str]) -> bool:
36 for name, field in get_model_fields(model).items():
37 field_value = getattr(value, name)
38 if PYDANTIC_V1:
39 # in v1 nullability was structured differently
40 # https://docs.pydantic.dev/2.0/migration/#required-optional-and-nullable-fields
41 allow_none = getattr(field, "allow_none", False)
42 else:
43 allow_none = False
44
45 assert_matches_type(
46 field_outer_type(field),
47 field_value,
48 path=[*path, name],
49 allow_none=allow_none,
50 )
51
52 return True
53
54
55# Note: the `path` argument is only used to improve error messages when `--showlocals` is used
56def assert_matches_type(
57 type_: Any,
58 value: object,
59 *,
60 path: list[str],
61 allow_none: bool = False,
62) -> None:
63 if is_type_alias_type(type_):
64 type_ = type_.__value__
65
66 # unwrap `Annotated[T, ...]` -> `T`
67 if is_annotated_type(type_):
68 type_ = extract_type_arg(type_, 0)
69
70 if allow_none and value is None:
71 return
72
73 if type_ is None or type_ is NoneType:
74 assert value is None
75 return
76
77 origin = get_origin(type_) or type_
78
79 if is_list_type(type_):
80 return _assert_list_type(type_, value)
81
82 if is_sequence_type(type_):
83 assert isinstance(value, Sequence)
84 inner_type = get_args(type_)[0]
85 for entry in value: # type: ignore
86 assert_type(inner_type, entry) # type: ignore
87 return
88
89 if origin == str:
90 assert isinstance(value, str)
91 elif origin == int:
92 assert isinstance(value, int)
93 elif origin == bool:
94 assert isinstance(value, bool)
95 elif origin == float:
96 assert isinstance(value, float)
97 elif origin == bytes:
98 assert isinstance(value, bytes)
99 elif origin == datetime:
100 assert isinstance(value, datetime)
101 elif origin == date:
102 assert isinstance(value, date)
103 elif origin == object:
104 # nothing to do here, the expected type is unknown
105 pass
106 elif origin == Literal:
107 assert value in get_args(type_)
108 elif origin == dict:
109 assert is_dict(value)
110
111 args = get_args(type_)
112 key_type = args[0]
113 items_type = args[1]
114
115 for key, item in value.items():
116 assert_matches_type(key_type, key, path=[*path, "<dict key>"])
117 assert_matches_type(items_type, item, path=[*path, "<dict item>"])
118 elif is_union_type(type_):
119 variants = get_args(type_)
120
121 try:
122 none_index = variants.index(type(None))
123 except ValueError:
124 pass
125 else:
126 # special case Optional[T] for better error messages
127 if len(variants) == 2:
128 if value is None:
129 # valid
130 return
131
132 return assert_matches_type(type_=variants[not none_index], value=value, path=path)
133
134 for i, variant in enumerate(variants):
135 try:
136 assert_matches_type(variant, value, path=[*path, f"variant {i}"])
137 return
138 except AssertionError:
139 traceback.print_exc()
140 continue
141
142 raise AssertionError("Did not match any variants")
143 elif issubclass(origin, BaseModel):
144 assert isinstance(value, type_)
145 assert assert_matches_model(type_, cast(Any, value), path=path)
146 elif inspect.isclass(origin) and origin.__name__ == "HttpxBinaryResponseContent":
147 assert value.__class__.__name__ == "HttpxBinaryResponseContent"
148 else:
149 assert None, f"Unhandled field type: {type_}"
150
151
152def _assert_list_type(type_: type[object], value: object) -> None:
153 assert is_list(value)
154
155 inner_type = get_args(type_)[0]
156 for entry in value:
157 assert_type(inner_type, entry) # type: ignore
158
159
160def rich_print_str(obj: object) -> str:
161 """Like `rich.print()` but returns the string instead"""
162 buf = io.StringIO()
163
164 console = rich.console.Console(file=buf, width=120)
165 console.print(obj)
166
167 return buf.getvalue()
168
169
170@contextlib.contextmanager
171def update_env(**new_env: str | Omit) -> Iterator[None]:
172 old = os.environ.copy()
173
174 try:
175 for name, value in new_env.items():
176 if isinstance(value, Omit):
177 os.environ.pop(name, None)
178 else:
179 os.environ[name] = value
180
181 yield None
182 finally:
183 os.environ.clear()
184 os.environ.update(old)