main
1from __future__ import annotations
2
3import io
4import base64
5import pathlib
6from typing import Any, Mapping, TypeVar, cast
7from datetime import date, datetime
8from typing_extensions import Literal, get_args, override, get_type_hints as _get_type_hints
9
10import anyio
11import pydantic
12
13from ._utils import (
14 is_list,
15 is_given,
16 lru_cache,
17 is_mapping,
18 is_iterable,
19 is_sequence,
20)
21from .._files import is_base64_file_input
22from ._compat import get_origin, is_typeddict
23from ._typing import (
24 is_list_type,
25 is_union_type,
26 extract_type_arg,
27 is_iterable_type,
28 is_required_type,
29 is_sequence_type,
30 is_annotated_type,
31 strip_annotated_type,
32)
33
34_T = TypeVar("_T")
35
36
37# TODO: support for drilling globals() and locals()
38# TODO: ensure works correctly with forward references in all cases
39
40
41PropertyFormat = Literal["iso8601", "base64", "custom"]
42
43
44class PropertyInfo:
45 """Metadata class to be used in Annotated types to provide information about a given type.
46
47 For example:
48
49 class MyParams(TypedDict):
50 account_holder_name: Annotated[str, PropertyInfo(alias='accountHolderName')]
51
52 This means that {'account_holder_name': 'Robert'} will be transformed to {'accountHolderName': 'Robert'} before being sent to the API.
53 """
54
55 alias: str | None
56 format: PropertyFormat | None
57 format_template: str | None
58 discriminator: str | None
59
60 def __init__(
61 self,
62 *,
63 alias: str | None = None,
64 format: PropertyFormat | None = None,
65 format_template: str | None = None,
66 discriminator: str | None = None,
67 ) -> None:
68 self.alias = alias
69 self.format = format
70 self.format_template = format_template
71 self.discriminator = discriminator
72
73 @override
74 def __repr__(self) -> str:
75 return f"{self.__class__.__name__}(alias='{self.alias}', format={self.format}, format_template='{self.format_template}', discriminator='{self.discriminator}')"
76
77
78def maybe_transform(
79 data: object,
80 expected_type: object,
81) -> Any | None:
82 """Wrapper over `transform()` that allows `None` to be passed.
83
84 See `transform()` for more details.
85 """
86 if data is None:
87 return None
88 return transform(data, expected_type)
89
90
91# Wrapper over _transform_recursive providing fake types
92def transform(
93 data: _T,
94 expected_type: object,
95) -> _T:
96 """Transform dictionaries based off of type information from the given type, for example:
97
98 ```py
99 class Params(TypedDict, total=False):
100 card_id: Required[Annotated[str, PropertyInfo(alias="cardID")]]
101
102
103 transformed = transform({"card_id": "<my card ID>"}, Params)
104 # {'cardID': '<my card ID>'}
105 ```
106
107 Any keys / data that does not have type information given will be included as is.
108
109 It should be noted that the transformations that this function does are not represented in the type system.
110 """
111 transformed = _transform_recursive(data, annotation=cast(type, expected_type))
112 return cast(_T, transformed)
113
114
115@lru_cache(maxsize=8096)
116def _get_annotated_type(type_: type) -> type | None:
117 """If the given type is an `Annotated` type then it is returned, if not `None` is returned.
118
119 This also unwraps the type when applicable, e.g. `Required[Annotated[T, ...]]`
120 """
121 if is_required_type(type_):
122 # Unwrap `Required[Annotated[T, ...]]` to `Annotated[T, ...]`
123 type_ = get_args(type_)[0]
124
125 if is_annotated_type(type_):
126 return type_
127
128 return None
129
130
131def _maybe_transform_key(key: str, type_: type) -> str:
132 """Transform the given `data` based on the annotations provided in `type_`.
133
134 Note: this function only looks at `Annotated` types that contain `PropertyInfo` metadata.
135 """
136 annotated_type = _get_annotated_type(type_)
137 if annotated_type is None:
138 # no `Annotated` definition for this type, no transformation needed
139 return key
140
141 # ignore the first argument as it is the actual type
142 annotations = get_args(annotated_type)[1:]
143 for annotation in annotations:
144 if isinstance(annotation, PropertyInfo) and annotation.alias is not None:
145 return annotation.alias
146
147 return key
148
149
150def _no_transform_needed(annotation: type) -> bool:
151 return annotation == float or annotation == int
152
153
154def _transform_recursive(
155 data: object,
156 *,
157 annotation: type,
158 inner_type: type | None = None,
159) -> object:
160 """Transform the given data against the expected type.
161
162 Args:
163 annotation: The direct type annotation given to the particular piece of data.
164 This may or may not be wrapped in metadata types, e.g. `Required[T]`, `Annotated[T, ...]` etc
165
166 inner_type: If applicable, this is the "inside" type. This is useful in certain cases where the outside type
167 is a container type such as `List[T]`. In that case `inner_type` should be set to `T` so that each entry in
168 the list can be transformed using the metadata from the container type.
169
170 Defaults to the same value as the `annotation` argument.
171 """
172 from .._compat import model_dump
173
174 if inner_type is None:
175 inner_type = annotation
176
177 stripped_type = strip_annotated_type(inner_type)
178 origin = get_origin(stripped_type) or stripped_type
179 if is_typeddict(stripped_type) and is_mapping(data):
180 return _transform_typeddict(data, stripped_type)
181
182 if origin == dict and is_mapping(data):
183 items_type = get_args(stripped_type)[1]
184 return {key: _transform_recursive(value, annotation=items_type) for key, value in data.items()}
185
186 if (
187 # List[T]
188 (is_list_type(stripped_type) and is_list(data))
189 # Iterable[T]
190 or (is_iterable_type(stripped_type) and is_iterable(data) and not isinstance(data, str))
191 # Sequence[T]
192 or (is_sequence_type(stripped_type) and is_sequence(data) and not isinstance(data, str))
193 ):
194 # dicts are technically iterable, but it is an iterable on the keys of the dict and is not usually
195 # intended as an iterable, so we don't transform it.
196 if isinstance(data, dict):
197 return cast(object, data)
198
199 inner_type = extract_type_arg(stripped_type, 0)
200 if _no_transform_needed(inner_type):
201 # for some types there is no need to transform anything, so we can get a small
202 # perf boost from skipping that work.
203 #
204 # but we still need to convert to a list to ensure the data is json-serializable
205 if is_list(data):
206 return data
207 return list(data)
208
209 return [_transform_recursive(d, annotation=annotation, inner_type=inner_type) for d in data]
210
211 if is_union_type(stripped_type):
212 # For union types we run the transformation against all subtypes to ensure that everything is transformed.
213 #
214 # TODO: there may be edge cases where the same normalized field name will transform to two different names
215 # in different subtypes.
216 for subtype in get_args(stripped_type):
217 data = _transform_recursive(data, annotation=annotation, inner_type=subtype)
218 return data
219
220 if isinstance(data, pydantic.BaseModel):
221 return model_dump(data, exclude_unset=True, mode="json", exclude=getattr(data, "__api_exclude__", None))
222
223 annotated_type = _get_annotated_type(annotation)
224 if annotated_type is None:
225 return data
226
227 # ignore the first argument as it is the actual type
228 annotations = get_args(annotated_type)[1:]
229 for annotation in annotations:
230 if isinstance(annotation, PropertyInfo) and annotation.format is not None:
231 return _format_data(data, annotation.format, annotation.format_template)
232
233 return data
234
235
236def _format_data(data: object, format_: PropertyFormat, format_template: str | None) -> object:
237 if isinstance(data, (date, datetime)):
238 if format_ == "iso8601":
239 return data.isoformat()
240
241 if format_ == "custom" and format_template is not None:
242 return data.strftime(format_template)
243
244 if format_ == "base64" and is_base64_file_input(data):
245 binary: str | bytes | None = None
246
247 if isinstance(data, pathlib.Path):
248 binary = data.read_bytes()
249 elif isinstance(data, io.IOBase):
250 binary = data.read()
251
252 if isinstance(binary, str): # type: ignore[unreachable]
253 binary = binary.encode()
254
255 if not isinstance(binary, bytes):
256 raise RuntimeError(f"Could not read bytes from {data}; Received {type(binary)}")
257
258 return base64.b64encode(binary).decode("ascii")
259
260 return data
261
262
263def _transform_typeddict(
264 data: Mapping[str, object],
265 expected_type: type,
266) -> Mapping[str, object]:
267 result: dict[str, object] = {}
268 annotations = get_type_hints(expected_type, include_extras=True)
269 for key, value in data.items():
270 if not is_given(value):
271 # we don't need to include omitted values here as they'll
272 # be stripped out before the request is sent anyway
273 continue
274
275 type_ = annotations.get(key)
276 if type_ is None:
277 # we do not have a type annotation for this field, leave it as is
278 result[key] = value
279 else:
280 result[_maybe_transform_key(key, type_)] = _transform_recursive(value, annotation=type_)
281 return result
282
283
284async def async_maybe_transform(
285 data: object,
286 expected_type: object,
287) -> Any | None:
288 """Wrapper over `async_transform()` that allows `None` to be passed.
289
290 See `async_transform()` for more details.
291 """
292 if data is None:
293 return None
294 return await async_transform(data, expected_type)
295
296
297async def async_transform(
298 data: _T,
299 expected_type: object,
300) -> _T:
301 """Transform dictionaries based off of type information from the given type, for example:
302
303 ```py
304 class Params(TypedDict, total=False):
305 card_id: Required[Annotated[str, PropertyInfo(alias="cardID")]]
306
307
308 transformed = transform({"card_id": "<my card ID>"}, Params)
309 # {'cardID': '<my card ID>'}
310 ```
311
312 Any keys / data that does not have type information given will be included as is.
313
314 It should be noted that the transformations that this function does are not represented in the type system.
315 """
316 transformed = await _async_transform_recursive(data, annotation=cast(type, expected_type))
317 return cast(_T, transformed)
318
319
320async def _async_transform_recursive(
321 data: object,
322 *,
323 annotation: type,
324 inner_type: type | None = None,
325) -> object:
326 """Transform the given data against the expected type.
327
328 Args:
329 annotation: The direct type annotation given to the particular piece of data.
330 This may or may not be wrapped in metadata types, e.g. `Required[T]`, `Annotated[T, ...]` etc
331
332 inner_type: If applicable, this is the "inside" type. This is useful in certain cases where the outside type
333 is a container type such as `List[T]`. In that case `inner_type` should be set to `T` so that each entry in
334 the list can be transformed using the metadata from the container type.
335
336 Defaults to the same value as the `annotation` argument.
337 """
338 from .._compat import model_dump
339
340 if inner_type is None:
341 inner_type = annotation
342
343 stripped_type = strip_annotated_type(inner_type)
344 origin = get_origin(stripped_type) or stripped_type
345 if is_typeddict(stripped_type) and is_mapping(data):
346 return await _async_transform_typeddict(data, stripped_type)
347
348 if origin == dict and is_mapping(data):
349 items_type = get_args(stripped_type)[1]
350 return {key: _transform_recursive(value, annotation=items_type) for key, value in data.items()}
351
352 if (
353 # List[T]
354 (is_list_type(stripped_type) and is_list(data))
355 # Iterable[T]
356 or (is_iterable_type(stripped_type) and is_iterable(data) and not isinstance(data, str))
357 # Sequence[T]
358 or (is_sequence_type(stripped_type) and is_sequence(data) and not isinstance(data, str))
359 ):
360 # dicts are technically iterable, but it is an iterable on the keys of the dict and is not usually
361 # intended as an iterable, so we don't transform it.
362 if isinstance(data, dict):
363 return cast(object, data)
364
365 inner_type = extract_type_arg(stripped_type, 0)
366 if _no_transform_needed(inner_type):
367 # for some types there is no need to transform anything, so we can get a small
368 # perf boost from skipping that work.
369 #
370 # but we still need to convert to a list to ensure the data is json-serializable
371 if is_list(data):
372 return data
373 return list(data)
374
375 return [await _async_transform_recursive(d, annotation=annotation, inner_type=inner_type) for d in data]
376
377 if is_union_type(stripped_type):
378 # For union types we run the transformation against all subtypes to ensure that everything is transformed.
379 #
380 # TODO: there may be edge cases where the same normalized field name will transform to two different names
381 # in different subtypes.
382 for subtype in get_args(stripped_type):
383 data = await _async_transform_recursive(data, annotation=annotation, inner_type=subtype)
384 return data
385
386 if isinstance(data, pydantic.BaseModel):
387 return model_dump(data, exclude_unset=True, mode="json")
388
389 annotated_type = _get_annotated_type(annotation)
390 if annotated_type is None:
391 return data
392
393 # ignore the first argument as it is the actual type
394 annotations = get_args(annotated_type)[1:]
395 for annotation in annotations:
396 if isinstance(annotation, PropertyInfo) and annotation.format is not None:
397 return await _async_format_data(data, annotation.format, annotation.format_template)
398
399 return data
400
401
402async def _async_format_data(data: object, format_: PropertyFormat, format_template: str | None) -> object:
403 if isinstance(data, (date, datetime)):
404 if format_ == "iso8601":
405 return data.isoformat()
406
407 if format_ == "custom" and format_template is not None:
408 return data.strftime(format_template)
409
410 if format_ == "base64" and is_base64_file_input(data):
411 binary: str | bytes | None = None
412
413 if isinstance(data, pathlib.Path):
414 binary = await anyio.Path(data).read_bytes()
415 elif isinstance(data, io.IOBase):
416 binary = data.read()
417
418 if isinstance(binary, str): # type: ignore[unreachable]
419 binary = binary.encode()
420
421 if not isinstance(binary, bytes):
422 raise RuntimeError(f"Could not read bytes from {data}; Received {type(binary)}")
423
424 return base64.b64encode(binary).decode("ascii")
425
426 return data
427
428
429async def _async_transform_typeddict(
430 data: Mapping[str, object],
431 expected_type: type,
432) -> Mapping[str, object]:
433 result: dict[str, object] = {}
434 annotations = get_type_hints(expected_type, include_extras=True)
435 for key, value in data.items():
436 if not is_given(value):
437 # we don't need to include omitted values here as they'll
438 # be stripped out before the request is sent anyway
439 continue
440
441 type_ = annotations.get(key)
442 if type_ is None:
443 # we do not have a type annotation for this field, leave it as is
444 result[key] = value
445 else:
446 result[_maybe_transform_key(key, type_)] = await _async_transform_recursive(value, annotation=type_)
447 return result
448
449
450@lru_cache(maxsize=8096)
451def get_type_hints(
452 obj: Any,
453 globalns: dict[str, Any] | None = None,
454 localns: Mapping[str, Any] | None = None,
455 include_extras: bool = False,
456) -> dict[str, Any]:
457 return _get_type_hints(obj, globalns=globalns, localns=localns, include_extras=include_extras)