main
1from __future__ import annotations
2
3from enum import Enum
4
5from pydantic import Field, BaseModel
6from inline_snapshot import snapshot
7
8import openai
9from openai._compat import PYDANTIC_V1
10from openai.lib._pydantic import to_strict_json_schema
11
12from .schema_types.query import Query
13
14
15def test_most_types() -> None:
16 if not PYDANTIC_V1:
17 assert openai.pydantic_function_tool(Query)["function"] == snapshot(
18 {
19 "name": "Query",
20 "strict": True,
21 "parameters": {
22 "$defs": {
23 "Column": {
24 "enum": [
25 "id",
26 "status",
27 "expected_delivery_date",
28 "delivered_at",
29 "shipped_at",
30 "ordered_at",
31 "canceled_at",
32 ],
33 "title": "Column",
34 "type": "string",
35 },
36 "Condition": {
37 "properties": {
38 "column": {"title": "Column", "type": "string"},
39 "operator": {"$ref": "#/$defs/Operator"},
40 "value": {
41 "anyOf": [
42 {"type": "string"},
43 {"type": "integer"},
44 {"$ref": "#/$defs/DynamicValue"},
45 ],
46 "title": "Value",
47 },
48 },
49 "required": ["column", "operator", "value"],
50 "title": "Condition",
51 "type": "object",
52 "additionalProperties": False,
53 },
54 "DynamicValue": {
55 "properties": {"column_name": {"title": "Column Name", "type": "string"}},
56 "required": ["column_name"],
57 "title": "DynamicValue",
58 "type": "object",
59 "additionalProperties": False,
60 },
61 "Operator": {"enum": ["=", ">", "<", "<=", ">=", "!="], "title": "Operator", "type": "string"},
62 "OrderBy": {"enum": ["asc", "desc"], "title": "OrderBy", "type": "string"},
63 "Table": {"enum": ["orders", "customers", "products"], "title": "Table", "type": "string"},
64 },
65 "properties": {
66 "name": {"anyOf": [{"type": "string"}, {"type": "null"}], "title": "Name"},
67 "table_name": {"$ref": "#/$defs/Table"},
68 "columns": {
69 "items": {"$ref": "#/$defs/Column"},
70 "title": "Columns",
71 "type": "array",
72 },
73 "conditions": {
74 "items": {"$ref": "#/$defs/Condition"},
75 "title": "Conditions",
76 "type": "array",
77 },
78 "order_by": {"$ref": "#/$defs/OrderBy"},
79 },
80 "required": ["name", "table_name", "columns", "conditions", "order_by"],
81 "title": "Query",
82 "type": "object",
83 "additionalProperties": False,
84 },
85 }
86 )
87 else:
88 assert openai.pydantic_function_tool(Query)["function"] == snapshot(
89 {
90 "name": "Query",
91 "strict": True,
92 "parameters": {
93 "title": "Query",
94 "type": "object",
95 "properties": {
96 "name": {"title": "Name", "type": "string"},
97 "table_name": {"$ref": "#/definitions/Table"},
98 "columns": {"type": "array", "items": {"$ref": "#/definitions/Column"}},
99 "conditions": {
100 "title": "Conditions",
101 "type": "array",
102 "items": {"$ref": "#/definitions/Condition"},
103 },
104 "order_by": {"$ref": "#/definitions/OrderBy"},
105 },
106 "required": ["name", "table_name", "columns", "conditions", "order_by"],
107 "definitions": {
108 "Table": {
109 "title": "Table",
110 "description": "An enumeration.",
111 "enum": ["orders", "customers", "products"],
112 "type": "string",
113 },
114 "Column": {
115 "title": "Column",
116 "description": "An enumeration.",
117 "enum": [
118 "id",
119 "status",
120 "expected_delivery_date",
121 "delivered_at",
122 "shipped_at",
123 "ordered_at",
124 "canceled_at",
125 ],
126 "type": "string",
127 },
128 "Operator": {
129 "title": "Operator",
130 "description": "An enumeration.",
131 "enum": ["=", ">", "<", "<=", ">=", "!="],
132 "type": "string",
133 },
134 "DynamicValue": {
135 "title": "DynamicValue",
136 "type": "object",
137 "properties": {"column_name": {"title": "Column Name", "type": "string"}},
138 "required": ["column_name"],
139 "additionalProperties": False,
140 },
141 "Condition": {
142 "title": "Condition",
143 "type": "object",
144 "properties": {
145 "column": {"title": "Column", "type": "string"},
146 "operator": {"$ref": "#/definitions/Operator"},
147 "value": {
148 "title": "Value",
149 "anyOf": [
150 {"type": "string"},
151 {"type": "integer"},
152 {"$ref": "#/definitions/DynamicValue"},
153 ],
154 },
155 },
156 "required": ["column", "operator", "value"],
157 "additionalProperties": False,
158 },
159 "OrderBy": {
160 "title": "OrderBy",
161 "description": "An enumeration.",
162 "enum": ["asc", "desc"],
163 "type": "string",
164 },
165 },
166 "additionalProperties": False,
167 },
168 }
169 )
170
171
172class Color(Enum):
173 RED = "red"
174 BLUE = "blue"
175 GREEN = "green"
176
177
178class ColorDetection(BaseModel):
179 color: Color = Field(description="The detected color")
180 hex_color_code: str = Field(description="The hex color code of the detected color")
181
182
183def test_enums() -> None:
184 if not PYDANTIC_V1:
185 assert openai.pydantic_function_tool(ColorDetection)["function"] == snapshot(
186 {
187 "name": "ColorDetection",
188 "strict": True,
189 "parameters": {
190 "$defs": {"Color": {"enum": ["red", "blue", "green"], "title": "Color", "type": "string"}},
191 "properties": {
192 "color": {
193 "description": "The detected color",
194 "enum": ["red", "blue", "green"],
195 "title": "Color",
196 "type": "string",
197 },
198 "hex_color_code": {
199 "description": "The hex color code of the detected color",
200 "title": "Hex Color Code",
201 "type": "string",
202 },
203 },
204 "required": ["color", "hex_color_code"],
205 "title": "ColorDetection",
206 "type": "object",
207 "additionalProperties": False,
208 },
209 }
210 )
211 else:
212 assert openai.pydantic_function_tool(ColorDetection)["function"] == snapshot(
213 {
214 "name": "ColorDetection",
215 "strict": True,
216 "parameters": {
217 "properties": {
218 "color": {
219 "description": "The detected color",
220 "title": "Color",
221 "enum": ["red", "blue", "green"],
222 },
223 "hex_color_code": {
224 "description": "The hex color code of the detected color",
225 "title": "Hex Color Code",
226 "type": "string",
227 },
228 },
229 "required": ["color", "hex_color_code"],
230 "title": "ColorDetection",
231 "definitions": {
232 "Color": {"title": "Color", "description": "An enumeration.", "enum": ["red", "blue", "green"]}
233 },
234 "type": "object",
235 "additionalProperties": False,
236 },
237 }
238 )
239
240
241class Star(BaseModel):
242 name: str = Field(description="The name of the star.")
243
244
245class Galaxy(BaseModel):
246 name: str = Field(description="The name of the galaxy.")
247 largest_star: Star = Field(description="The largest star in the galaxy.")
248
249
250class Universe(BaseModel):
251 name: str = Field(description="The name of the universe.")
252 galaxy: Galaxy = Field(description="A galaxy in the universe.")
253
254
255def test_nested_inline_ref_expansion() -> None:
256 if not PYDANTIC_V1:
257 assert to_strict_json_schema(Universe) == snapshot(
258 {
259 "title": "Universe",
260 "type": "object",
261 "$defs": {
262 "Star": {
263 "title": "Star",
264 "type": "object",
265 "properties": {
266 "name": {
267 "type": "string",
268 "title": "Name",
269 "description": "The name of the star.",
270 }
271 },
272 "required": ["name"],
273 "additionalProperties": False,
274 },
275 "Galaxy": {
276 "title": "Galaxy",
277 "type": "object",
278 "properties": {
279 "name": {
280 "type": "string",
281 "title": "Name",
282 "description": "The name of the galaxy.",
283 },
284 "largest_star": {
285 "title": "Star",
286 "type": "object",
287 "properties": {
288 "name": {
289 "type": "string",
290 "title": "Name",
291 "description": "The name of the star.",
292 }
293 },
294 "required": ["name"],
295 "description": "The largest star in the galaxy.",
296 "additionalProperties": False,
297 },
298 },
299 "required": ["name", "largest_star"],
300 "additionalProperties": False,
301 },
302 },
303 "properties": {
304 "name": {
305 "type": "string",
306 "title": "Name",
307 "description": "The name of the universe.",
308 },
309 "galaxy": {
310 "title": "Galaxy",
311 "type": "object",
312 "properties": {
313 "name": {
314 "type": "string",
315 "title": "Name",
316 "description": "The name of the galaxy.",
317 },
318 "largest_star": {
319 "title": "Star",
320 "type": "object",
321 "properties": {
322 "name": {
323 "type": "string",
324 "title": "Name",
325 "description": "The name of the star.",
326 }
327 },
328 "required": ["name"],
329 "description": "The largest star in the galaxy.",
330 "additionalProperties": False,
331 },
332 },
333 "required": ["name", "largest_star"],
334 "description": "A galaxy in the universe.",
335 "additionalProperties": False,
336 },
337 },
338 "required": ["name", "galaxy"],
339 "additionalProperties": False,
340 }
341 )
342 else:
343 assert to_strict_json_schema(Universe) == snapshot(
344 {
345 "title": "Universe",
346 "type": "object",
347 "definitions": {
348 "Star": {
349 "title": "Star",
350 "type": "object",
351 "properties": {
352 "name": {"title": "Name", "description": "The name of the star.", "type": "string"}
353 },
354 "required": ["name"],
355 "additionalProperties": False,
356 },
357 "Galaxy": {
358 "title": "Galaxy",
359 "type": "object",
360 "properties": {
361 "name": {"title": "Name", "description": "The name of the galaxy.", "type": "string"},
362 "largest_star": {
363 "title": "Largest Star",
364 "description": "The largest star in the galaxy.",
365 "type": "object",
366 "properties": {
367 "name": {"title": "Name", "description": "The name of the star.", "type": "string"}
368 },
369 "required": ["name"],
370 "additionalProperties": False,
371 },
372 },
373 "required": ["name", "largest_star"],
374 "additionalProperties": False,
375 },
376 },
377 "properties": {
378 "name": {
379 "title": "Name",
380 "description": "The name of the universe.",
381 "type": "string",
382 },
383 "galaxy": {
384 "title": "Galaxy",
385 "description": "A galaxy in the universe.",
386 "type": "object",
387 "properties": {
388 "name": {
389 "title": "Name",
390 "description": "The name of the galaxy.",
391 "type": "string",
392 },
393 "largest_star": {
394 "title": "Largest Star",
395 "description": "The largest star in the galaxy.",
396 "type": "object",
397 "properties": {
398 "name": {"title": "Name", "description": "The name of the star.", "type": "string"}
399 },
400 "required": ["name"],
401 "additionalProperties": False,
402 },
403 },
404 "required": ["name", "largest_star"],
405 "additionalProperties": False,
406 },
407 },
408 "required": ["name", "galaxy"],
409 "additionalProperties": False,
410 }
411 )