Commit 09350cb5
Changed files (2)
src
openai
lib
tests
src/openai/lib/_pydantic.py
@@ -53,9 +53,13 @@ def _ensure_strict_json_schema(
# intersections
all_of = json_schema.get("allOf")
if is_list(all_of):
- json_schema["allOf"] = [
- _ensure_strict_json_schema(entry, path=(*path, "anyOf", str(i))) for i, entry in enumerate(all_of)
- ]
+ if len(all_of) == 1:
+ json_schema.update(_ensure_strict_json_schema(all_of[0], path=(*path, "allOf", "0")))
+ json_schema.pop("allOf")
+ else:
+ json_schema["allOf"] = [
+ _ensure_strict_json_schema(entry, path=(*path, "allOf", str(i))) for i, entry in enumerate(all_of)
+ ]
defs = json_schema.get("$defs")
if is_dict(defs):
tests/lib/test_pydantic.py
@@ -1,5 +1,8 @@
from __future__ import annotations
+from enum import Enum
+
+from pydantic import Field, BaseModel
from inline_snapshot import snapshot
import openai
@@ -161,3 +164,63 @@ def test_most_types() -> None:
},
}
)
+
+
+class Color(Enum):
+ RED = "red"
+ BLUE = "blue"
+ GREEN = "green"
+
+
+class ColorDetection(BaseModel):
+ color: Color = Field(description="The detected color")
+ hex_color_code: str = Field(description="The hex color code of the detected color")
+
+
+def test_enums() -> None:
+ if PYDANTIC_V2:
+ assert openai.pydantic_function_tool(ColorDetection)["function"] == snapshot(
+ {
+ "name": "ColorDetection",
+ "strict": True,
+ "parameters": {
+ "$defs": {"Color": {"enum": ["red", "blue", "green"], "title": "Color", "type": "string"}},
+ "properties": {
+ "color": {"description": "The detected color", "$ref": "#/$defs/Color"},
+ "hex_color_code": {
+ "description": "The hex color code of the detected color",
+ "title": "Hex Color Code",
+ "type": "string",
+ },
+ },
+ "required": ["color", "hex_color_code"],
+ "title": "ColorDetection",
+ "type": "object",
+ "additionalProperties": False,
+ },
+ }
+ )
+ else:
+ assert openai.pydantic_function_tool(ColorDetection)["function"] == snapshot(
+ {
+ "name": "ColorDetection",
+ "strict": True,
+ "parameters": {
+ "properties": {
+ "color": {"description": "The detected color", "$ref": "#/definitions/Color"},
+ "hex_color_code": {
+ "description": "The hex color code of the detected color",
+ "title": "Hex Color Code",
+ "type": "string",
+ },
+ },
+ "required": ["color", "hex_color_code"],
+ "title": "ColorDetection",
+ "definitions": {
+ "Color": {"title": "Color", "description": "An enumeration.", "enum": ["red", "blue", "green"]}
+ },
+ "type": "object",
+ "additionalProperties": False,
+ },
+ }
+ )