main
  1from __future__ import annotations
  2
  3import io
  4import os
  5import inspect
  6import traceback
  7import contextlib
  8from typing import Any, TypeVar, Iterator, Sequence, ForwardRef, cast
  9from datetime import date, datetime
 10from typing_extensions import Literal, get_args, get_origin, assert_type
 11
 12import rich
 13
 14from openai._types import Omit, NoneType
 15from openai._utils import (
 16    is_dict,
 17    is_list,
 18    is_list_type,
 19    is_union_type,
 20    extract_type_arg,
 21    is_sequence_type,
 22    is_annotated_type,
 23    is_type_alias_type,
 24)
 25from openai._compat import PYDANTIC_V1, field_outer_type, get_model_fields
 26from openai._models import BaseModel
 27
 28BaseModelT = TypeVar("BaseModelT", bound=BaseModel)
 29
 30
 31def evaluate_forwardref(forwardref: ForwardRef, globalns: dict[str, Any]) -> type:
 32    return eval(str(forwardref), globalns)  # type: ignore[no-any-return]
 33
 34
 35def assert_matches_model(model: type[BaseModelT], value: BaseModelT, *, path: list[str]) -> bool:
 36    for name, field in get_model_fields(model).items():
 37        field_value = getattr(value, name)
 38        if PYDANTIC_V1:
 39            # in v1 nullability was structured differently
 40            # https://docs.pydantic.dev/2.0/migration/#required-optional-and-nullable-fields
 41            allow_none = getattr(field, "allow_none", False)
 42        else:
 43            allow_none = False
 44
 45        assert_matches_type(
 46            field_outer_type(field),
 47            field_value,
 48            path=[*path, name],
 49            allow_none=allow_none,
 50        )
 51
 52    return True
 53
 54
 55# Note: the `path` argument is only used to improve error messages when `--showlocals` is used
 56def assert_matches_type(
 57    type_: Any,
 58    value: object,
 59    *,
 60    path: list[str],
 61    allow_none: bool = False,
 62) -> None:
 63    if is_type_alias_type(type_):
 64        type_ = type_.__value__
 65
 66    # unwrap `Annotated[T, ...]` -> `T`
 67    if is_annotated_type(type_):
 68        type_ = extract_type_arg(type_, 0)
 69
 70    if allow_none and value is None:
 71        return
 72
 73    if type_ is None or type_ is NoneType:
 74        assert value is None
 75        return
 76
 77    origin = get_origin(type_) or type_
 78
 79    if is_list_type(type_):
 80        return _assert_list_type(type_, value)
 81
 82    if is_sequence_type(type_):
 83        assert isinstance(value, Sequence)
 84        inner_type = get_args(type_)[0]
 85        for entry in value:  # type: ignore
 86            assert_type(inner_type, entry)  # type: ignore
 87        return
 88
 89    if origin == str:
 90        assert isinstance(value, str)
 91    elif origin == int:
 92        assert isinstance(value, int)
 93    elif origin == bool:
 94        assert isinstance(value, bool)
 95    elif origin == float:
 96        assert isinstance(value, float)
 97    elif origin == bytes:
 98        assert isinstance(value, bytes)
 99    elif origin == datetime:
100        assert isinstance(value, datetime)
101    elif origin == date:
102        assert isinstance(value, date)
103    elif origin == object:
104        # nothing to do here, the expected type is unknown
105        pass
106    elif origin == Literal:
107        assert value in get_args(type_)
108    elif origin == dict:
109        assert is_dict(value)
110
111        args = get_args(type_)
112        key_type = args[0]
113        items_type = args[1]
114
115        for key, item in value.items():
116            assert_matches_type(key_type, key, path=[*path, "<dict key>"])
117            assert_matches_type(items_type, item, path=[*path, "<dict item>"])
118    elif is_union_type(type_):
119        variants = get_args(type_)
120
121        try:
122            none_index = variants.index(type(None))
123        except ValueError:
124            pass
125        else:
126            # special case Optional[T] for better error messages
127            if len(variants) == 2:
128                if value is None:
129                    # valid
130                    return
131
132                return assert_matches_type(type_=variants[not none_index], value=value, path=path)
133
134        for i, variant in enumerate(variants):
135            try:
136                assert_matches_type(variant, value, path=[*path, f"variant {i}"])
137                return
138            except AssertionError:
139                traceback.print_exc()
140                continue
141
142        raise AssertionError("Did not match any variants")
143    elif issubclass(origin, BaseModel):
144        assert isinstance(value, type_)
145        assert assert_matches_model(type_, cast(Any, value), path=path)
146    elif inspect.isclass(origin) and origin.__name__ == "HttpxBinaryResponseContent":
147        assert value.__class__.__name__ == "HttpxBinaryResponseContent"
148    else:
149        assert None, f"Unhandled field type: {type_}"
150
151
152def _assert_list_type(type_: type[object], value: object) -> None:
153    assert is_list(value)
154
155    inner_type = get_args(type_)[0]
156    for entry in value:
157        assert_type(inner_type, entry)  # type: ignore
158
159
160def rich_print_str(obj: object) -> str:
161    """Like `rich.print()` but returns the string instead"""
162    buf = io.StringIO()
163
164    console = rich.console.Console(file=buf, width=120)
165    console.print(obj)
166
167    return buf.getvalue()
168
169
170@contextlib.contextmanager
171def update_env(**new_env: str | Omit) -> Iterator[None]:
172    old = os.environ.copy()
173
174    try:
175        for name, value in new_env.items():
176            if isinstance(value, Omit):
177                os.environ.pop(name, None)
178            else:
179                os.environ[name] = value
180
181        yield None
182    finally:
183        os.environ.clear()
184        os.environ.update(old)