main
1from __future__ import annotations
2
3import sys
4import inspect
5import typing_extensions
6from typing import get_args
7
8import pytest
9
10from openai import OpenAI, AsyncOpenAI
11from tests.utils import evaluate_forwardref
12from openai._utils import assert_signatures_in_sync
13from openai._compat import is_literal_type
14from openai._utils._typing import is_union_type
15from openai.types.audio_response_format import AudioResponseFormat
16
17
18@pytest.mark.parametrize("sync", [True, False], ids=["sync", "async"])
19def test_translation_create_overloads_in_sync(sync: bool, client: OpenAI, async_client: AsyncOpenAI) -> None:
20 checking_client: OpenAI | AsyncOpenAI = client if sync else async_client
21
22 fn = checking_client.audio.translations.create
23 overload_response_formats: set[str] = set()
24
25 for i, overload in enumerate(typing_extensions.get_overloads(fn)):
26 assert_signatures_in_sync(
27 fn,
28 overload,
29 exclude_params={"response_format", "stream"},
30 description=f" for overload {i}",
31 )
32
33 sig = inspect.signature(overload)
34 typ = evaluate_forwardref(
35 sig.parameters["response_format"].annotation,
36 globalns=sys.modules[fn.__module__].__dict__,
37 )
38 if is_union_type(typ):
39 for arg in get_args(typ):
40 if not is_literal_type(arg):
41 continue
42
43 overload_response_formats.update(get_args(arg))
44 elif is_literal_type(typ):
45 overload_response_formats.update(get_args(typ))
46
47 # 'diarized_json' applies only to transcriptions, not translations.
48 src_response_formats: set[str] = set(get_args(AudioResponseFormat)) - {"diarized_json"}
49 diff = src_response_formats.difference(overload_response_formats)
50 assert len(diff) == 0, f"some response format options don't have overloads"
51
52
53@pytest.mark.parametrize("sync", [True, False], ids=["sync", "async"])
54def test_transcription_create_overloads_in_sync(sync: bool, client: OpenAI, async_client: AsyncOpenAI) -> None:
55 checking_client: OpenAI | AsyncOpenAI = client if sync else async_client
56
57 fn = checking_client.audio.transcriptions.create
58 overload_response_formats: set[str] = set()
59
60 for i, overload in enumerate(typing_extensions.get_overloads(fn)):
61 sig = inspect.signature(overload)
62 typ = evaluate_forwardref(
63 sig.parameters["response_format"].annotation,
64 globalns=sys.modules[fn.__module__].__dict__,
65 )
66
67 exclude_params = {"response_format", "stream"}
68 # known_speaker_names and known_speaker_references are only supported by diarized_json
69 if not (is_literal_type(typ) and set(get_args(typ)) == {"diarized_json"}):
70 exclude_params.update({"known_speaker_names", "known_speaker_references"})
71
72 # diarized_json does not support these parameters
73 if is_literal_type(typ) and set(get_args(typ)) == {"diarized_json"}:
74 exclude_params.update({"include", "prompt", "timestamp_granularities"})
75
76 assert_signatures_in_sync(
77 fn,
78 overload,
79 exclude_params=exclude_params,
80 description=f" for overload {i}",
81 )
82 if is_union_type(typ):
83 for arg in get_args(typ):
84 if not is_literal_type(arg):
85 continue
86
87 overload_response_formats.update(get_args(arg))
88 elif is_literal_type(typ):
89 overload_response_formats.update(get_args(typ))
90
91 src_response_formats: set[str] = set(get_args(AudioResponseFormat))
92 diff = src_response_formats.difference(overload_response_formats)
93 assert len(diff) == 0, f"some response format options don't have overloads"