Commit ad5ac9fc

Robert Craigie <robert@craigie.dev>
2024-08-12 15:48:17
fix(json schema): unravel `$ref`s alongside additional keys
1 parent 09350cb
Changed files (3)
src/openai/lib/_pydantic.py
@@ -10,12 +10,15 @@ from .._compat import model_json_schema
 
 
 def to_strict_json_schema(model: type[pydantic.BaseModel]) -> dict[str, Any]:
-    return _ensure_strict_json_schema(model_json_schema(model), path=())
+    schema = model_json_schema(model)
+    return _ensure_strict_json_schema(schema, path=(), root=schema)
 
 
 def _ensure_strict_json_schema(
     json_schema: object,
+    *,
     path: tuple[str, ...],
+    root: dict[str, object],
 ) -> dict[str, Any]:
     """Mutates the given JSON schema to ensure it conforms to the `strict` standard
     that the API expects.
@@ -23,6 +26,16 @@ def _ensure_strict_json_schema(
     if not is_dict(json_schema):
         raise TypeError(f"Expected {json_schema} to be a dictionary; path={path}")
 
+    defs = json_schema.get("$defs")
+    if is_dict(defs):
+        for def_name, def_schema in defs.items():
+            _ensure_strict_json_schema(def_schema, path=(*path, "$defs", def_name), root=root)
+
+    definitions = json_schema.get("definitions")
+    if is_dict(definitions):
+        for definition_name, definition_schema in definitions.items():
+            _ensure_strict_json_schema(definition_schema, path=(*path, "definitions", definition_name), root=root)
+
     typ = json_schema.get("type")
     if typ == "object" and "additionalProperties" not in json_schema:
         json_schema["additionalProperties"] = False
@@ -33,7 +46,7 @@ def _ensure_strict_json_schema(
     if is_dict(properties):
         json_schema["required"] = [prop for prop in properties.keys()]
         json_schema["properties"] = {
-            key: _ensure_strict_json_schema(prop_schema, path=(*path, "properties", key))
+            key: _ensure_strict_json_schema(prop_schema, path=(*path, "properties", key), root=root)
             for key, prop_schema in properties.items()
         }
 
@@ -41,40 +54,72 @@ def _ensure_strict_json_schema(
     # { 'type': 'array', 'items': {...} }
     items = json_schema.get("items")
     if is_dict(items):
-        json_schema["items"] = _ensure_strict_json_schema(items, path=(*path, "items"))
+        json_schema["items"] = _ensure_strict_json_schema(items, path=(*path, "items"), root=root)
 
     # unions
     any_of = json_schema.get("anyOf")
     if is_list(any_of):
         json_schema["anyOf"] = [
-            _ensure_strict_json_schema(variant, path=(*path, "anyOf", str(i))) for i, variant in enumerate(any_of)
+            _ensure_strict_json_schema(variant, path=(*path, "anyOf", str(i)), root=root)
+            for i, variant in enumerate(any_of)
         ]
 
     # intersections
     all_of = json_schema.get("allOf")
     if is_list(all_of):
         if len(all_of) == 1:
-            json_schema.update(_ensure_strict_json_schema(all_of[0], path=(*path, "allOf", "0")))
+            json_schema.update(_ensure_strict_json_schema(all_of[0], path=(*path, "allOf", "0"), root=root))
             json_schema.pop("allOf")
         else:
             json_schema["allOf"] = [
-                _ensure_strict_json_schema(entry, path=(*path, "allOf", str(i))) for i, entry in enumerate(all_of)
+                _ensure_strict_json_schema(entry, path=(*path, "allOf", str(i)), root=root)
+                for i, entry in enumerate(all_of)
             ]
 
-    defs = json_schema.get("$defs")
-    if is_dict(defs):
-        for def_name, def_schema in defs.items():
-            _ensure_strict_json_schema(def_schema, path=(*path, "$defs", def_name))
+    # we can't use `$ref`s if there are also other properties defined, e.g.
+    # `{"$ref": "...", "description": "my description"}`
+    #
+    # so we unravel the ref
+    # `{"type": "string", "description": "my description"}`
+    ref = json_schema.get("$ref")
+    if ref and has_more_than_n_keys(json_schema, 1):
+        assert isinstance(ref, str), f"Received non-string $ref - {ref}"
 
-    definitions = json_schema.get("definitions")
-    if is_dict(definitions):
-        for definition_name, definition_schema in definitions.items():
-            _ensure_strict_json_schema(definition_schema, path=(*path, "definitions", definition_name))
+        resolved = resolve_ref(root=root, ref=ref)
+        if not is_dict(resolved):
+            raise ValueError(f"Expected `$ref: {ref}` to resolved to a dictionary but got {resolved}")
+
+        # properties from the json schema take priority over the ones on the `$ref`
+        json_schema.update({**resolved, **json_schema})
+        json_schema.pop("$ref")
 
     return json_schema
 
 
+def resolve_ref(*, root: dict[str, object], ref: str) -> object:
+    if not ref.startswith("#/"):
+        raise ValueError(f"Unexpected $ref format {ref!r}; Does not start with #/")
+
+    path = ref[2:].split("/")
+    resolved = root
+    for key in path:
+        value = resolved[key]
+        assert is_dict(value), f"encountered non-dictionary entry while resolving {ref} - {resolved}"
+        resolved = value
+
+    return resolved
+
+
 def is_dict(obj: object) -> TypeGuard[dict[str, object]]:
     # just pretend that we know there are only `str` keys
     # as that check is not worth the performance cost
     return _is_dict(obj)
+
+
+def has_more_than_n_keys(obj: dict[str, object], n: int) -> bool:
+    i = 0
+    for _ in obj.keys():
+        i += 1
+        if i > n:
+            return True
+    return False
tests/lib/chat/test_completions.py
@@ -2,13 +2,14 @@ from __future__ import annotations
 
 import os
 import json
+from enum import Enum
 from typing import Any, Callable
 from typing_extensions import Literal, TypeVar
 
 import httpx
 import pytest
 from respx import MockRouter
-from pydantic import BaseModel
+from pydantic import Field, BaseModel
 from inline_snapshot import snapshot
 
 import openai
@@ -133,6 +134,53 @@ ParsedChatCompletion[Location](
     )
 
 
+@pytest.mark.respx(base_url=base_url)
+def test_parse_pydantic_model_enum(client: OpenAI, respx_mock: MockRouter, monkeypatch: pytest.MonkeyPatch) -> None:
+    class Color(Enum):
+        """The detected color"""
+
+        RED = "red"
+        BLUE = "blue"
+        GREEN = "green"
+
+    class ColorDetection(BaseModel):
+        color: Color
+        hex_color_code: str = Field(description="The hex color code of the detected color")
+
+    completion = _make_snapshot_request(
+        lambda c: c.beta.chat.completions.parse(
+            model="gpt-4o-2024-08-06",
+            messages=[
+                {"role": "user", "content": "What color is a Coke can?"},
+            ],
+            response_format=ColorDetection,
+        ),
+        content_snapshot=snapshot(
+            '{"id": "chatcmpl-9vK4UZVr385F2UgZlP1ShwPn2nFxG", "object": "chat.completion", "created": 1723448878, "model": "gpt-4o-2024-08-06", "choices": [{"index": 0, "message": {"role": "assistant", "content": "{\\"color\\":\\"red\\",\\"hex_color_code\\":\\"#FF0000\\"}", "refusal": null}, "logprobs": null, "finish_reason": "stop"}], "usage": {"prompt_tokens": 18, "completion_tokens": 14, "total_tokens": 32}, "system_fingerprint": "fp_845eaabc1f"}'
+        ),
+        mock_client=client,
+        respx_mock=respx_mock,
+    )
+
+    assert print_obj(completion.choices[0], monkeypatch) == snapshot(
+        """\
+ParsedChoice[ColorDetection](
+    finish_reason='stop',
+    index=0,
+    logprobs=None,
+    message=ParsedChatCompletionMessage[ColorDetection](
+        content='{"color":"red","hex_color_code":"#FF0000"}',
+        function_call=None,
+        parsed=ColorDetection(color=<Color.RED: 'red'>, hex_color_code='#FF0000'),
+        refusal=None,
+        role='assistant',
+        tool_calls=[]
+    )
+)
+"""
+    )
+
+
 @pytest.mark.respx(base_url=base_url)
 def test_parse_pydantic_model_multiple_choices(
     client: OpenAI, respx_mock: MockRouter, monkeypatch: pytest.MonkeyPatch
tests/lib/test_pydantic.py
@@ -186,7 +186,12 @@ def test_enums() -> None:
                 "parameters": {
                     "$defs": {"Color": {"enum": ["red", "blue", "green"], "title": "Color", "type": "string"}},
                     "properties": {
-                        "color": {"description": "The detected color", "$ref": "#/$defs/Color"},
+                        "color": {
+                            "description": "The detected color",
+                            "enum": ["red", "blue", "green"],
+                            "title": "Color",
+                            "type": "string",
+                        },
                         "hex_color_code": {
                             "description": "The hex color code of the detected color",
                             "title": "Hex Color Code",
@@ -207,7 +212,11 @@ def test_enums() -> None:
                 "strict": True,
                 "parameters": {
                     "properties": {
-                        "color": {"description": "The detected color", "$ref": "#/definitions/Color"},
+                        "color": {
+                            "description": "The detected color",
+                            "title": "Color",
+                            "enum": ["red", "blue", "green"],
+                        },
                         "hex_color_code": {
                             "description": "The hex color code of the detected color",
                             "title": "Hex Color Code",