main
 1from __future__ import annotations
 2
 3import sys
 4import inspect
 5import typing_extensions
 6from typing import get_args
 7
 8import pytest
 9
10from openai import OpenAI, AsyncOpenAI
11from tests.utils import evaluate_forwardref
12from openai._utils import assert_signatures_in_sync
13from openai._compat import is_literal_type
14from openai._utils._typing import is_union_type
15from openai.types.audio_response_format import AudioResponseFormat
16
17
18@pytest.mark.parametrize("sync", [True, False], ids=["sync", "async"])
19def test_translation_create_overloads_in_sync(sync: bool, client: OpenAI, async_client: AsyncOpenAI) -> None:
20    checking_client: OpenAI | AsyncOpenAI = client if sync else async_client
21
22    fn = checking_client.audio.translations.create
23    overload_response_formats: set[str] = set()
24
25    for i, overload in enumerate(typing_extensions.get_overloads(fn)):
26        assert_signatures_in_sync(
27            fn,
28            overload,
29            exclude_params={"response_format", "stream"},
30            description=f" for overload {i}",
31        )
32
33        sig = inspect.signature(overload)
34        typ = evaluate_forwardref(
35            sig.parameters["response_format"].annotation,
36            globalns=sys.modules[fn.__module__].__dict__,
37        )
38        if is_union_type(typ):
39            for arg in get_args(typ):
40                if not is_literal_type(arg):
41                    continue
42
43                overload_response_formats.update(get_args(arg))
44        elif is_literal_type(typ):
45            overload_response_formats.update(get_args(typ))
46
47    # 'diarized_json' applies only to transcriptions, not translations.
48    src_response_formats: set[str] = set(get_args(AudioResponseFormat)) - {"diarized_json"}
49    diff = src_response_formats.difference(overload_response_formats)
50    assert len(diff) == 0, f"some response format options don't have overloads"
51
52
53@pytest.mark.parametrize("sync", [True, False], ids=["sync", "async"])
54def test_transcription_create_overloads_in_sync(sync: bool, client: OpenAI, async_client: AsyncOpenAI) -> None:
55    checking_client: OpenAI | AsyncOpenAI = client if sync else async_client
56
57    fn = checking_client.audio.transcriptions.create
58    overload_response_formats: set[str] = set()
59
60    for i, overload in enumerate(typing_extensions.get_overloads(fn)):
61        sig = inspect.signature(overload)
62        typ = evaluate_forwardref(
63            sig.parameters["response_format"].annotation,
64            globalns=sys.modules[fn.__module__].__dict__,
65        )
66
67        exclude_params = {"response_format", "stream"}
68        # known_speaker_names and known_speaker_references are only supported by diarized_json
69        if not (is_literal_type(typ) and set(get_args(typ)) == {"diarized_json"}):
70            exclude_params.update({"known_speaker_names", "known_speaker_references"})
71
72        # diarized_json does not support these parameters
73        if is_literal_type(typ) and set(get_args(typ)) == {"diarized_json"}:
74            exclude_params.update({"include", "prompt", "timestamp_granularities"})
75
76        assert_signatures_in_sync(
77            fn,
78            overload,
79            exclude_params=exclude_params,
80            description=f" for overload {i}",
81        )
82        if is_union_type(typ):
83            for arg in get_args(typ):
84                if not is_literal_type(arg):
85                    continue
86
87                overload_response_formats.update(get_args(arg))
88        elif is_literal_type(typ):
89            overload_response_formats.update(get_args(typ))
90
91    src_response_formats: set[str] = set(get_args(AudioResponseFormat))
92    diff = src_response_formats.difference(overload_response_formats)
93    assert len(diff) == 0, f"some response format options don't have overloads"