main
  1from __future__ import annotations
  2
  3import sys
  4import typing
  5import typing_extensions
  6from typing import Any, TypeVar, Iterable, cast
  7from collections import abc as _c_abc
  8from typing_extensions import (
  9    TypeIs,
 10    Required,
 11    Annotated,
 12    get_args,
 13    get_origin,
 14)
 15
 16from ._utils import lru_cache
 17from .._types import InheritsGeneric
 18from ._compat import is_union as _is_union
 19
 20
 21def is_annotated_type(typ: type) -> bool:
 22    return get_origin(typ) == Annotated
 23
 24
 25def is_list_type(typ: type) -> bool:
 26    return (get_origin(typ) or typ) == list
 27
 28
 29def is_sequence_type(typ: type) -> bool:
 30    origin = get_origin(typ) or typ
 31    return origin == typing_extensions.Sequence or origin == typing.Sequence or origin == _c_abc.Sequence
 32
 33
 34def is_iterable_type(typ: type) -> bool:
 35    """If the given type is `typing.Iterable[T]`"""
 36    origin = get_origin(typ) or typ
 37    return origin == Iterable or origin == _c_abc.Iterable
 38
 39
 40def is_union_type(typ: type) -> bool:
 41    return _is_union(get_origin(typ))
 42
 43
 44def is_required_type(typ: type) -> bool:
 45    return get_origin(typ) == Required
 46
 47
 48def is_typevar(typ: type) -> bool:
 49    # type ignore is required because type checkers
 50    # think this expression will always return False
 51    return type(typ) == TypeVar  # type: ignore
 52
 53
 54_TYPE_ALIAS_TYPES: tuple[type[typing_extensions.TypeAliasType], ...] = (typing_extensions.TypeAliasType,)
 55if sys.version_info >= (3, 12):
 56    _TYPE_ALIAS_TYPES = (*_TYPE_ALIAS_TYPES, typing.TypeAliasType)
 57
 58
 59def is_type_alias_type(tp: Any, /) -> TypeIs[typing_extensions.TypeAliasType]:
 60    """Return whether the provided argument is an instance of `TypeAliasType`.
 61
 62    ```python
 63    type Int = int
 64    is_type_alias_type(Int)
 65    # > True
 66    Str = TypeAliasType("Str", str)
 67    is_type_alias_type(Str)
 68    # > True
 69    ```
 70    """
 71    return isinstance(tp, _TYPE_ALIAS_TYPES)
 72
 73
 74# Extracts T from Annotated[T, ...] or from Required[Annotated[T, ...]]
 75@lru_cache(maxsize=8096)
 76def strip_annotated_type(typ: type) -> type:
 77    if is_required_type(typ) or is_annotated_type(typ):
 78        return strip_annotated_type(cast(type, get_args(typ)[0]))
 79
 80    return typ
 81
 82
 83def extract_type_arg(typ: type, index: int) -> type:
 84    args = get_args(typ)
 85    try:
 86        return cast(type, args[index])
 87    except IndexError as err:
 88        raise RuntimeError(f"Expected type {typ} to have a type argument at index {index} but it did not") from err
 89
 90
 91def extract_type_var_from_base(
 92    typ: type,
 93    *,
 94    generic_bases: tuple[type, ...],
 95    index: int,
 96    failure_message: str | None = None,
 97) -> type:
 98    """Given a type like `Foo[T]`, returns the generic type variable `T`.
 99
100    This also handles the case where a concrete subclass is given, e.g.
101    ```py
102    class MyResponse(Foo[bytes]):
103        ...
104
105    extract_type_var(MyResponse, bases=(Foo,), index=0) -> bytes
106    ```
107
108    And where a generic subclass is given:
109    ```py
110    _T = TypeVar('_T')
111    class MyResponse(Foo[_T]):
112        ...
113
114    extract_type_var(MyResponse[bytes], bases=(Foo,), index=0) -> bytes
115    ```
116    """
117    cls = cast(object, get_origin(typ) or typ)
118    if cls in generic_bases:  # pyright: ignore[reportUnnecessaryContains]
119        # we're given the class directly
120        return extract_type_arg(typ, index)
121
122    # if a subclass is given
123    # ---
124    # this is needed as __orig_bases__ is not present in the typeshed stubs
125    # because it is intended to be for internal use only, however there does
126    # not seem to be a way to resolve generic TypeVars for inherited subclasses
127    # without using it.
128    if isinstance(cls, InheritsGeneric):
129        target_base_class: Any | None = None
130        for base in cls.__orig_bases__:
131            if base.__origin__ in generic_bases:
132                target_base_class = base
133                break
134
135        if target_base_class is None:
136            raise RuntimeError(
137                "Could not find the generic base class;\n"
138                "This should never happen;\n"
139                f"Does {cls} inherit from one of {generic_bases} ?"
140            )
141
142        extracted = extract_type_arg(target_base_class, index)
143        if is_typevar(extracted):
144            # If the extracted type argument is itself a type variable
145            # then that means the subclass itself is generic, so we have
146            # to resolve the type argument from the class itself, not
147            # the base class.
148            #
149            # Note: if there is more than 1 type argument, the subclass could
150            # change the ordering of the type arguments, this is not currently
151            # supported.
152            return extract_type_arg(typ, index)
153
154        return extracted
155
156    raise RuntimeError(failure_message or f"Could not resolve inner type variable at index {index} for {typ}")