main
  1from __future__ import annotations
  2
  3import inspect
  4from typing import Any, TypeVar
  5from typing_extensions import TypeGuard
  6
  7import pydantic
  8
  9from .._types import NOT_GIVEN
 10from .._utils import is_dict as _is_dict, is_list
 11from .._compat import PYDANTIC_V1, model_json_schema
 12
 13_T = TypeVar("_T")
 14
 15
 16def to_strict_json_schema(model: type[pydantic.BaseModel] | pydantic.TypeAdapter[Any]) -> dict[str, Any]:
 17    if inspect.isclass(model) and is_basemodel_type(model):
 18        schema = model_json_schema(model)
 19    elif (not PYDANTIC_V1) and isinstance(model, pydantic.TypeAdapter):
 20        schema = model.json_schema()
 21    else:
 22        raise TypeError(f"Non BaseModel types are only supported with Pydantic v2 - {model}")
 23
 24    return _ensure_strict_json_schema(schema, path=(), root=schema)
 25
 26
 27def _ensure_strict_json_schema(
 28    json_schema: object,
 29    *,
 30    path: tuple[str, ...],
 31    root: dict[str, object],
 32) -> dict[str, Any]:
 33    """Mutates the given JSON schema to ensure it conforms to the `strict` standard
 34    that the API expects.
 35    """
 36    if not is_dict(json_schema):
 37        raise TypeError(f"Expected {json_schema} to be a dictionary; path={path}")
 38
 39    defs = json_schema.get("$defs")
 40    if is_dict(defs):
 41        for def_name, def_schema in defs.items():
 42            _ensure_strict_json_schema(def_schema, path=(*path, "$defs", def_name), root=root)
 43
 44    definitions = json_schema.get("definitions")
 45    if is_dict(definitions):
 46        for definition_name, definition_schema in definitions.items():
 47            _ensure_strict_json_schema(definition_schema, path=(*path, "definitions", definition_name), root=root)
 48
 49    typ = json_schema.get("type")
 50    if typ == "object" and "additionalProperties" not in json_schema:
 51        json_schema["additionalProperties"] = False
 52
 53    # object types
 54    # { 'type': 'object', 'properties': { 'a':  {...} } }
 55    properties = json_schema.get("properties")
 56    if is_dict(properties):
 57        json_schema["required"] = [prop for prop in properties.keys()]
 58        json_schema["properties"] = {
 59            key: _ensure_strict_json_schema(prop_schema, path=(*path, "properties", key), root=root)
 60            for key, prop_schema in properties.items()
 61        }
 62
 63    # arrays
 64    # { 'type': 'array', 'items': {...} }
 65    items = json_schema.get("items")
 66    if is_dict(items):
 67        json_schema["items"] = _ensure_strict_json_schema(items, path=(*path, "items"), root=root)
 68
 69    # unions
 70    any_of = json_schema.get("anyOf")
 71    if is_list(any_of):
 72        json_schema["anyOf"] = [
 73            _ensure_strict_json_schema(variant, path=(*path, "anyOf", str(i)), root=root)
 74            for i, variant in enumerate(any_of)
 75        ]
 76
 77    # intersections
 78    all_of = json_schema.get("allOf")
 79    if is_list(all_of):
 80        if len(all_of) == 1:
 81            json_schema.update(_ensure_strict_json_schema(all_of[0], path=(*path, "allOf", "0"), root=root))
 82            json_schema.pop("allOf")
 83        else:
 84            json_schema["allOf"] = [
 85                _ensure_strict_json_schema(entry, path=(*path, "allOf", str(i)), root=root)
 86                for i, entry in enumerate(all_of)
 87            ]
 88
 89    # strip `None` defaults as there's no meaningful distinction here
 90    # the schema will still be `nullable` and the model will default
 91    # to using `None` anyway
 92    if json_schema.get("default", NOT_GIVEN) is None:
 93        json_schema.pop("default")
 94
 95    # we can't use `$ref`s if there are also other properties defined, e.g.
 96    # `{"$ref": "...", "description": "my description"}`
 97    #
 98    # so we unravel the ref
 99    # `{"type": "string", "description": "my description"}`
100    ref = json_schema.get("$ref")
101    if ref and has_more_than_n_keys(json_schema, 1):
102        assert isinstance(ref, str), f"Received non-string $ref - {ref}"
103
104        resolved = resolve_ref(root=root, ref=ref)
105        if not is_dict(resolved):
106            raise ValueError(f"Expected `$ref: {ref}` to resolved to a dictionary but got {resolved}")
107
108        # properties from the json schema take priority over the ones on the `$ref`
109        json_schema.update({**resolved, **json_schema})
110        json_schema.pop("$ref")
111        # Since the schema expanded from `$ref` might not have `additionalProperties: false` applied,
112        # we call `_ensure_strict_json_schema` again to fix the inlined schema and ensure it's valid.
113        return _ensure_strict_json_schema(json_schema, path=path, root=root)
114
115    return json_schema
116
117
118def resolve_ref(*, root: dict[str, object], ref: str) -> object:
119    if not ref.startswith("#/"):
120        raise ValueError(f"Unexpected $ref format {ref!r}; Does not start with #/")
121
122    path = ref[2:].split("/")
123    resolved = root
124    for key in path:
125        value = resolved[key]
126        assert is_dict(value), f"encountered non-dictionary entry while resolving {ref} - {resolved}"
127        resolved = value
128
129    return resolved
130
131
132def is_basemodel_type(typ: type) -> TypeGuard[type[pydantic.BaseModel]]:
133    if not inspect.isclass(typ):
134        return False
135    return issubclass(typ, pydantic.BaseModel)
136
137
138def is_dataclass_like_type(typ: type) -> bool:
139    """Returns True if the given type likely used `@pydantic.dataclass`"""
140    return hasattr(typ, "__pydantic_config__")
141
142
143def is_dict(obj: object) -> TypeGuard[dict[str, object]]:
144    # just pretend that we know there are only `str` keys
145    # as that check is not worth the performance cost
146    return _is_dict(obj)
147
148
149def has_more_than_n_keys(obj: dict[str, object], n: int) -> bool:
150    i = 0
151    for _ in obj.keys():
152        i += 1
153        if i > n:
154            return True
155    return False