main
1from __future__ import annotations
2
3import os
4import inspect
5import weakref
6from typing import TYPE_CHECKING, Any, Type, Tuple, Union, Generic, TypeVar, Callable, Optional, cast
7from datetime import date, datetime
8from typing_extensions import (
9 List,
10 Unpack,
11 Literal,
12 ClassVar,
13 Protocol,
14 Required,
15 Sequence,
16 ParamSpec,
17 TypedDict,
18 TypeGuard,
19 final,
20 override,
21 runtime_checkable,
22)
23
24import pydantic
25from pydantic.fields import FieldInfo
26
27from ._types import (
28 Body,
29 IncEx,
30 Query,
31 ModelT,
32 Headers,
33 Timeout,
34 NotGiven,
35 AnyMapping,
36 HttpxRequestFiles,
37)
38from ._utils import (
39 PropertyInfo,
40 is_list,
41 is_given,
42 json_safe,
43 lru_cache,
44 is_mapping,
45 parse_date,
46 coerce_boolean,
47 parse_datetime,
48 strip_not_given,
49 extract_type_arg,
50 is_annotated_type,
51 is_type_alias_type,
52 strip_annotated_type,
53)
54from ._compat import (
55 PYDANTIC_V1,
56 ConfigDict,
57 GenericModel as BaseGenericModel,
58 get_args,
59 is_union,
60 parse_obj,
61 get_origin,
62 is_literal_type,
63 get_model_config,
64 get_model_fields,
65 field_get_default,
66)
67from ._constants import RAW_RESPONSE_HEADER
68
69if TYPE_CHECKING:
70 from pydantic_core.core_schema import ModelField, ModelSchema, LiteralSchema, ModelFieldsSchema
71
72__all__ = ["BaseModel", "GenericModel"]
73
74_T = TypeVar("_T")
75_BaseModelT = TypeVar("_BaseModelT", bound="BaseModel")
76
77P = ParamSpec("P")
78
79ReprArgs = Sequence[Tuple[Optional[str], Any]]
80
81
82@runtime_checkable
83class _ConfigProtocol(Protocol):
84 allow_population_by_field_name: bool
85
86
87class BaseModel(pydantic.BaseModel):
88 if PYDANTIC_V1:
89
90 @property
91 @override
92 def model_fields_set(self) -> set[str]:
93 # a forwards-compat shim for pydantic v2
94 return self.__fields_set__ # type: ignore
95
96 class Config(pydantic.BaseConfig): # pyright: ignore[reportDeprecated]
97 extra: Any = pydantic.Extra.allow # type: ignore
98
99 @override
100 def __repr_args__(self) -> ReprArgs:
101 # we don't want these attributes to be included when something like `rich.print` is used
102 return [arg for arg in super().__repr_args__() if arg[0] not in {"_request_id", "__exclude_fields__"}]
103 else:
104 model_config: ClassVar[ConfigDict] = ConfigDict(
105 extra="allow", defer_build=coerce_boolean(os.environ.get("DEFER_PYDANTIC_BUILD", "true"))
106 )
107
108 if TYPE_CHECKING:
109 _request_id: Optional[str] = None
110 """The ID of the request, returned via the X-Request-ID header. Useful for debugging requests and reporting issues to OpenAI.
111
112 This will **only** be set for the top-level response object, it will not be defined for nested objects. For example:
113
114 ```py
115 completion = await client.chat.completions.create(...)
116 completion._request_id # req_id_xxx
117 completion.usage._request_id # raises `AttributeError`
118 ```
119
120 Note: unlike other properties that use an `_` prefix, this property
121 *is* public. Unless documented otherwise, all other `_` prefix properties,
122 methods and modules are *private*.
123 """
124
125 def to_dict(
126 self,
127 *,
128 mode: Literal["json", "python"] = "python",
129 use_api_names: bool = True,
130 exclude_unset: bool = True,
131 exclude_defaults: bool = False,
132 exclude_none: bool = False,
133 warnings: bool = True,
134 ) -> dict[str, object]:
135 """Recursively generate a dictionary representation of the model, optionally specifying which fields to include or exclude.
136
137 By default, fields that were not set by the API will not be included,
138 and keys will match the API response, *not* the property names from the model.
139
140 For example, if the API responds with `"fooBar": true` but we've defined a `foo_bar: bool` property,
141 the output will use the `"fooBar"` key (unless `use_api_names=False` is passed).
142
143 Args:
144 mode:
145 If mode is 'json', the dictionary will only contain JSON serializable types. e.g. `datetime` will be turned into a string, `"2024-3-22T18:11:19.117000Z"`.
146 If mode is 'python', the dictionary may contain any Python objects. e.g. `datetime(2024, 3, 22)`
147
148 use_api_names: Whether to use the key that the API responded with or the property name. Defaults to `True`.
149 exclude_unset: Whether to exclude fields that have not been explicitly set.
150 exclude_defaults: Whether to exclude fields that are set to their default value from the output.
151 exclude_none: Whether to exclude fields that have a value of `None` from the output.
152 warnings: Whether to log warnings when invalid fields are encountered. This is only supported in Pydantic v2.
153 """
154 return self.model_dump(
155 mode=mode,
156 by_alias=use_api_names,
157 exclude_unset=exclude_unset,
158 exclude_defaults=exclude_defaults,
159 exclude_none=exclude_none,
160 warnings=warnings,
161 )
162
163 def to_json(
164 self,
165 *,
166 indent: int | None = 2,
167 use_api_names: bool = True,
168 exclude_unset: bool = True,
169 exclude_defaults: bool = False,
170 exclude_none: bool = False,
171 warnings: bool = True,
172 ) -> str:
173 """Generates a JSON string representing this model as it would be received from or sent to the API (but with indentation).
174
175 By default, fields that were not set by the API will not be included,
176 and keys will match the API response, *not* the property names from the model.
177
178 For example, if the API responds with `"fooBar": true` but we've defined a `foo_bar: bool` property,
179 the output will use the `"fooBar"` key (unless `use_api_names=False` is passed).
180
181 Args:
182 indent: Indentation to use in the JSON output. If `None` is passed, the output will be compact. Defaults to `2`
183 use_api_names: Whether to use the key that the API responded with or the property name. Defaults to `True`.
184 exclude_unset: Whether to exclude fields that have not been explicitly set.
185 exclude_defaults: Whether to exclude fields that have the default value.
186 exclude_none: Whether to exclude fields that have a value of `None`.
187 warnings: Whether to show any warnings that occurred during serialization. This is only supported in Pydantic v2.
188 """
189 return self.model_dump_json(
190 indent=indent,
191 by_alias=use_api_names,
192 exclude_unset=exclude_unset,
193 exclude_defaults=exclude_defaults,
194 exclude_none=exclude_none,
195 warnings=warnings,
196 )
197
198 @override
199 def __str__(self) -> str:
200 # mypy complains about an invalid self arg
201 return f"{self.__repr_name__()}({self.__repr_str__(', ')})" # type: ignore[misc]
202
203 # Override the 'construct' method in a way that supports recursive parsing without validation.
204 # Based on https://github.com/samuelcolvin/pydantic/issues/1168#issuecomment-817742836.
205 @classmethod
206 @override
207 def construct( # pyright: ignore[reportIncompatibleMethodOverride]
208 __cls: Type[ModelT],
209 _fields_set: set[str] | None = None,
210 **values: object,
211 ) -> ModelT:
212 m = __cls.__new__(__cls)
213 fields_values: dict[str, object] = {}
214
215 config = get_model_config(__cls)
216 populate_by_name = (
217 config.allow_population_by_field_name
218 if isinstance(config, _ConfigProtocol)
219 else config.get("populate_by_name")
220 )
221
222 if _fields_set is None:
223 _fields_set = set()
224
225 model_fields = get_model_fields(__cls)
226 for name, field in model_fields.items():
227 key = field.alias
228 if key is None or (key not in values and populate_by_name):
229 key = name
230
231 if key in values:
232 fields_values[name] = _construct_field(value=values[key], field=field, key=key)
233 _fields_set.add(name)
234 else:
235 fields_values[name] = field_get_default(field)
236
237 extra_field_type = _get_extra_fields_type(__cls)
238
239 _extra = {}
240 for key, value in values.items():
241 if key not in model_fields:
242 parsed = construct_type(value=value, type_=extra_field_type) if extra_field_type is not None else value
243
244 if PYDANTIC_V1:
245 _fields_set.add(key)
246 fields_values[key] = parsed
247 else:
248 _extra[key] = parsed
249
250 object.__setattr__(m, "__dict__", fields_values)
251
252 if PYDANTIC_V1:
253 # init_private_attributes() does not exist in v2
254 m._init_private_attributes() # type: ignore
255
256 # copied from Pydantic v1's `construct()` method
257 object.__setattr__(m, "__fields_set__", _fields_set)
258 else:
259 # these properties are copied from Pydantic's `model_construct()` method
260 object.__setattr__(m, "__pydantic_private__", None)
261 object.__setattr__(m, "__pydantic_extra__", _extra)
262 object.__setattr__(m, "__pydantic_fields_set__", _fields_set)
263
264 return m
265
266 if not TYPE_CHECKING:
267 # type checkers incorrectly complain about this assignment
268 # because the type signatures are technically different
269 # although not in practice
270 model_construct = construct
271
272 if PYDANTIC_V1:
273 # we define aliases for some of the new pydantic v2 methods so
274 # that we can just document these methods without having to specify
275 # a specific pydantic version as some users may not know which
276 # pydantic version they are currently using
277
278 @override
279 def model_dump(
280 self,
281 *,
282 mode: Literal["json", "python"] | str = "python",
283 include: IncEx | None = None,
284 exclude: IncEx | None = None,
285 context: Any | None = None,
286 by_alias: bool | None = None,
287 exclude_unset: bool = False,
288 exclude_defaults: bool = False,
289 exclude_none: bool = False,
290 exclude_computed_fields: bool = False,
291 round_trip: bool = False,
292 warnings: bool | Literal["none", "warn", "error"] = True,
293 fallback: Callable[[Any], Any] | None = None,
294 serialize_as_any: bool = False,
295 ) -> dict[str, Any]:
296 """Usage docs: https://docs.pydantic.dev/2.4/concepts/serialization/#modelmodel_dump
297
298 Generate a dictionary representation of the model, optionally specifying which fields to include or exclude.
299
300 Args:
301 mode: The mode in which `to_python` should run.
302 If mode is 'json', the output will only contain JSON serializable types.
303 If mode is 'python', the output may contain non-JSON-serializable Python objects.
304 include: A set of fields to include in the output.
305 exclude: A set of fields to exclude from the output.
306 context: Additional context to pass to the serializer.
307 by_alias: Whether to use the field's alias in the dictionary key if defined.
308 exclude_unset: Whether to exclude fields that have not been explicitly set.
309 exclude_defaults: Whether to exclude fields that are set to their default value.
310 exclude_none: Whether to exclude fields that have a value of `None`.
311 exclude_computed_fields: Whether to exclude computed fields.
312 While this can be useful for round-tripping, it is usually recommended to use the dedicated
313 `round_trip` parameter instead.
314 round_trip: If True, dumped values should be valid as input for non-idempotent types such as Json[T].
315 warnings: How to handle serialization errors. False/"none" ignores them, True/"warn" logs errors,
316 "error" raises a [`PydanticSerializationError`][pydantic_core.PydanticSerializationError].
317 fallback: A function to call when an unknown value is encountered. If not provided,
318 a [`PydanticSerializationError`][pydantic_core.PydanticSerializationError] error is raised.
319 serialize_as_any: Whether to serialize fields with duck-typing serialization behavior.
320
321 Returns:
322 A dictionary representation of the model.
323 """
324 if mode not in {"json", "python"}:
325 raise ValueError("mode must be either 'json' or 'python'")
326 if round_trip != False:
327 raise ValueError("round_trip is only supported in Pydantic v2")
328 if warnings != True:
329 raise ValueError("warnings is only supported in Pydantic v2")
330 if context is not None:
331 raise ValueError("context is only supported in Pydantic v2")
332 if serialize_as_any != False:
333 raise ValueError("serialize_as_any is only supported in Pydantic v2")
334 if fallback is not None:
335 raise ValueError("fallback is only supported in Pydantic v2")
336 if exclude_computed_fields != False:
337 raise ValueError("exclude_computed_fields is only supported in Pydantic v2")
338 dumped = super().dict( # pyright: ignore[reportDeprecated]
339 include=include,
340 exclude=exclude,
341 by_alias=by_alias if by_alias is not None else False,
342 exclude_unset=exclude_unset,
343 exclude_defaults=exclude_defaults,
344 exclude_none=exclude_none,
345 )
346
347 return cast("dict[str, Any]", json_safe(dumped)) if mode == "json" else dumped
348
349 @override
350 def model_dump_json(
351 self,
352 *,
353 indent: int | None = None,
354 ensure_ascii: bool = False,
355 include: IncEx | None = None,
356 exclude: IncEx | None = None,
357 context: Any | None = None,
358 by_alias: bool | None = None,
359 exclude_unset: bool = False,
360 exclude_defaults: bool = False,
361 exclude_none: bool = False,
362 exclude_computed_fields: bool = False,
363 round_trip: bool = False,
364 warnings: bool | Literal["none", "warn", "error"] = True,
365 fallback: Callable[[Any], Any] | None = None,
366 serialize_as_any: bool = False,
367 ) -> str:
368 """Usage docs: https://docs.pydantic.dev/2.4/concepts/serialization/#modelmodel_dump_json
369
370 Generates a JSON representation of the model using Pydantic's `to_json` method.
371
372 Args:
373 indent: Indentation to use in the JSON output. If None is passed, the output will be compact.
374 include: Field(s) to include in the JSON output. Can take either a string or set of strings.
375 exclude: Field(s) to exclude from the JSON output. Can take either a string or set of strings.
376 by_alias: Whether to serialize using field aliases.
377 exclude_unset: Whether to exclude fields that have not been explicitly set.
378 exclude_defaults: Whether to exclude fields that have the default value.
379 exclude_none: Whether to exclude fields that have a value of `None`.
380 round_trip: Whether to use serialization/deserialization between JSON and class instance.
381 warnings: Whether to show any warnings that occurred during serialization.
382
383 Returns:
384 A JSON string representation of the model.
385 """
386 if round_trip != False:
387 raise ValueError("round_trip is only supported in Pydantic v2")
388 if warnings != True:
389 raise ValueError("warnings is only supported in Pydantic v2")
390 if context is not None:
391 raise ValueError("context is only supported in Pydantic v2")
392 if serialize_as_any != False:
393 raise ValueError("serialize_as_any is only supported in Pydantic v2")
394 if fallback is not None:
395 raise ValueError("fallback is only supported in Pydantic v2")
396 if ensure_ascii != False:
397 raise ValueError("ensure_ascii is only supported in Pydantic v2")
398 if exclude_computed_fields != False:
399 raise ValueError("exclude_computed_fields is only supported in Pydantic v2")
400 return super().json( # type: ignore[reportDeprecated]
401 indent=indent,
402 include=include,
403 exclude=exclude,
404 by_alias=by_alias if by_alias is not None else False,
405 exclude_unset=exclude_unset,
406 exclude_defaults=exclude_defaults,
407 exclude_none=exclude_none,
408 )
409
410
411def _construct_field(value: object, field: FieldInfo, key: str) -> object:
412 if value is None:
413 return field_get_default(field)
414
415 if PYDANTIC_V1:
416 type_ = cast(type, field.outer_type_) # type: ignore
417 else:
418 type_ = field.annotation # type: ignore
419
420 if type_ is None:
421 raise RuntimeError(f"Unexpected field type is None for {key}")
422
423 return construct_type(value=value, type_=type_, metadata=getattr(field, "metadata", None))
424
425
426def _get_extra_fields_type(cls: type[pydantic.BaseModel]) -> type | None:
427 if PYDANTIC_V1:
428 # TODO
429 return None
430
431 schema = cls.__pydantic_core_schema__
432 if schema["type"] == "model":
433 fields = schema["schema"]
434 if fields["type"] == "model-fields":
435 extras = fields.get("extras_schema")
436 if extras and "cls" in extras:
437 # mypy can't narrow the type
438 return extras["cls"] # type: ignore[no-any-return]
439
440 return None
441
442
443def is_basemodel(type_: type) -> bool:
444 """Returns whether or not the given type is either a `BaseModel` or a union of `BaseModel`"""
445 if is_union(type_):
446 for variant in get_args(type_):
447 if is_basemodel(variant):
448 return True
449
450 return False
451
452 return is_basemodel_type(type_)
453
454
455def is_basemodel_type(type_: type) -> TypeGuard[type[BaseModel] | type[GenericModel]]:
456 origin = get_origin(type_) or type_
457 if not inspect.isclass(origin):
458 return False
459 return issubclass(origin, BaseModel) or issubclass(origin, GenericModel)
460
461
462def build(
463 base_model_cls: Callable[P, _BaseModelT],
464 *args: P.args,
465 **kwargs: P.kwargs,
466) -> _BaseModelT:
467 """Construct a BaseModel class without validation.
468
469 This is useful for cases where you need to instantiate a `BaseModel`
470 from an API response as this provides type-safe params which isn't supported
471 by helpers like `construct_type()`.
472
473 ```py
474 build(MyModel, my_field_a="foo", my_field_b=123)
475 ```
476 """
477 if args:
478 raise TypeError(
479 "Received positional arguments which are not supported; Keyword arguments must be used instead",
480 )
481
482 return cast(_BaseModelT, construct_type(type_=base_model_cls, value=kwargs))
483
484
485def construct_type_unchecked(*, value: object, type_: type[_T]) -> _T:
486 """Loose coercion to the expected type with construction of nested values.
487
488 Note: the returned value from this function is not guaranteed to match the
489 given type.
490 """
491 return cast(_T, construct_type(value=value, type_=type_))
492
493
494def construct_type(*, value: object, type_: object, metadata: Optional[List[Any]] = None) -> object:
495 """Loose coercion to the expected type with construction of nested values.
496
497 If the given value does not match the expected type then it is returned as-is.
498 """
499
500 # store a reference to the original type we were given before we extract any inner
501 # types so that we can properly resolve forward references in `TypeAliasType` annotations
502 original_type = None
503
504 # we allow `object` as the input type because otherwise, passing things like
505 # `Literal['value']` will be reported as a type error by type checkers
506 type_ = cast("type[object]", type_)
507 if is_type_alias_type(type_):
508 original_type = type_ # type: ignore[unreachable]
509 type_ = type_.__value__ # type: ignore[unreachable]
510
511 # unwrap `Annotated[T, ...]` -> `T`
512 if metadata is not None and len(metadata) > 0:
513 meta: tuple[Any, ...] = tuple(metadata)
514 elif is_annotated_type(type_):
515 meta = get_args(type_)[1:]
516 type_ = extract_type_arg(type_, 0)
517 else:
518 meta = tuple()
519
520 # we need to use the origin class for any types that are subscripted generics
521 # e.g. Dict[str, object]
522 origin = get_origin(type_) or type_
523 args = get_args(type_)
524
525 if is_union(origin):
526 try:
527 return validate_type(type_=cast("type[object]", original_type or type_), value=value)
528 except Exception:
529 pass
530
531 # if the type is a discriminated union then we want to construct the right variant
532 # in the union, even if the data doesn't match exactly, otherwise we'd break code
533 # that relies on the constructed class types, e.g.
534 #
535 # class FooType:
536 # kind: Literal['foo']
537 # value: str
538 #
539 # class BarType:
540 # kind: Literal['bar']
541 # value: int
542 #
543 # without this block, if the data we get is something like `{'kind': 'bar', 'value': 'foo'}` then
544 # we'd end up constructing `FooType` when it should be `BarType`.
545 discriminator = _build_discriminated_union_meta(union=type_, meta_annotations=meta)
546 if discriminator and is_mapping(value):
547 variant_value = value.get(discriminator.field_alias_from or discriminator.field_name)
548 if variant_value and isinstance(variant_value, str):
549 variant_type = discriminator.mapping.get(variant_value)
550 if variant_type:
551 return construct_type(type_=variant_type, value=value)
552
553 # if the data is not valid, use the first variant that doesn't fail while deserializing
554 for variant in args:
555 try:
556 return construct_type(value=value, type_=variant)
557 except Exception:
558 continue
559
560 raise RuntimeError(f"Could not convert data into a valid instance of {type_}")
561
562 if origin == dict:
563 if not is_mapping(value):
564 return value
565
566 _, items_type = get_args(type_) # Dict[_, items_type]
567 return {key: construct_type(value=item, type_=items_type) for key, item in value.items()}
568
569 if (
570 not is_literal_type(type_)
571 and inspect.isclass(origin)
572 and (issubclass(origin, BaseModel) or issubclass(origin, GenericModel))
573 ):
574 if is_list(value):
575 return [cast(Any, type_).construct(**entry) if is_mapping(entry) else entry for entry in value]
576
577 if is_mapping(value):
578 if issubclass(type_, BaseModel):
579 return type_.construct(**value) # type: ignore[arg-type]
580
581 return cast(Any, type_).construct(**value)
582
583 if origin == list:
584 if not is_list(value):
585 return value
586
587 inner_type = args[0] # List[inner_type]
588 return [construct_type(value=entry, type_=inner_type) for entry in value]
589
590 if origin == float:
591 if isinstance(value, int):
592 coerced = float(value)
593 if coerced != value:
594 return value
595 return coerced
596
597 return value
598
599 if type_ == datetime:
600 try:
601 return parse_datetime(value) # type: ignore
602 except Exception:
603 return value
604
605 if type_ == date:
606 try:
607 return parse_date(value) # type: ignore
608 except Exception:
609 return value
610
611 return value
612
613
614@runtime_checkable
615class CachedDiscriminatorType(Protocol):
616 __discriminator__: DiscriminatorDetails
617
618
619DISCRIMINATOR_CACHE: weakref.WeakKeyDictionary[type, DiscriminatorDetails] = weakref.WeakKeyDictionary()
620
621
622class DiscriminatorDetails:
623 field_name: str
624 """The name of the discriminator field in the variant class, e.g.
625
626 ```py
627 class Foo(BaseModel):
628 type: Literal['foo']
629 ```
630
631 Will result in field_name='type'
632 """
633
634 field_alias_from: str | None
635 """The name of the discriminator field in the API response, e.g.
636
637 ```py
638 class Foo(BaseModel):
639 type: Literal['foo'] = Field(alias='type_from_api')
640 ```
641
642 Will result in field_alias_from='type_from_api'
643 """
644
645 mapping: dict[str, type]
646 """Mapping of discriminator value to variant type, e.g.
647
648 {'foo': FooVariant, 'bar': BarVariant}
649 """
650
651 def __init__(
652 self,
653 *,
654 mapping: dict[str, type],
655 discriminator_field: str,
656 discriminator_alias: str | None,
657 ) -> None:
658 self.mapping = mapping
659 self.field_name = discriminator_field
660 self.field_alias_from = discriminator_alias
661
662
663def _build_discriminated_union_meta(*, union: type, meta_annotations: tuple[Any, ...]) -> DiscriminatorDetails | None:
664 cached = DISCRIMINATOR_CACHE.get(union)
665 if cached is not None:
666 return cached
667
668 discriminator_field_name: str | None = None
669
670 for annotation in meta_annotations:
671 if isinstance(annotation, PropertyInfo) and annotation.discriminator is not None:
672 discriminator_field_name = annotation.discriminator
673 break
674
675 if not discriminator_field_name:
676 return None
677
678 mapping: dict[str, type] = {}
679 discriminator_alias: str | None = None
680
681 for variant in get_args(union):
682 variant = strip_annotated_type(variant)
683 if is_basemodel_type(variant):
684 if PYDANTIC_V1:
685 field_info = cast("dict[str, FieldInfo]", variant.__fields__).get(discriminator_field_name) # pyright: ignore[reportDeprecated, reportUnnecessaryCast]
686 if not field_info:
687 continue
688
689 # Note: if one variant defines an alias then they all should
690 discriminator_alias = field_info.alias
691
692 if (annotation := getattr(field_info, "annotation", None)) and is_literal_type(annotation):
693 for entry in get_args(annotation):
694 if isinstance(entry, str):
695 mapping[entry] = variant
696 else:
697 field = _extract_field_schema_pv2(variant, discriminator_field_name)
698 if not field:
699 continue
700
701 # Note: if one variant defines an alias then they all should
702 discriminator_alias = field.get("serialization_alias")
703
704 field_schema = field["schema"]
705
706 if field_schema["type"] == "literal":
707 for entry in cast("LiteralSchema", field_schema)["expected"]:
708 if isinstance(entry, str):
709 mapping[entry] = variant
710
711 if not mapping:
712 return None
713
714 details = DiscriminatorDetails(
715 mapping=mapping,
716 discriminator_field=discriminator_field_name,
717 discriminator_alias=discriminator_alias,
718 )
719 DISCRIMINATOR_CACHE.setdefault(union, details)
720 return details
721
722
723def _extract_field_schema_pv2(model: type[BaseModel], field_name: str) -> ModelField | None:
724 schema = model.__pydantic_core_schema__
725 if schema["type"] == "definitions":
726 schema = schema["schema"]
727
728 if schema["type"] != "model":
729 return None
730
731 schema = cast("ModelSchema", schema)
732 fields_schema = schema["schema"]
733 if fields_schema["type"] != "model-fields":
734 return None
735
736 fields_schema = cast("ModelFieldsSchema", fields_schema)
737 field = fields_schema["fields"].get(field_name)
738 if not field:
739 return None
740
741 return cast("ModelField", field) # pyright: ignore[reportUnnecessaryCast]
742
743
744def validate_type(*, type_: type[_T], value: object) -> _T:
745 """Strict validation that the given value matches the expected type"""
746 if inspect.isclass(type_) and issubclass(type_, pydantic.BaseModel):
747 return cast(_T, parse_obj(type_, value))
748
749 return cast(_T, _validate_non_model_type(type_=type_, value=value))
750
751
752def set_pydantic_config(typ: Any, config: pydantic.ConfigDict) -> None:
753 """Add a pydantic config for the given type.
754
755 Note: this is a no-op on Pydantic v1.
756 """
757 setattr(typ, "__pydantic_config__", config) # noqa: B010
758
759
760def add_request_id(obj: BaseModel, request_id: str | None) -> None:
761 obj._request_id = request_id
762
763 # in Pydantic v1, using setattr like we do above causes the attribute
764 # to be included when serializing the model which we don't want in this
765 # case so we need to explicitly exclude it
766 if PYDANTIC_V1:
767 try:
768 exclude_fields = obj.__exclude_fields__ # type: ignore
769 except AttributeError:
770 cast(Any, obj).__exclude_fields__ = {"_request_id", "__exclude_fields__"}
771 else:
772 cast(Any, obj).__exclude_fields__ = {*(exclude_fields or {}), "_request_id", "__exclude_fields__"}
773
774
775# our use of subclassing here causes weirdness for type checkers,
776# so we just pretend that we don't subclass
777if TYPE_CHECKING:
778 GenericModel = BaseModel
779else:
780
781 class GenericModel(BaseGenericModel, BaseModel):
782 pass
783
784
785if not PYDANTIC_V1:
786 from pydantic import TypeAdapter as _TypeAdapter
787
788 _CachedTypeAdapter = cast("TypeAdapter[object]", lru_cache(maxsize=None)(_TypeAdapter))
789
790 if TYPE_CHECKING:
791 from pydantic import TypeAdapter
792 else:
793 TypeAdapter = _CachedTypeAdapter
794
795 def _validate_non_model_type(*, type_: type[_T], value: object) -> _T:
796 return TypeAdapter(type_).validate_python(value)
797
798elif not TYPE_CHECKING: # TODO: condition is weird
799
800 class RootModel(GenericModel, Generic[_T]):
801 """Used as a placeholder to easily convert runtime types to a Pydantic format
802 to provide validation.
803
804 For example:
805 ```py
806 validated = RootModel[int](__root__="5").__root__
807 # validated: 5
808 ```
809 """
810
811 __root__: _T
812
813 def _validate_non_model_type(*, type_: type[_T], value: object) -> _T:
814 model = _create_pydantic_model(type_).validate(value)
815 return cast(_T, model.__root__)
816
817 def _create_pydantic_model(type_: _T) -> Type[RootModel[_T]]:
818 return RootModel[type_] # type: ignore
819
820
821class FinalRequestOptionsInput(TypedDict, total=False):
822 method: Required[str]
823 url: Required[str]
824 params: Query
825 headers: Headers
826 max_retries: int
827 timeout: float | Timeout | None
828 files: HttpxRequestFiles | None
829 idempotency_key: str
830 json_data: Body
831 extra_json: AnyMapping
832 follow_redirects: bool
833
834
835@final
836class FinalRequestOptions(pydantic.BaseModel):
837 method: str
838 url: str
839 params: Query = {}
840 headers: Union[Headers, NotGiven] = NotGiven()
841 max_retries: Union[int, NotGiven] = NotGiven()
842 timeout: Union[float, Timeout, None, NotGiven] = NotGiven()
843 files: Union[HttpxRequestFiles, None] = None
844 idempotency_key: Union[str, None] = None
845 post_parser: Union[Callable[[Any], Any], NotGiven] = NotGiven()
846 follow_redirects: Union[bool, None] = None
847
848 # It should be noted that we cannot use `json` here as that would override
849 # a BaseModel method in an incompatible fashion.
850 json_data: Union[Body, None] = None
851 extra_json: Union[AnyMapping, None] = None
852
853 if PYDANTIC_V1:
854
855 class Config(pydantic.BaseConfig): # pyright: ignore[reportDeprecated]
856 arbitrary_types_allowed: bool = True
857 else:
858 model_config: ClassVar[ConfigDict] = ConfigDict(arbitrary_types_allowed=True)
859
860 def get_max_retries(self, max_retries: int) -> int:
861 if isinstance(self.max_retries, NotGiven):
862 return max_retries
863 return self.max_retries
864
865 def _strip_raw_response_header(self) -> None:
866 if not is_given(self.headers):
867 return
868
869 if self.headers.get(RAW_RESPONSE_HEADER):
870 self.headers = {**self.headers}
871 self.headers.pop(RAW_RESPONSE_HEADER)
872
873 # override the `construct` method so that we can run custom transformations.
874 # this is necessary as we don't want to do any actual runtime type checking
875 # (which means we can't use validators) but we do want to ensure that `NotGiven`
876 # values are not present
877 #
878 # type ignore required because we're adding explicit types to `**values`
879 @classmethod
880 def construct( # type: ignore
881 cls,
882 _fields_set: set[str] | None = None,
883 **values: Unpack[FinalRequestOptionsInput],
884 ) -> FinalRequestOptions:
885 kwargs: dict[str, Any] = {
886 # we unconditionally call `strip_not_given` on any value
887 # as it will just ignore any non-mapping types
888 key: strip_not_given(value)
889 for key, value in values.items()
890 }
891 if PYDANTIC_V1:
892 return cast(FinalRequestOptions, super().construct(_fields_set, **kwargs)) # pyright: ignore[reportDeprecated]
893 return super().model_construct(_fields_set, **kwargs)
894
895 if not TYPE_CHECKING:
896 # type checkers incorrectly complain about this assignment
897 model_construct = construct