Commit 09350cb5

Robert Craigie <robert@craigie.dev>
2024-08-12 15:30:48
fix(json schema): unwrap `allOf`s with one entry
1 parent 1a388a1
Changed files (2)
src
openai
tests
src/openai/lib/_pydantic.py
@@ -53,9 +53,13 @@ def _ensure_strict_json_schema(
     # intersections
     all_of = json_schema.get("allOf")
     if is_list(all_of):
-        json_schema["allOf"] = [
-            _ensure_strict_json_schema(entry, path=(*path, "anyOf", str(i))) for i, entry in enumerate(all_of)
-        ]
+        if len(all_of) == 1:
+            json_schema.update(_ensure_strict_json_schema(all_of[0], path=(*path, "allOf", "0")))
+            json_schema.pop("allOf")
+        else:
+            json_schema["allOf"] = [
+                _ensure_strict_json_schema(entry, path=(*path, "allOf", str(i))) for i, entry in enumerate(all_of)
+            ]
 
     defs = json_schema.get("$defs")
     if is_dict(defs):
tests/lib/test_pydantic.py
@@ -1,5 +1,8 @@
 from __future__ import annotations
 
+from enum import Enum
+
+from pydantic import Field, BaseModel
 from inline_snapshot import snapshot
 
 import openai
@@ -161,3 +164,63 @@ def test_most_types() -> None:
                 },
             }
         )
+
+
+class Color(Enum):
+    RED = "red"
+    BLUE = "blue"
+    GREEN = "green"
+
+
+class ColorDetection(BaseModel):
+    color: Color = Field(description="The detected color")
+    hex_color_code: str = Field(description="The hex color code of the detected color")
+
+
+def test_enums() -> None:
+    if PYDANTIC_V2:
+        assert openai.pydantic_function_tool(ColorDetection)["function"] == snapshot(
+            {
+                "name": "ColorDetection",
+                "strict": True,
+                "parameters": {
+                    "$defs": {"Color": {"enum": ["red", "blue", "green"], "title": "Color", "type": "string"}},
+                    "properties": {
+                        "color": {"description": "The detected color", "$ref": "#/$defs/Color"},
+                        "hex_color_code": {
+                            "description": "The hex color code of the detected color",
+                            "title": "Hex Color Code",
+                            "type": "string",
+                        },
+                    },
+                    "required": ["color", "hex_color_code"],
+                    "title": "ColorDetection",
+                    "type": "object",
+                    "additionalProperties": False,
+                },
+            }
+        )
+    else:
+        assert openai.pydantic_function_tool(ColorDetection)["function"] == snapshot(
+            {
+                "name": "ColorDetection",
+                "strict": True,
+                "parameters": {
+                    "properties": {
+                        "color": {"description": "The detected color", "$ref": "#/definitions/Color"},
+                        "hex_color_code": {
+                            "description": "The hex color code of the detected color",
+                            "title": "Hex Color Code",
+                            "type": "string",
+                        },
+                    },
+                    "required": ["color", "hex_color_code"],
+                    "title": "ColorDetection",
+                    "definitions": {
+                        "Color": {"title": "Color", "description": "An enumeration.", "enum": ["red", "blue", "green"]}
+                    },
+                    "type": "object",
+                    "additionalProperties": False,
+                },
+            }
+        )