main
  1from __future__ import annotations
  2
  3from typing import TYPE_CHECKING, Any, Union, Generic, TypeVar, Callable, cast, overload
  4from datetime import date, datetime
  5from typing_extensions import Self, Literal
  6
  7import pydantic
  8from pydantic.fields import FieldInfo
  9
 10from ._types import IncEx, StrBytesIntFloat
 11
 12_T = TypeVar("_T")
 13_ModelT = TypeVar("_ModelT", bound=pydantic.BaseModel)
 14
 15# --------------- Pydantic v2, v3 compatibility ---------------
 16
 17# Pyright incorrectly reports some of our functions as overriding a method when they don't
 18# pyright: reportIncompatibleMethodOverride=false
 19
 20PYDANTIC_V1 = pydantic.VERSION.startswith("1.")
 21
 22if TYPE_CHECKING:
 23
 24    def parse_date(value: date | StrBytesIntFloat) -> date:  # noqa: ARG001
 25        ...
 26
 27    def parse_datetime(value: Union[datetime, StrBytesIntFloat]) -> datetime:  # noqa: ARG001
 28        ...
 29
 30    def get_args(t: type[Any]) -> tuple[Any, ...]:  # noqa: ARG001
 31        ...
 32
 33    def is_union(tp: type[Any] | None) -> bool:  # noqa: ARG001
 34        ...
 35
 36    def get_origin(t: type[Any]) -> type[Any] | None:  # noqa: ARG001
 37        ...
 38
 39    def is_literal_type(type_: type[Any]) -> bool:  # noqa: ARG001
 40        ...
 41
 42    def is_typeddict(type_: type[Any]) -> bool:  # noqa: ARG001
 43        ...
 44
 45else:
 46    # v1 re-exports
 47    if PYDANTIC_V1:
 48        from pydantic.typing import (
 49            get_args as get_args,
 50            is_union as is_union,
 51            get_origin as get_origin,
 52            is_typeddict as is_typeddict,
 53            is_literal_type as is_literal_type,
 54        )
 55        from pydantic.datetime_parse import parse_date as parse_date, parse_datetime as parse_datetime
 56    else:
 57        from ._utils import (
 58            get_args as get_args,
 59            is_union as is_union,
 60            get_origin as get_origin,
 61            parse_date as parse_date,
 62            is_typeddict as is_typeddict,
 63            parse_datetime as parse_datetime,
 64            is_literal_type as is_literal_type,
 65        )
 66
 67
 68# refactored config
 69if TYPE_CHECKING:
 70    from pydantic import ConfigDict as ConfigDict
 71else:
 72    if PYDANTIC_V1:
 73        # TODO: provide an error message here?
 74        ConfigDict = None
 75    else:
 76        from pydantic import ConfigDict as ConfigDict
 77
 78
 79# renamed methods / properties
 80def parse_obj(model: type[_ModelT], value: object) -> _ModelT:
 81    if PYDANTIC_V1:
 82        return cast(_ModelT, model.parse_obj(value))  # pyright: ignore[reportDeprecated, reportUnnecessaryCast]
 83    else:
 84        return model.model_validate(value)
 85
 86
 87def field_is_required(field: FieldInfo) -> bool:
 88    if PYDANTIC_V1:
 89        return field.required  # type: ignore
 90    return field.is_required()
 91
 92
 93def field_get_default(field: FieldInfo) -> Any:
 94    value = field.get_default()
 95    if PYDANTIC_V1:
 96        return value
 97    from pydantic_core import PydanticUndefined
 98
 99    if value == PydanticUndefined:
100        return None
101    return value
102
103
104def field_outer_type(field: FieldInfo) -> Any:
105    if PYDANTIC_V1:
106        return field.outer_type_  # type: ignore
107    return field.annotation
108
109
110def get_model_config(model: type[pydantic.BaseModel]) -> Any:
111    if PYDANTIC_V1:
112        return model.__config__  # type: ignore
113    return model.model_config
114
115
116def get_model_fields(model: type[pydantic.BaseModel]) -> dict[str, FieldInfo]:
117    if PYDANTIC_V1:
118        return model.__fields__  # type: ignore
119    return model.model_fields
120
121
122def model_copy(model: _ModelT, *, deep: bool = False) -> _ModelT:
123    if PYDANTIC_V1:
124        return model.copy(deep=deep)  # type: ignore
125    return model.model_copy(deep=deep)
126
127
128def model_json(model: pydantic.BaseModel, *, indent: int | None = None) -> str:
129    if PYDANTIC_V1:
130        return model.json(indent=indent)  # type: ignore
131    return model.model_dump_json(indent=indent)
132
133
134def model_dump(
135    model: pydantic.BaseModel,
136    *,
137    exclude: IncEx | None = None,
138    exclude_unset: bool = False,
139    exclude_defaults: bool = False,
140    warnings: bool = True,
141    mode: Literal["json", "python"] = "python",
142) -> dict[str, Any]:
143    if (not PYDANTIC_V1) or hasattr(model, "model_dump"):
144        return model.model_dump(
145            mode=mode,
146            exclude=exclude,
147            exclude_unset=exclude_unset,
148            exclude_defaults=exclude_defaults,
149            # warnings are not supported in Pydantic v1
150            warnings=True if PYDANTIC_V1 else warnings,
151        )
152    return cast(
153        "dict[str, Any]",
154        model.dict(  # pyright: ignore[reportDeprecated, reportUnnecessaryCast]
155            exclude=exclude,
156            exclude_unset=exclude_unset,
157            exclude_defaults=exclude_defaults,
158        ),
159    )
160
161
162def model_parse(model: type[_ModelT], data: Any) -> _ModelT:
163    if PYDANTIC_V1:
164        return model.parse_obj(data)  # pyright: ignore[reportDeprecated]
165    return model.model_validate(data)
166
167
168def model_parse_json(model: type[_ModelT], data: str | bytes) -> _ModelT:
169    if PYDANTIC_V1:
170        return model.parse_raw(data)  # pyright: ignore[reportDeprecated]
171    return model.model_validate_json(data)
172
173
174def model_json_schema(model: type[_ModelT]) -> dict[str, Any]:
175    if PYDANTIC_V1:
176        return model.schema()  # pyright: ignore[reportDeprecated]
177    return model.model_json_schema()
178
179
180# generic models
181if TYPE_CHECKING:
182
183    class GenericModel(pydantic.BaseModel): ...
184
185else:
186    if PYDANTIC_V1:
187        import pydantic.generics
188
189        class GenericModel(pydantic.generics.GenericModel, pydantic.BaseModel): ...
190    else:
191        # there no longer needs to be a distinction in v2 but
192        # we still have to create our own subclass to avoid
193        # inconsistent MRO ordering errors
194        class GenericModel(pydantic.BaseModel): ...
195
196
197# cached properties
198if TYPE_CHECKING:
199    cached_property = property
200
201    # we define a separate type (copied from typeshed)
202    # that represents that `cached_property` is `set`able
203    # at runtime, which differs from `@property`.
204    #
205    # this is a separate type as editors likely special case
206    # `@property` and we don't want to cause issues just to have
207    # more helpful internal types.
208
209    class typed_cached_property(Generic[_T]):
210        func: Callable[[Any], _T]
211        attrname: str | None
212
213        def __init__(self, func: Callable[[Any], _T]) -> None: ...
214
215        @overload
216        def __get__(self, instance: None, owner: type[Any] | None = None) -> Self: ...
217
218        @overload
219        def __get__(self, instance: object, owner: type[Any] | None = None) -> _T: ...
220
221        def __get__(self, instance: object, owner: type[Any] | None = None) -> _T | Self:
222            raise NotImplementedError()
223
224        def __set_name__(self, owner: type[Any], name: str) -> None: ...
225
226        # __set__ is not defined at runtime, but @cached_property is designed to be settable
227        def __set__(self, instance: object, value: _T) -> None: ...
228else:
229    from functools import cached_property as cached_property
230
231    typed_cached_property = cached_property