main
  1from __future__ import annotations
  2
  3import io
  4import base64
  5import pathlib
  6from typing import Any, Mapping, TypeVar, cast
  7from datetime import date, datetime
  8from typing_extensions import Literal, get_args, override, get_type_hints as _get_type_hints
  9
 10import anyio
 11import pydantic
 12
 13from ._utils import (
 14    is_list,
 15    is_given,
 16    lru_cache,
 17    is_mapping,
 18    is_iterable,
 19    is_sequence,
 20)
 21from .._files import is_base64_file_input
 22from ._compat import get_origin, is_typeddict
 23from ._typing import (
 24    is_list_type,
 25    is_union_type,
 26    extract_type_arg,
 27    is_iterable_type,
 28    is_required_type,
 29    is_sequence_type,
 30    is_annotated_type,
 31    strip_annotated_type,
 32)
 33
 34_T = TypeVar("_T")
 35
 36
 37# TODO: support for drilling globals() and locals()
 38# TODO: ensure works correctly with forward references in all cases
 39
 40
 41PropertyFormat = Literal["iso8601", "base64", "custom"]
 42
 43
 44class PropertyInfo:
 45    """Metadata class to be used in Annotated types to provide information about a given type.
 46
 47    For example:
 48
 49    class MyParams(TypedDict):
 50        account_holder_name: Annotated[str, PropertyInfo(alias='accountHolderName')]
 51
 52    This means that {'account_holder_name': 'Robert'} will be transformed to {'accountHolderName': 'Robert'} before being sent to the API.
 53    """
 54
 55    alias: str | None
 56    format: PropertyFormat | None
 57    format_template: str | None
 58    discriminator: str | None
 59
 60    def __init__(
 61        self,
 62        *,
 63        alias: str | None = None,
 64        format: PropertyFormat | None = None,
 65        format_template: str | None = None,
 66        discriminator: str | None = None,
 67    ) -> None:
 68        self.alias = alias
 69        self.format = format
 70        self.format_template = format_template
 71        self.discriminator = discriminator
 72
 73    @override
 74    def __repr__(self) -> str:
 75        return f"{self.__class__.__name__}(alias='{self.alias}', format={self.format}, format_template='{self.format_template}', discriminator='{self.discriminator}')"
 76
 77
 78def maybe_transform(
 79    data: object,
 80    expected_type: object,
 81) -> Any | None:
 82    """Wrapper over `transform()` that allows `None` to be passed.
 83
 84    See `transform()` for more details.
 85    """
 86    if data is None:
 87        return None
 88    return transform(data, expected_type)
 89
 90
 91# Wrapper over _transform_recursive providing fake types
 92def transform(
 93    data: _T,
 94    expected_type: object,
 95) -> _T:
 96    """Transform dictionaries based off of type information from the given type, for example:
 97
 98    ```py
 99    class Params(TypedDict, total=False):
100        card_id: Required[Annotated[str, PropertyInfo(alias="cardID")]]
101
102
103    transformed = transform({"card_id": "<my card ID>"}, Params)
104    # {'cardID': '<my card ID>'}
105    ```
106
107    Any keys / data that does not have type information given will be included as is.
108
109    It should be noted that the transformations that this function does are not represented in the type system.
110    """
111    transformed = _transform_recursive(data, annotation=cast(type, expected_type))
112    return cast(_T, transformed)
113
114
115@lru_cache(maxsize=8096)
116def _get_annotated_type(type_: type) -> type | None:
117    """If the given type is an `Annotated` type then it is returned, if not `None` is returned.
118
119    This also unwraps the type when applicable, e.g. `Required[Annotated[T, ...]]`
120    """
121    if is_required_type(type_):
122        # Unwrap `Required[Annotated[T, ...]]` to `Annotated[T, ...]`
123        type_ = get_args(type_)[0]
124
125    if is_annotated_type(type_):
126        return type_
127
128    return None
129
130
131def _maybe_transform_key(key: str, type_: type) -> str:
132    """Transform the given `data` based on the annotations provided in `type_`.
133
134    Note: this function only looks at `Annotated` types that contain `PropertyInfo` metadata.
135    """
136    annotated_type = _get_annotated_type(type_)
137    if annotated_type is None:
138        # no `Annotated` definition for this type, no transformation needed
139        return key
140
141    # ignore the first argument as it is the actual type
142    annotations = get_args(annotated_type)[1:]
143    for annotation in annotations:
144        if isinstance(annotation, PropertyInfo) and annotation.alias is not None:
145            return annotation.alias
146
147    return key
148
149
150def _no_transform_needed(annotation: type) -> bool:
151    return annotation == float or annotation == int
152
153
154def _transform_recursive(
155    data: object,
156    *,
157    annotation: type,
158    inner_type: type | None = None,
159) -> object:
160    """Transform the given data against the expected type.
161
162    Args:
163        annotation: The direct type annotation given to the particular piece of data.
164            This may or may not be wrapped in metadata types, e.g. `Required[T]`, `Annotated[T, ...]` etc
165
166        inner_type: If applicable, this is the "inside" type. This is useful in certain cases where the outside type
167            is a container type such as `List[T]`. In that case `inner_type` should be set to `T` so that each entry in
168            the list can be transformed using the metadata from the container type.
169
170            Defaults to the same value as the `annotation` argument.
171    """
172    from .._compat import model_dump
173
174    if inner_type is None:
175        inner_type = annotation
176
177    stripped_type = strip_annotated_type(inner_type)
178    origin = get_origin(stripped_type) or stripped_type
179    if is_typeddict(stripped_type) and is_mapping(data):
180        return _transform_typeddict(data, stripped_type)
181
182    if origin == dict and is_mapping(data):
183        items_type = get_args(stripped_type)[1]
184        return {key: _transform_recursive(value, annotation=items_type) for key, value in data.items()}
185
186    if (
187        # List[T]
188        (is_list_type(stripped_type) and is_list(data))
189        # Iterable[T]
190        or (is_iterable_type(stripped_type) and is_iterable(data) and not isinstance(data, str))
191        # Sequence[T]
192        or (is_sequence_type(stripped_type) and is_sequence(data) and not isinstance(data, str))
193    ):
194        # dicts are technically iterable, but it is an iterable on the keys of the dict and is not usually
195        # intended as an iterable, so we don't transform it.
196        if isinstance(data, dict):
197            return cast(object, data)
198
199        inner_type = extract_type_arg(stripped_type, 0)
200        if _no_transform_needed(inner_type):
201            # for some types there is no need to transform anything, so we can get a small
202            # perf boost from skipping that work.
203            #
204            # but we still need to convert to a list to ensure the data is json-serializable
205            if is_list(data):
206                return data
207            return list(data)
208
209        return [_transform_recursive(d, annotation=annotation, inner_type=inner_type) for d in data]
210
211    if is_union_type(stripped_type):
212        # For union types we run the transformation against all subtypes to ensure that everything is transformed.
213        #
214        # TODO: there may be edge cases where the same normalized field name will transform to two different names
215        # in different subtypes.
216        for subtype in get_args(stripped_type):
217            data = _transform_recursive(data, annotation=annotation, inner_type=subtype)
218        return data
219
220    if isinstance(data, pydantic.BaseModel):
221        return model_dump(data, exclude_unset=True, mode="json", exclude=getattr(data, "__api_exclude__", None))
222
223    annotated_type = _get_annotated_type(annotation)
224    if annotated_type is None:
225        return data
226
227    # ignore the first argument as it is the actual type
228    annotations = get_args(annotated_type)[1:]
229    for annotation in annotations:
230        if isinstance(annotation, PropertyInfo) and annotation.format is not None:
231            return _format_data(data, annotation.format, annotation.format_template)
232
233    return data
234
235
236def _format_data(data: object, format_: PropertyFormat, format_template: str | None) -> object:
237    if isinstance(data, (date, datetime)):
238        if format_ == "iso8601":
239            return data.isoformat()
240
241        if format_ == "custom" and format_template is not None:
242            return data.strftime(format_template)
243
244    if format_ == "base64" and is_base64_file_input(data):
245        binary: str | bytes | None = None
246
247        if isinstance(data, pathlib.Path):
248            binary = data.read_bytes()
249        elif isinstance(data, io.IOBase):
250            binary = data.read()
251
252            if isinstance(binary, str):  # type: ignore[unreachable]
253                binary = binary.encode()
254
255        if not isinstance(binary, bytes):
256            raise RuntimeError(f"Could not read bytes from {data}; Received {type(binary)}")
257
258        return base64.b64encode(binary).decode("ascii")
259
260    return data
261
262
263def _transform_typeddict(
264    data: Mapping[str, object],
265    expected_type: type,
266) -> Mapping[str, object]:
267    result: dict[str, object] = {}
268    annotations = get_type_hints(expected_type, include_extras=True)
269    for key, value in data.items():
270        if not is_given(value):
271            # we don't need to include omitted values here as they'll
272            # be stripped out before the request is sent anyway
273            continue
274
275        type_ = annotations.get(key)
276        if type_ is None:
277            # we do not have a type annotation for this field, leave it as is
278            result[key] = value
279        else:
280            result[_maybe_transform_key(key, type_)] = _transform_recursive(value, annotation=type_)
281    return result
282
283
284async def async_maybe_transform(
285    data: object,
286    expected_type: object,
287) -> Any | None:
288    """Wrapper over `async_transform()` that allows `None` to be passed.
289
290    See `async_transform()` for more details.
291    """
292    if data is None:
293        return None
294    return await async_transform(data, expected_type)
295
296
297async def async_transform(
298    data: _T,
299    expected_type: object,
300) -> _T:
301    """Transform dictionaries based off of type information from the given type, for example:
302
303    ```py
304    class Params(TypedDict, total=False):
305        card_id: Required[Annotated[str, PropertyInfo(alias="cardID")]]
306
307
308    transformed = transform({"card_id": "<my card ID>"}, Params)
309    # {'cardID': '<my card ID>'}
310    ```
311
312    Any keys / data that does not have type information given will be included as is.
313
314    It should be noted that the transformations that this function does are not represented in the type system.
315    """
316    transformed = await _async_transform_recursive(data, annotation=cast(type, expected_type))
317    return cast(_T, transformed)
318
319
320async def _async_transform_recursive(
321    data: object,
322    *,
323    annotation: type,
324    inner_type: type | None = None,
325) -> object:
326    """Transform the given data against the expected type.
327
328    Args:
329        annotation: The direct type annotation given to the particular piece of data.
330            This may or may not be wrapped in metadata types, e.g. `Required[T]`, `Annotated[T, ...]` etc
331
332        inner_type: If applicable, this is the "inside" type. This is useful in certain cases where the outside type
333            is a container type such as `List[T]`. In that case `inner_type` should be set to `T` so that each entry in
334            the list can be transformed using the metadata from the container type.
335
336            Defaults to the same value as the `annotation` argument.
337    """
338    from .._compat import model_dump
339
340    if inner_type is None:
341        inner_type = annotation
342
343    stripped_type = strip_annotated_type(inner_type)
344    origin = get_origin(stripped_type) or stripped_type
345    if is_typeddict(stripped_type) and is_mapping(data):
346        return await _async_transform_typeddict(data, stripped_type)
347
348    if origin == dict and is_mapping(data):
349        items_type = get_args(stripped_type)[1]
350        return {key: _transform_recursive(value, annotation=items_type) for key, value in data.items()}
351
352    if (
353        # List[T]
354        (is_list_type(stripped_type) and is_list(data))
355        # Iterable[T]
356        or (is_iterable_type(stripped_type) and is_iterable(data) and not isinstance(data, str))
357        # Sequence[T]
358        or (is_sequence_type(stripped_type) and is_sequence(data) and not isinstance(data, str))
359    ):
360        # dicts are technically iterable, but it is an iterable on the keys of the dict and is not usually
361        # intended as an iterable, so we don't transform it.
362        if isinstance(data, dict):
363            return cast(object, data)
364
365        inner_type = extract_type_arg(stripped_type, 0)
366        if _no_transform_needed(inner_type):
367            # for some types there is no need to transform anything, so we can get a small
368            # perf boost from skipping that work.
369            #
370            # but we still need to convert to a list to ensure the data is json-serializable
371            if is_list(data):
372                return data
373            return list(data)
374
375        return [await _async_transform_recursive(d, annotation=annotation, inner_type=inner_type) for d in data]
376
377    if is_union_type(stripped_type):
378        # For union types we run the transformation against all subtypes to ensure that everything is transformed.
379        #
380        # TODO: there may be edge cases where the same normalized field name will transform to two different names
381        # in different subtypes.
382        for subtype in get_args(stripped_type):
383            data = await _async_transform_recursive(data, annotation=annotation, inner_type=subtype)
384        return data
385
386    if isinstance(data, pydantic.BaseModel):
387        return model_dump(data, exclude_unset=True, mode="json")
388
389    annotated_type = _get_annotated_type(annotation)
390    if annotated_type is None:
391        return data
392
393    # ignore the first argument as it is the actual type
394    annotations = get_args(annotated_type)[1:]
395    for annotation in annotations:
396        if isinstance(annotation, PropertyInfo) and annotation.format is not None:
397            return await _async_format_data(data, annotation.format, annotation.format_template)
398
399    return data
400
401
402async def _async_format_data(data: object, format_: PropertyFormat, format_template: str | None) -> object:
403    if isinstance(data, (date, datetime)):
404        if format_ == "iso8601":
405            return data.isoformat()
406
407        if format_ == "custom" and format_template is not None:
408            return data.strftime(format_template)
409
410    if format_ == "base64" and is_base64_file_input(data):
411        binary: str | bytes | None = None
412
413        if isinstance(data, pathlib.Path):
414            binary = await anyio.Path(data).read_bytes()
415        elif isinstance(data, io.IOBase):
416            binary = data.read()
417
418            if isinstance(binary, str):  # type: ignore[unreachable]
419                binary = binary.encode()
420
421        if not isinstance(binary, bytes):
422            raise RuntimeError(f"Could not read bytes from {data}; Received {type(binary)}")
423
424        return base64.b64encode(binary).decode("ascii")
425
426    return data
427
428
429async def _async_transform_typeddict(
430    data: Mapping[str, object],
431    expected_type: type,
432) -> Mapping[str, object]:
433    result: dict[str, object] = {}
434    annotations = get_type_hints(expected_type, include_extras=True)
435    for key, value in data.items():
436        if not is_given(value):
437            # we don't need to include omitted values here as they'll
438            # be stripped out before the request is sent anyway
439            continue
440
441        type_ = annotations.get(key)
442        if type_ is None:
443            # we do not have a type annotation for this field, leave it as is
444            result[key] = value
445        else:
446            result[_maybe_transform_key(key, type_)] = await _async_transform_recursive(value, annotation=type_)
447    return result
448
449
450@lru_cache(maxsize=8096)
451def get_type_hints(
452    obj: Any,
453    globalns: dict[str, Any] | None = None,
454    localns: Mapping[str, Any] | None = None,
455    include_extras: bool = False,
456) -> dict[str, Any]:
457    return _get_type_hints(obj, globalns=globalns, localns=localns, include_extras=include_extras)