main
1from __future__ import annotations
2
3from typing import TYPE_CHECKING, Any, Union, Generic, TypeVar, Callable, cast, overload
4from datetime import date, datetime
5from typing_extensions import Self, Literal
6
7import pydantic
8from pydantic.fields import FieldInfo
9
10from ._types import IncEx, StrBytesIntFloat
11
12_T = TypeVar("_T")
13_ModelT = TypeVar("_ModelT", bound=pydantic.BaseModel)
14
15# --------------- Pydantic v2, v3 compatibility ---------------
16
17# Pyright incorrectly reports some of our functions as overriding a method when they don't
18# pyright: reportIncompatibleMethodOverride=false
19
20PYDANTIC_V1 = pydantic.VERSION.startswith("1.")
21
22if TYPE_CHECKING:
23
24 def parse_date(value: date | StrBytesIntFloat) -> date: # noqa: ARG001
25 ...
26
27 def parse_datetime(value: Union[datetime, StrBytesIntFloat]) -> datetime: # noqa: ARG001
28 ...
29
30 def get_args(t: type[Any]) -> tuple[Any, ...]: # noqa: ARG001
31 ...
32
33 def is_union(tp: type[Any] | None) -> bool: # noqa: ARG001
34 ...
35
36 def get_origin(t: type[Any]) -> type[Any] | None: # noqa: ARG001
37 ...
38
39 def is_literal_type(type_: type[Any]) -> bool: # noqa: ARG001
40 ...
41
42 def is_typeddict(type_: type[Any]) -> bool: # noqa: ARG001
43 ...
44
45else:
46 # v1 re-exports
47 if PYDANTIC_V1:
48 from pydantic.typing import (
49 get_args as get_args,
50 is_union as is_union,
51 get_origin as get_origin,
52 is_typeddict as is_typeddict,
53 is_literal_type as is_literal_type,
54 )
55 from pydantic.datetime_parse import parse_date as parse_date, parse_datetime as parse_datetime
56 else:
57 from ._utils import (
58 get_args as get_args,
59 is_union as is_union,
60 get_origin as get_origin,
61 parse_date as parse_date,
62 is_typeddict as is_typeddict,
63 parse_datetime as parse_datetime,
64 is_literal_type as is_literal_type,
65 )
66
67
68# refactored config
69if TYPE_CHECKING:
70 from pydantic import ConfigDict as ConfigDict
71else:
72 if PYDANTIC_V1:
73 # TODO: provide an error message here?
74 ConfigDict = None
75 else:
76 from pydantic import ConfigDict as ConfigDict
77
78
79# renamed methods / properties
80def parse_obj(model: type[_ModelT], value: object) -> _ModelT:
81 if PYDANTIC_V1:
82 return cast(_ModelT, model.parse_obj(value)) # pyright: ignore[reportDeprecated, reportUnnecessaryCast]
83 else:
84 return model.model_validate(value)
85
86
87def field_is_required(field: FieldInfo) -> bool:
88 if PYDANTIC_V1:
89 return field.required # type: ignore
90 return field.is_required()
91
92
93def field_get_default(field: FieldInfo) -> Any:
94 value = field.get_default()
95 if PYDANTIC_V1:
96 return value
97 from pydantic_core import PydanticUndefined
98
99 if value == PydanticUndefined:
100 return None
101 return value
102
103
104def field_outer_type(field: FieldInfo) -> Any:
105 if PYDANTIC_V1:
106 return field.outer_type_ # type: ignore
107 return field.annotation
108
109
110def get_model_config(model: type[pydantic.BaseModel]) -> Any:
111 if PYDANTIC_V1:
112 return model.__config__ # type: ignore
113 return model.model_config
114
115
116def get_model_fields(model: type[pydantic.BaseModel]) -> dict[str, FieldInfo]:
117 if PYDANTIC_V1:
118 return model.__fields__ # type: ignore
119 return model.model_fields
120
121
122def model_copy(model: _ModelT, *, deep: bool = False) -> _ModelT:
123 if PYDANTIC_V1:
124 return model.copy(deep=deep) # type: ignore
125 return model.model_copy(deep=deep)
126
127
128def model_json(model: pydantic.BaseModel, *, indent: int | None = None) -> str:
129 if PYDANTIC_V1:
130 return model.json(indent=indent) # type: ignore
131 return model.model_dump_json(indent=indent)
132
133
134def model_dump(
135 model: pydantic.BaseModel,
136 *,
137 exclude: IncEx | None = None,
138 exclude_unset: bool = False,
139 exclude_defaults: bool = False,
140 warnings: bool = True,
141 mode: Literal["json", "python"] = "python",
142) -> dict[str, Any]:
143 if (not PYDANTIC_V1) or hasattr(model, "model_dump"):
144 return model.model_dump(
145 mode=mode,
146 exclude=exclude,
147 exclude_unset=exclude_unset,
148 exclude_defaults=exclude_defaults,
149 # warnings are not supported in Pydantic v1
150 warnings=True if PYDANTIC_V1 else warnings,
151 )
152 return cast(
153 "dict[str, Any]",
154 model.dict( # pyright: ignore[reportDeprecated, reportUnnecessaryCast]
155 exclude=exclude,
156 exclude_unset=exclude_unset,
157 exclude_defaults=exclude_defaults,
158 ),
159 )
160
161
162def model_parse(model: type[_ModelT], data: Any) -> _ModelT:
163 if PYDANTIC_V1:
164 return model.parse_obj(data) # pyright: ignore[reportDeprecated]
165 return model.model_validate(data)
166
167
168def model_parse_json(model: type[_ModelT], data: str | bytes) -> _ModelT:
169 if PYDANTIC_V1:
170 return model.parse_raw(data) # pyright: ignore[reportDeprecated]
171 return model.model_validate_json(data)
172
173
174def model_json_schema(model: type[_ModelT]) -> dict[str, Any]:
175 if PYDANTIC_V1:
176 return model.schema() # pyright: ignore[reportDeprecated]
177 return model.model_json_schema()
178
179
180# generic models
181if TYPE_CHECKING:
182
183 class GenericModel(pydantic.BaseModel): ...
184
185else:
186 if PYDANTIC_V1:
187 import pydantic.generics
188
189 class GenericModel(pydantic.generics.GenericModel, pydantic.BaseModel): ...
190 else:
191 # there no longer needs to be a distinction in v2 but
192 # we still have to create our own subclass to avoid
193 # inconsistent MRO ordering errors
194 class GenericModel(pydantic.BaseModel): ...
195
196
197# cached properties
198if TYPE_CHECKING:
199 cached_property = property
200
201 # we define a separate type (copied from typeshed)
202 # that represents that `cached_property` is `set`able
203 # at runtime, which differs from `@property`.
204 #
205 # this is a separate type as editors likely special case
206 # `@property` and we don't want to cause issues just to have
207 # more helpful internal types.
208
209 class typed_cached_property(Generic[_T]):
210 func: Callable[[Any], _T]
211 attrname: str | None
212
213 def __init__(self, func: Callable[[Any], _T]) -> None: ...
214
215 @overload
216 def __get__(self, instance: None, owner: type[Any] | None = None) -> Self: ...
217
218 @overload
219 def __get__(self, instance: object, owner: type[Any] | None = None) -> _T: ...
220
221 def __get__(self, instance: object, owner: type[Any] | None = None) -> _T | Self:
222 raise NotImplementedError()
223
224 def __set_name__(self, owner: type[Any], name: str) -> None: ...
225
226 # __set__ is not defined at runtime, but @cached_property is designed to be settable
227 def __set__(self, instance: object, value: _T) -> None: ...
228else:
229 from functools import cached_property as cached_property
230
231 typed_cached_property = cached_property