main
1from __future__ import annotations
2
3import os
4import re
5import inspect
6import functools
7from typing import (
8 TYPE_CHECKING,
9 Any,
10 Tuple,
11 Mapping,
12 TypeVar,
13 Callable,
14 Iterable,
15 Sequence,
16 cast,
17 overload,
18)
19from pathlib import Path
20from datetime import date, datetime
21from typing_extensions import TypeGuard
22
23import sniffio
24
25from .._types import Omit, NotGiven, FileTypes, HeadersLike
26
27_T = TypeVar("_T")
28_TupleT = TypeVar("_TupleT", bound=Tuple[object, ...])
29_MappingT = TypeVar("_MappingT", bound=Mapping[str, object])
30_SequenceT = TypeVar("_SequenceT", bound=Sequence[object])
31CallableT = TypeVar("CallableT", bound=Callable[..., Any])
32
33if TYPE_CHECKING:
34 from ..lib.azure import AzureOpenAI, AsyncAzureOpenAI
35
36
37def flatten(t: Iterable[Iterable[_T]]) -> list[_T]:
38 return [item for sublist in t for item in sublist]
39
40
41def extract_files(
42 # TODO: this needs to take Dict but variance issues.....
43 # create protocol type ?
44 query: Mapping[str, object],
45 *,
46 paths: Sequence[Sequence[str]],
47) -> list[tuple[str, FileTypes]]:
48 """Recursively extract files from the given dictionary based on specified paths.
49
50 A path may look like this ['foo', 'files', '<array>', 'data'].
51
52 Note: this mutates the given dictionary.
53 """
54 files: list[tuple[str, FileTypes]] = []
55 for path in paths:
56 files.extend(_extract_items(query, path, index=0, flattened_key=None))
57 return files
58
59
60def _extract_items(
61 obj: object,
62 path: Sequence[str],
63 *,
64 index: int,
65 flattened_key: str | None,
66) -> list[tuple[str, FileTypes]]:
67 try:
68 key = path[index]
69 except IndexError:
70 if not is_given(obj):
71 # no value was provided - we can safely ignore
72 return []
73
74 # cyclical import
75 from .._files import assert_is_file_content
76
77 # We have exhausted the path, return the entry we found.
78 assert flattened_key is not None
79
80 if is_list(obj):
81 files: list[tuple[str, FileTypes]] = []
82 for entry in obj:
83 assert_is_file_content(entry, key=flattened_key + "[]" if flattened_key else "")
84 files.append((flattened_key + "[]", cast(FileTypes, entry)))
85 return files
86
87 assert_is_file_content(obj, key=flattened_key)
88 return [(flattened_key, cast(FileTypes, obj))]
89
90 index += 1
91 if is_dict(obj):
92 try:
93 # We are at the last entry in the path so we must remove the field
94 if (len(path)) == index:
95 item = obj.pop(key)
96 else:
97 item = obj[key]
98 except KeyError:
99 # Key was not present in the dictionary, this is not indicative of an error
100 # as the given path may not point to a required field. We also do not want
101 # to enforce required fields as the API may differ from the spec in some cases.
102 return []
103 if flattened_key is None:
104 flattened_key = key
105 else:
106 flattened_key += f"[{key}]"
107 return _extract_items(
108 item,
109 path,
110 index=index,
111 flattened_key=flattened_key,
112 )
113 elif is_list(obj):
114 if key != "<array>":
115 return []
116
117 return flatten(
118 [
119 _extract_items(
120 item,
121 path,
122 index=index,
123 flattened_key=flattened_key + "[]" if flattened_key is not None else "[]",
124 )
125 for item in obj
126 ]
127 )
128
129 # Something unexpected was passed, just ignore it.
130 return []
131
132
133def is_given(obj: _T | NotGiven | Omit) -> TypeGuard[_T]:
134 return not isinstance(obj, NotGiven) and not isinstance(obj, Omit)
135
136
137# Type safe methods for narrowing types with TypeVars.
138# The default narrowing for isinstance(obj, dict) is dict[unknown, unknown],
139# however this cause Pyright to rightfully report errors. As we know we don't
140# care about the contained types we can safely use `object` in its place.
141#
142# There are two separate functions defined, `is_*` and `is_*_t` for different use cases.
143# `is_*` is for when you're dealing with an unknown input
144# `is_*_t` is for when you're narrowing a known union type to a specific subset
145
146
147def is_tuple(obj: object) -> TypeGuard[tuple[object, ...]]:
148 return isinstance(obj, tuple)
149
150
151def is_tuple_t(obj: _TupleT | object) -> TypeGuard[_TupleT]:
152 return isinstance(obj, tuple)
153
154
155def is_sequence(obj: object) -> TypeGuard[Sequence[object]]:
156 return isinstance(obj, Sequence)
157
158
159def is_sequence_t(obj: _SequenceT | object) -> TypeGuard[_SequenceT]:
160 return isinstance(obj, Sequence)
161
162
163def is_mapping(obj: object) -> TypeGuard[Mapping[str, object]]:
164 return isinstance(obj, Mapping)
165
166
167def is_mapping_t(obj: _MappingT | object) -> TypeGuard[_MappingT]:
168 return isinstance(obj, Mapping)
169
170
171def is_dict(obj: object) -> TypeGuard[dict[object, object]]:
172 return isinstance(obj, dict)
173
174
175def is_list(obj: object) -> TypeGuard[list[object]]:
176 return isinstance(obj, list)
177
178
179def is_iterable(obj: object) -> TypeGuard[Iterable[object]]:
180 return isinstance(obj, Iterable)
181
182
183def deepcopy_minimal(item: _T) -> _T:
184 """Minimal reimplementation of copy.deepcopy() that will only copy certain object types:
185
186 - mappings, e.g. `dict`
187 - list
188
189 This is done for performance reasons.
190 """
191 if is_mapping(item):
192 return cast(_T, {k: deepcopy_minimal(v) for k, v in item.items()})
193 if is_list(item):
194 return cast(_T, [deepcopy_minimal(entry) for entry in item])
195 return item
196
197
198# copied from https://github.com/Rapptz/RoboDanny
199def human_join(seq: Sequence[str], *, delim: str = ", ", final: str = "or") -> str:
200 size = len(seq)
201 if size == 0:
202 return ""
203
204 if size == 1:
205 return seq[0]
206
207 if size == 2:
208 return f"{seq[0]} {final} {seq[1]}"
209
210 return delim.join(seq[:-1]) + f" {final} {seq[-1]}"
211
212
213def quote(string: str) -> str:
214 """Add single quotation marks around the given string. Does *not* do any escaping."""
215 return f"'{string}'"
216
217
218def required_args(*variants: Sequence[str]) -> Callable[[CallableT], CallableT]:
219 """Decorator to enforce a given set of arguments or variants of arguments are passed to the decorated function.
220
221 Useful for enforcing runtime validation of overloaded functions.
222
223 Example usage:
224 ```py
225 @overload
226 def foo(*, a: str) -> str: ...
227
228
229 @overload
230 def foo(*, b: bool) -> str: ...
231
232
233 # This enforces the same constraints that a static type checker would
234 # i.e. that either a or b must be passed to the function
235 @required_args(["a"], ["b"])
236 def foo(*, a: str | None = None, b: bool | None = None) -> str: ...
237 ```
238 """
239
240 def inner(func: CallableT) -> CallableT:
241 params = inspect.signature(func).parameters
242 positional = [
243 name
244 for name, param in params.items()
245 if param.kind
246 in {
247 param.POSITIONAL_ONLY,
248 param.POSITIONAL_OR_KEYWORD,
249 }
250 ]
251
252 @functools.wraps(func)
253 def wrapper(*args: object, **kwargs: object) -> object:
254 given_params: set[str] = set()
255 for i, _ in enumerate(args):
256 try:
257 given_params.add(positional[i])
258 except IndexError:
259 raise TypeError(
260 f"{func.__name__}() takes {len(positional)} argument(s) but {len(args)} were given"
261 ) from None
262
263 for key in kwargs.keys():
264 given_params.add(key)
265
266 for variant in variants:
267 matches = all((param in given_params for param in variant))
268 if matches:
269 break
270 else: # no break
271 if len(variants) > 1:
272 variations = human_join(
273 ["(" + human_join([quote(arg) for arg in variant], final="and") + ")" for variant in variants]
274 )
275 msg = f"Missing required arguments; Expected either {variations} arguments to be given"
276 else:
277 assert len(variants) > 0
278
279 # TODO: this error message is not deterministic
280 missing = list(set(variants[0]) - given_params)
281 if len(missing) > 1:
282 msg = f"Missing required arguments: {human_join([quote(arg) for arg in missing])}"
283 else:
284 msg = f"Missing required argument: {quote(missing[0])}"
285 raise TypeError(msg)
286 return func(*args, **kwargs)
287
288 return wrapper # type: ignore
289
290 return inner
291
292
293_K = TypeVar("_K")
294_V = TypeVar("_V")
295
296
297@overload
298def strip_not_given(obj: None) -> None: ...
299
300
301@overload
302def strip_not_given(obj: Mapping[_K, _V | NotGiven]) -> dict[_K, _V]: ...
303
304
305@overload
306def strip_not_given(obj: object) -> object: ...
307
308
309def strip_not_given(obj: object | None) -> object:
310 """Remove all top-level keys where their values are instances of `NotGiven`"""
311 if obj is None:
312 return None
313
314 if not is_mapping(obj):
315 return obj
316
317 return {key: value for key, value in obj.items() if not isinstance(value, NotGiven)}
318
319
320def coerce_integer(val: str) -> int:
321 return int(val, base=10)
322
323
324def coerce_float(val: str) -> float:
325 return float(val)
326
327
328def coerce_boolean(val: str) -> bool:
329 return val == "true" or val == "1" or val == "on"
330
331
332def maybe_coerce_integer(val: str | None) -> int | None:
333 if val is None:
334 return None
335 return coerce_integer(val)
336
337
338def maybe_coerce_float(val: str | None) -> float | None:
339 if val is None:
340 return None
341 return coerce_float(val)
342
343
344def maybe_coerce_boolean(val: str | None) -> bool | None:
345 if val is None:
346 return None
347 return coerce_boolean(val)
348
349
350def removeprefix(string: str, prefix: str) -> str:
351 """Remove a prefix from a string.
352
353 Backport of `str.removeprefix` for Python < 3.9
354 """
355 if string.startswith(prefix):
356 return string[len(prefix) :]
357 return string
358
359
360def removesuffix(string: str, suffix: str) -> str:
361 """Remove a suffix from a string.
362
363 Backport of `str.removesuffix` for Python < 3.9
364 """
365 if string.endswith(suffix):
366 return string[: -len(suffix)]
367 return string
368
369
370def file_from_path(path: str) -> FileTypes:
371 contents = Path(path).read_bytes()
372 file_name = os.path.basename(path)
373 return (file_name, contents)
374
375
376def get_required_header(headers: HeadersLike, header: str) -> str:
377 lower_header = header.lower()
378 if is_mapping_t(headers):
379 # mypy doesn't understand the type narrowing here
380 for k, v in headers.items(): # type: ignore
381 if k.lower() == lower_header and isinstance(v, str):
382 return v
383
384 # to deal with the case where the header looks like Stainless-Event-Id
385 intercaps_header = re.sub(r"([^\w])(\w)", lambda pat: pat.group(1) + pat.group(2).upper(), header.capitalize())
386
387 for normalized_header in [header, lower_header, header.upper(), intercaps_header]:
388 value = headers.get(normalized_header)
389 if value:
390 return value
391
392 raise ValueError(f"Could not find {header} header")
393
394
395def get_async_library() -> str:
396 try:
397 return sniffio.current_async_library()
398 except Exception:
399 return "false"
400
401
402def lru_cache(*, maxsize: int | None = 128) -> Callable[[CallableT], CallableT]:
403 """A version of functools.lru_cache that retains the type signature
404 for the wrapped function arguments.
405 """
406 wrapper = functools.lru_cache( # noqa: TID251
407 maxsize=maxsize,
408 )
409 return cast(Any, wrapper) # type: ignore[no-any-return]
410
411
412def json_safe(data: object) -> object:
413 """Translates a mapping / sequence recursively in the same fashion
414 as `pydantic` v2's `model_dump(mode="json")`.
415 """
416 if is_mapping(data):
417 return {json_safe(key): json_safe(value) for key, value in data.items()}
418
419 if is_iterable(data) and not isinstance(data, (str, bytes, bytearray)):
420 return [json_safe(item) for item in data]
421
422 if isinstance(data, (datetime, date)):
423 return data.isoformat()
424
425 return data
426
427
428def is_azure_client(client: object) -> TypeGuard[AzureOpenAI]:
429 from ..lib.azure import AzureOpenAI
430
431 return isinstance(client, AzureOpenAI)
432
433
434def is_async_azure_client(client: object) -> TypeGuard[AsyncAzureOpenAI]:
435 from ..lib.azure import AsyncAzureOpenAI
436
437 return isinstance(client, AsyncAzureOpenAI)