Commit 463e870d
Changed files (4)
src
openai
tests
src/openai/_utils/__init__.py
@@ -41,6 +41,7 @@ from ._typing import (
extract_type_arg as extract_type_arg,
is_iterable_type as is_iterable_type,
is_required_type as is_required_type,
+ is_sequence_type as is_sequence_type,
is_annotated_type as is_annotated_type,
is_type_alias_type as is_type_alias_type,
strip_annotated_type as strip_annotated_type,
src/openai/_utils/_typing.py
@@ -26,6 +26,11 @@ def is_list_type(typ: type) -> bool:
return (get_origin(typ) or typ) == list
+def is_sequence_type(typ: type) -> bool:
+ origin = get_origin(typ) or typ
+ return origin == typing_extensions.Sequence or origin == typing.Sequence or origin == _c_abc.Sequence
+
+
def is_iterable_type(typ: type) -> bool:
"""If the given type is `typing.Iterable[T]`"""
origin = get_origin(typ) or typ
src/openai/_types.py
@@ -13,10 +13,21 @@ from typing import (
Mapping,
TypeVar,
Callable,
+ Iterator,
Optional,
Sequence,
)
-from typing_extensions import Set, Literal, Protocol, TypeAlias, TypedDict, override, runtime_checkable
+from typing_extensions import (
+ Set,
+ Literal,
+ Protocol,
+ TypeAlias,
+ TypedDict,
+ SupportsIndex,
+ overload,
+ override,
+ runtime_checkable,
+)
import httpx
import pydantic
@@ -219,3 +230,26 @@ class _GenericAlias(Protocol):
class HttpxSendArgs(TypedDict, total=False):
auth: httpx.Auth
follow_redirects: bool
+
+
+_T_co = TypeVar("_T_co", covariant=True)
+
+
+if TYPE_CHECKING:
+ # This works because str.__contains__ does not accept object (either in typeshed or at runtime)
+ # https://github.com/hauntsaninja/useful_types/blob/5e9710f3875107d068e7679fd7fec9cfab0eff3b/useful_types/__init__.py#L285
+ class SequenceNotStr(Protocol[_T_co]):
+ @overload
+ def __getitem__(self, index: SupportsIndex, /) -> _T_co: ...
+ @overload
+ def __getitem__(self, index: slice, /) -> Sequence[_T_co]: ...
+ def __contains__(self, value: object, /) -> bool: ...
+ def __len__(self) -> int: ...
+ def __iter__(self) -> Iterator[_T_co]: ...
+ def index(self, value: Any, start: int = 0, stop: int = ..., /) -> int: ...
+ def count(self, value: Any, /) -> int: ...
+ def __reversed__(self) -> Iterator[_T_co]: ...
+else:
+ # just point this to a normal `Sequence` at runtime to avoid having to special case
+ # deserializing our custom sequence type
+ SequenceNotStr = Sequence
tests/utils.py
@@ -5,7 +5,7 @@ import os
import inspect
import traceback
import contextlib
-from typing import Any, TypeVar, Iterator, ForwardRef, cast
+from typing import Any, TypeVar, Iterator, ForwardRef, Sequence, cast
from datetime import date, datetime
from typing_extensions import Literal, get_args, get_origin, assert_type
@@ -18,6 +18,7 @@ from openai._utils import (
is_list_type,
is_union_type,
extract_type_arg,
+ is_sequence_type,
is_annotated_type,
is_type_alias_type,
)
@@ -78,6 +79,13 @@ def assert_matches_type(
if is_list_type(type_):
return _assert_list_type(type_, value)
+ if is_sequence_type(type_):
+ assert isinstance(value, Sequence)
+ inner_type = get_args(type_)[0]
+ for entry in value: # type: ignore
+ assert_type(inner_type, entry) # type: ignore
+ return
+
if origin == str:
assert isinstance(value, str)
elif origin == int: