main
  1# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
  2
  3from __future__ import annotations
  4
  5import os as _os
  6
  7import httpx
  8import pytest
  9from httpx import URL
 10
 11import openai
 12from openai import DEFAULT_TIMEOUT, DEFAULT_MAX_RETRIES
 13
 14
 15def reset_state() -> None:
 16    openai._reset_client()
 17    openai.api_key = None or "My API Key"
 18    openai.organization = None
 19    openai.project = None
 20    openai.webhook_secret = None
 21    openai.base_url = None
 22    openai.timeout = DEFAULT_TIMEOUT
 23    openai.max_retries = DEFAULT_MAX_RETRIES
 24    openai.default_headers = None
 25    openai.default_query = None
 26    openai.http_client = None
 27    openai.api_type = _os.environ.get("OPENAI_API_TYPE")  # type: ignore
 28    openai.api_version = None
 29    openai.azure_endpoint = None
 30    openai.azure_ad_token = None
 31    openai.azure_ad_token_provider = None
 32
 33
 34@pytest.fixture(autouse=True)
 35def reset_state_fixture() -> None:
 36    reset_state()
 37
 38
 39def test_base_url_option() -> None:
 40    assert openai.base_url is None
 41    assert openai.completions._client.base_url == URL("https://api.openai.com/v1/")
 42
 43    openai.base_url = "http://foo.com"
 44
 45    assert openai.base_url == URL("http://foo.com")
 46    assert openai.completions._client.base_url == URL("http://foo.com")
 47
 48
 49def test_timeout_option() -> None:
 50    assert openai.timeout == openai.DEFAULT_TIMEOUT
 51    assert openai.completions._client.timeout == openai.DEFAULT_TIMEOUT
 52
 53    openai.timeout = 3
 54
 55    assert openai.timeout == 3
 56    assert openai.completions._client.timeout == 3
 57
 58
 59def test_max_retries_option() -> None:
 60    assert openai.max_retries == openai.DEFAULT_MAX_RETRIES
 61    assert openai.completions._client.max_retries == openai.DEFAULT_MAX_RETRIES
 62
 63    openai.max_retries = 1
 64
 65    assert openai.max_retries == 1
 66    assert openai.completions._client.max_retries == 1
 67
 68
 69def test_default_headers_option() -> None:
 70    assert openai.default_headers == None
 71
 72    openai.default_headers = {"Foo": "Bar"}
 73
 74    assert openai.default_headers["Foo"] == "Bar"
 75    assert openai.completions._client.default_headers["Foo"] == "Bar"
 76
 77
 78def test_default_query_option() -> None:
 79    assert openai.default_query is None
 80    assert openai.completions._client._custom_query == {}
 81
 82    openai.default_query = {"Foo": {"nested": 1}}
 83
 84    assert openai.default_query["Foo"] == {"nested": 1}
 85    assert openai.completions._client._custom_query["Foo"] == {"nested": 1}
 86
 87
 88def test_http_client_option() -> None:
 89    assert openai.http_client is None
 90
 91    original_http_client = openai.completions._client._client
 92    assert original_http_client is not None
 93
 94    new_client = httpx.Client()
 95    openai.http_client = new_client
 96
 97    assert openai.completions._client._client is new_client
 98
 99
100import contextlib
101from typing import Iterator
102
103from openai.lib.azure import AzureOpenAI
104
105
106@contextlib.contextmanager
107def fresh_env() -> Iterator[None]:
108    old = _os.environ.copy()
109
110    try:
111        _os.environ.clear()
112        yield
113    finally:
114        _os.environ.clear()
115        _os.environ.update(old)
116
117
118def test_only_api_key_results_in_openai_api() -> None:
119    with fresh_env():
120        openai.api_type = None
121        openai.api_key = "example API key"
122
123        assert type(openai.completions._client).__name__ == "_ModuleClient"
124
125
126def test_azure_api_key_env_without_api_version() -> None:
127    with fresh_env():
128        openai.api_type = None
129        _os.environ["AZURE_OPENAI_API_KEY"] = "example API key"
130
131        with pytest.raises(
132            ValueError,
133            match=r"Must provide either the `api_version` argument or the `OPENAI_API_VERSION` environment variable",
134        ):
135            openai.completions._client  # noqa: B018
136
137
138def test_azure_api_key_and_version_env() -> None:
139    with fresh_env():
140        openai.api_type = None
141        _os.environ["AZURE_OPENAI_API_KEY"] = "example API key"
142        _os.environ["OPENAI_API_VERSION"] = "example-version"
143
144        with pytest.raises(
145            ValueError,
146            match=r"Must provide one of the `base_url` or `azure_endpoint` arguments, or the `AZURE_OPENAI_ENDPOINT` environment variable",
147        ):
148            openai.completions._client  # noqa: B018
149
150
151def test_azure_api_key_version_and_endpoint_env() -> None:
152    with fresh_env():
153        openai.api_type = None
154        _os.environ["AZURE_OPENAI_API_KEY"] = "example API key"
155        _os.environ["OPENAI_API_VERSION"] = "example-version"
156        _os.environ["AZURE_OPENAI_ENDPOINT"] = "https://www.example"
157
158        openai.completions._client  # noqa: B018
159
160        assert openai.api_type == "azure"
161
162
163def test_azure_azure_ad_token_version_and_endpoint_env() -> None:
164    with fresh_env():
165        openai.api_type = None
166        _os.environ["AZURE_OPENAI_AD_TOKEN"] = "example AD token"
167        _os.environ["OPENAI_API_VERSION"] = "example-version"
168        _os.environ["AZURE_OPENAI_ENDPOINT"] = "https://www.example"
169
170        client = openai.completions._client
171        assert isinstance(client, AzureOpenAI)
172        assert client._azure_ad_token == "example AD token"
173
174
175def test_azure_azure_ad_token_provider_version_and_endpoint_env() -> None:
176    with fresh_env():
177        openai.api_type = None
178        _os.environ["OPENAI_API_VERSION"] = "example-version"
179        _os.environ["AZURE_OPENAI_ENDPOINT"] = "https://www.example"
180        openai.azure_ad_token_provider = lambda: "token"
181
182        client = openai.completions._client
183        assert isinstance(client, AzureOpenAI)
184        assert client._azure_ad_token_provider is not None
185        assert client._azure_ad_token_provider() == "token"