main
  1from __future__ import annotations
  2
  3import os
  4import re
  5import inspect
  6import functools
  7from typing import (
  8    TYPE_CHECKING,
  9    Any,
 10    Tuple,
 11    Mapping,
 12    TypeVar,
 13    Callable,
 14    Iterable,
 15    Sequence,
 16    cast,
 17    overload,
 18)
 19from pathlib import Path
 20from datetime import date, datetime
 21from typing_extensions import TypeGuard
 22
 23import sniffio
 24
 25from .._types import Omit, NotGiven, FileTypes, HeadersLike
 26
 27_T = TypeVar("_T")
 28_TupleT = TypeVar("_TupleT", bound=Tuple[object, ...])
 29_MappingT = TypeVar("_MappingT", bound=Mapping[str, object])
 30_SequenceT = TypeVar("_SequenceT", bound=Sequence[object])
 31CallableT = TypeVar("CallableT", bound=Callable[..., Any])
 32
 33if TYPE_CHECKING:
 34    from ..lib.azure import AzureOpenAI, AsyncAzureOpenAI
 35
 36
 37def flatten(t: Iterable[Iterable[_T]]) -> list[_T]:
 38    return [item for sublist in t for item in sublist]
 39
 40
 41def extract_files(
 42    # TODO: this needs to take Dict but variance issues.....
 43    # create protocol type ?
 44    query: Mapping[str, object],
 45    *,
 46    paths: Sequence[Sequence[str]],
 47) -> list[tuple[str, FileTypes]]:
 48    """Recursively extract files from the given dictionary based on specified paths.
 49
 50    A path may look like this ['foo', 'files', '<array>', 'data'].
 51
 52    Note: this mutates the given dictionary.
 53    """
 54    files: list[tuple[str, FileTypes]] = []
 55    for path in paths:
 56        files.extend(_extract_items(query, path, index=0, flattened_key=None))
 57    return files
 58
 59
 60def _extract_items(
 61    obj: object,
 62    path: Sequence[str],
 63    *,
 64    index: int,
 65    flattened_key: str | None,
 66) -> list[tuple[str, FileTypes]]:
 67    try:
 68        key = path[index]
 69    except IndexError:
 70        if not is_given(obj):
 71            # no value was provided - we can safely ignore
 72            return []
 73
 74        # cyclical import
 75        from .._files import assert_is_file_content
 76
 77        # We have exhausted the path, return the entry we found.
 78        assert flattened_key is not None
 79
 80        if is_list(obj):
 81            files: list[tuple[str, FileTypes]] = []
 82            for entry in obj:
 83                assert_is_file_content(entry, key=flattened_key + "[]" if flattened_key else "")
 84                files.append((flattened_key + "[]", cast(FileTypes, entry)))
 85            return files
 86
 87        assert_is_file_content(obj, key=flattened_key)
 88        return [(flattened_key, cast(FileTypes, obj))]
 89
 90    index += 1
 91    if is_dict(obj):
 92        try:
 93            # We are at the last entry in the path so we must remove the field
 94            if (len(path)) == index:
 95                item = obj.pop(key)
 96            else:
 97                item = obj[key]
 98        except KeyError:
 99            # Key was not present in the dictionary, this is not indicative of an error
100            # as the given path may not point to a required field. We also do not want
101            # to enforce required fields as the API may differ from the spec in some cases.
102            return []
103        if flattened_key is None:
104            flattened_key = key
105        else:
106            flattened_key += f"[{key}]"
107        return _extract_items(
108            item,
109            path,
110            index=index,
111            flattened_key=flattened_key,
112        )
113    elif is_list(obj):
114        if key != "<array>":
115            return []
116
117        return flatten(
118            [
119                _extract_items(
120                    item,
121                    path,
122                    index=index,
123                    flattened_key=flattened_key + "[]" if flattened_key is not None else "[]",
124                )
125                for item in obj
126            ]
127        )
128
129    # Something unexpected was passed, just ignore it.
130    return []
131
132
133def is_given(obj: _T | NotGiven | Omit) -> TypeGuard[_T]:
134    return not isinstance(obj, NotGiven) and not isinstance(obj, Omit)
135
136
137# Type safe methods for narrowing types with TypeVars.
138# The default narrowing for isinstance(obj, dict) is dict[unknown, unknown],
139# however this cause Pyright to rightfully report errors. As we know we don't
140# care about the contained types we can safely use `object` in its place.
141#
142# There are two separate functions defined, `is_*` and `is_*_t` for different use cases.
143# `is_*` is for when you're dealing with an unknown input
144# `is_*_t` is for when you're narrowing a known union type to a specific subset
145
146
147def is_tuple(obj: object) -> TypeGuard[tuple[object, ...]]:
148    return isinstance(obj, tuple)
149
150
151def is_tuple_t(obj: _TupleT | object) -> TypeGuard[_TupleT]:
152    return isinstance(obj, tuple)
153
154
155def is_sequence(obj: object) -> TypeGuard[Sequence[object]]:
156    return isinstance(obj, Sequence)
157
158
159def is_sequence_t(obj: _SequenceT | object) -> TypeGuard[_SequenceT]:
160    return isinstance(obj, Sequence)
161
162
163def is_mapping(obj: object) -> TypeGuard[Mapping[str, object]]:
164    return isinstance(obj, Mapping)
165
166
167def is_mapping_t(obj: _MappingT | object) -> TypeGuard[_MappingT]:
168    return isinstance(obj, Mapping)
169
170
171def is_dict(obj: object) -> TypeGuard[dict[object, object]]:
172    return isinstance(obj, dict)
173
174
175def is_list(obj: object) -> TypeGuard[list[object]]:
176    return isinstance(obj, list)
177
178
179def is_iterable(obj: object) -> TypeGuard[Iterable[object]]:
180    return isinstance(obj, Iterable)
181
182
183def deepcopy_minimal(item: _T) -> _T:
184    """Minimal reimplementation of copy.deepcopy() that will only copy certain object types:
185
186    - mappings, e.g. `dict`
187    - list
188
189    This is done for performance reasons.
190    """
191    if is_mapping(item):
192        return cast(_T, {k: deepcopy_minimal(v) for k, v in item.items()})
193    if is_list(item):
194        return cast(_T, [deepcopy_minimal(entry) for entry in item])
195    return item
196
197
198# copied from https://github.com/Rapptz/RoboDanny
199def human_join(seq: Sequence[str], *, delim: str = ", ", final: str = "or") -> str:
200    size = len(seq)
201    if size == 0:
202        return ""
203
204    if size == 1:
205        return seq[0]
206
207    if size == 2:
208        return f"{seq[0]} {final} {seq[1]}"
209
210    return delim.join(seq[:-1]) + f" {final} {seq[-1]}"
211
212
213def quote(string: str) -> str:
214    """Add single quotation marks around the given string. Does *not* do any escaping."""
215    return f"'{string}'"
216
217
218def required_args(*variants: Sequence[str]) -> Callable[[CallableT], CallableT]:
219    """Decorator to enforce a given set of arguments or variants of arguments are passed to the decorated function.
220
221    Useful for enforcing runtime validation of overloaded functions.
222
223    Example usage:
224    ```py
225    @overload
226    def foo(*, a: str) -> str: ...
227
228
229    @overload
230    def foo(*, b: bool) -> str: ...
231
232
233    # This enforces the same constraints that a static type checker would
234    # i.e. that either a or b must be passed to the function
235    @required_args(["a"], ["b"])
236    def foo(*, a: str | None = None, b: bool | None = None) -> str: ...
237    ```
238    """
239
240    def inner(func: CallableT) -> CallableT:
241        params = inspect.signature(func).parameters
242        positional = [
243            name
244            for name, param in params.items()
245            if param.kind
246            in {
247                param.POSITIONAL_ONLY,
248                param.POSITIONAL_OR_KEYWORD,
249            }
250        ]
251
252        @functools.wraps(func)
253        def wrapper(*args: object, **kwargs: object) -> object:
254            given_params: set[str] = set()
255            for i, _ in enumerate(args):
256                try:
257                    given_params.add(positional[i])
258                except IndexError:
259                    raise TypeError(
260                        f"{func.__name__}() takes {len(positional)} argument(s) but {len(args)} were given"
261                    ) from None
262
263            for key in kwargs.keys():
264                given_params.add(key)
265
266            for variant in variants:
267                matches = all((param in given_params for param in variant))
268                if matches:
269                    break
270            else:  # no break
271                if len(variants) > 1:
272                    variations = human_join(
273                        ["(" + human_join([quote(arg) for arg in variant], final="and") + ")" for variant in variants]
274                    )
275                    msg = f"Missing required arguments; Expected either {variations} arguments to be given"
276                else:
277                    assert len(variants) > 0
278
279                    # TODO: this error message is not deterministic
280                    missing = list(set(variants[0]) - given_params)
281                    if len(missing) > 1:
282                        msg = f"Missing required arguments: {human_join([quote(arg) for arg in missing])}"
283                    else:
284                        msg = f"Missing required argument: {quote(missing[0])}"
285                raise TypeError(msg)
286            return func(*args, **kwargs)
287
288        return wrapper  # type: ignore
289
290    return inner
291
292
293_K = TypeVar("_K")
294_V = TypeVar("_V")
295
296
297@overload
298def strip_not_given(obj: None) -> None: ...
299
300
301@overload
302def strip_not_given(obj: Mapping[_K, _V | NotGiven]) -> dict[_K, _V]: ...
303
304
305@overload
306def strip_not_given(obj: object) -> object: ...
307
308
309def strip_not_given(obj: object | None) -> object:
310    """Remove all top-level keys where their values are instances of `NotGiven`"""
311    if obj is None:
312        return None
313
314    if not is_mapping(obj):
315        return obj
316
317    return {key: value for key, value in obj.items() if not isinstance(value, NotGiven)}
318
319
320def coerce_integer(val: str) -> int:
321    return int(val, base=10)
322
323
324def coerce_float(val: str) -> float:
325    return float(val)
326
327
328def coerce_boolean(val: str) -> bool:
329    return val == "true" or val == "1" or val == "on"
330
331
332def maybe_coerce_integer(val: str | None) -> int | None:
333    if val is None:
334        return None
335    return coerce_integer(val)
336
337
338def maybe_coerce_float(val: str | None) -> float | None:
339    if val is None:
340        return None
341    return coerce_float(val)
342
343
344def maybe_coerce_boolean(val: str | None) -> bool | None:
345    if val is None:
346        return None
347    return coerce_boolean(val)
348
349
350def removeprefix(string: str, prefix: str) -> str:
351    """Remove a prefix from a string.
352
353    Backport of `str.removeprefix` for Python < 3.9
354    """
355    if string.startswith(prefix):
356        return string[len(prefix) :]
357    return string
358
359
360def removesuffix(string: str, suffix: str) -> str:
361    """Remove a suffix from a string.
362
363    Backport of `str.removesuffix` for Python < 3.9
364    """
365    if string.endswith(suffix):
366        return string[: -len(suffix)]
367    return string
368
369
370def file_from_path(path: str) -> FileTypes:
371    contents = Path(path).read_bytes()
372    file_name = os.path.basename(path)
373    return (file_name, contents)
374
375
376def get_required_header(headers: HeadersLike, header: str) -> str:
377    lower_header = header.lower()
378    if is_mapping_t(headers):
379        # mypy doesn't understand the type narrowing here
380        for k, v in headers.items():  # type: ignore
381            if k.lower() == lower_header and isinstance(v, str):
382                return v
383
384    # to deal with the case where the header looks like Stainless-Event-Id
385    intercaps_header = re.sub(r"([^\w])(\w)", lambda pat: pat.group(1) + pat.group(2).upper(), header.capitalize())
386
387    for normalized_header in [header, lower_header, header.upper(), intercaps_header]:
388        value = headers.get(normalized_header)
389        if value:
390            return value
391
392    raise ValueError(f"Could not find {header} header")
393
394
395def get_async_library() -> str:
396    try:
397        return sniffio.current_async_library()
398    except Exception:
399        return "false"
400
401
402def lru_cache(*, maxsize: int | None = 128) -> Callable[[CallableT], CallableT]:
403    """A version of functools.lru_cache that retains the type signature
404    for the wrapped function arguments.
405    """
406    wrapper = functools.lru_cache(  # noqa: TID251
407        maxsize=maxsize,
408    )
409    return cast(Any, wrapper)  # type: ignore[no-any-return]
410
411
412def json_safe(data: object) -> object:
413    """Translates a mapping / sequence recursively in the same fashion
414    as `pydantic` v2's `model_dump(mode="json")`.
415    """
416    if is_mapping(data):
417        return {json_safe(key): json_safe(value) for key, value in data.items()}
418
419    if is_iterable(data) and not isinstance(data, (str, bytes, bytearray)):
420        return [json_safe(item) for item in data]
421
422    if isinstance(data, (datetime, date)):
423        return data.isoformat()
424
425    return data
426
427
428def is_azure_client(client: object) -> TypeGuard[AzureOpenAI]:
429    from ..lib.azure import AzureOpenAI
430
431    return isinstance(client, AzureOpenAI)
432
433
434def is_async_azure_client(client: object) -> TypeGuard[AsyncAzureOpenAI]:
435    from ..lib.azure import AsyncAzureOpenAI
436
437    return isinstance(client, AsyncAzureOpenAI)