main
  1# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
  2
  3from __future__ import annotations
  4
  5import os as _os
  6import typing as _t
  7from typing_extensions import override
  8
  9from . import types
 10from ._types import NOT_GIVEN, Omit, NoneType, NotGiven, Transport, ProxiesTypes, omit, not_given
 11from ._utils import file_from_path
 12from ._client import Client, OpenAI, Stream, Timeout, Transport, AsyncClient, AsyncOpenAI, AsyncStream, RequestOptions
 13from ._models import BaseModel
 14from ._version import __title__, __version__
 15from ._response import APIResponse as APIResponse, AsyncAPIResponse as AsyncAPIResponse
 16from ._constants import DEFAULT_TIMEOUT, DEFAULT_MAX_RETRIES, DEFAULT_CONNECTION_LIMITS
 17from ._exceptions import (
 18    APIError,
 19    OpenAIError,
 20    ConflictError,
 21    NotFoundError,
 22    APIStatusError,
 23    RateLimitError,
 24    APITimeoutError,
 25    BadRequestError,
 26    APIConnectionError,
 27    AuthenticationError,
 28    InternalServerError,
 29    PermissionDeniedError,
 30    LengthFinishReasonError,
 31    UnprocessableEntityError,
 32    APIResponseValidationError,
 33    InvalidWebhookSignatureError,
 34    ContentFilterFinishReasonError,
 35)
 36from ._base_client import DefaultHttpxClient, DefaultAioHttpClient, DefaultAsyncHttpxClient
 37from ._utils._logs import setup_logging as _setup_logging
 38from ._legacy_response import HttpxBinaryResponseContent as HttpxBinaryResponseContent
 39
 40__all__ = [
 41    "types",
 42    "__version__",
 43    "__title__",
 44    "NoneType",
 45    "Transport",
 46    "ProxiesTypes",
 47    "NotGiven",
 48    "NOT_GIVEN",
 49    "not_given",
 50    "Omit",
 51    "omit",
 52    "OpenAIError",
 53    "APIError",
 54    "APIStatusError",
 55    "APITimeoutError",
 56    "APIConnectionError",
 57    "APIResponseValidationError",
 58    "BadRequestError",
 59    "AuthenticationError",
 60    "PermissionDeniedError",
 61    "NotFoundError",
 62    "ConflictError",
 63    "UnprocessableEntityError",
 64    "RateLimitError",
 65    "InternalServerError",
 66    "LengthFinishReasonError",
 67    "ContentFilterFinishReasonError",
 68    "InvalidWebhookSignatureError",
 69    "Timeout",
 70    "RequestOptions",
 71    "Client",
 72    "AsyncClient",
 73    "Stream",
 74    "AsyncStream",
 75    "OpenAI",
 76    "AsyncOpenAI",
 77    "file_from_path",
 78    "BaseModel",
 79    "DEFAULT_TIMEOUT",
 80    "DEFAULT_MAX_RETRIES",
 81    "DEFAULT_CONNECTION_LIMITS",
 82    "DefaultHttpxClient",
 83    "DefaultAsyncHttpxClient",
 84    "DefaultAioHttpClient",
 85]
 86
 87if not _t.TYPE_CHECKING:
 88    from ._utils._resources_proxy import resources as resources
 89
 90from .lib import azure as _azure, pydantic_function_tool as pydantic_function_tool
 91from .version import VERSION as VERSION
 92from .lib.azure import AzureOpenAI as AzureOpenAI, AsyncAzureOpenAI as AsyncAzureOpenAI
 93from .lib._old_api import *
 94from .lib.streaming import (
 95    AssistantEventHandler as AssistantEventHandler,
 96    AsyncAssistantEventHandler as AsyncAssistantEventHandler,
 97)
 98
 99_setup_logging()
100
101# Update the __module__ attribute for exported symbols so that
102# error messages point to this module instead of the module
103# it was originally defined in, e.g.
104# openai._exceptions.NotFoundError -> openai.NotFoundError
105__locals = locals()
106for __name in __all__:
107    if not __name.startswith("__"):
108        try:
109            __locals[__name].__module__ = "openai"
110        except (TypeError, AttributeError):
111            # Some of our exported symbols are builtins which we can't set attributes for.
112            pass
113
114# ------ Module level client ------
115import typing as _t
116import typing_extensions as _te
117
118import httpx as _httpx
119
120from ._base_client import DEFAULT_TIMEOUT, DEFAULT_MAX_RETRIES
121
122api_key: str | None = None
123
124organization: str | None = None
125
126project: str | None = None
127
128webhook_secret: str | None = None
129
130base_url: str | _httpx.URL | None = None
131
132timeout: float | Timeout | None = DEFAULT_TIMEOUT
133
134max_retries: int = DEFAULT_MAX_RETRIES
135
136default_headers: _t.Mapping[str, str] | None = None
137
138default_query: _t.Mapping[str, object] | None = None
139
140http_client: _httpx.Client | None = None
141
142_ApiType = _te.Literal["openai", "azure"]
143
144api_type: _ApiType | None = _t.cast(_ApiType, _os.environ.get("OPENAI_API_TYPE"))
145
146api_version: str | None = _os.environ.get("OPENAI_API_VERSION")
147
148azure_endpoint: str | None = _os.environ.get("AZURE_OPENAI_ENDPOINT")
149
150azure_ad_token: str | None = _os.environ.get("AZURE_OPENAI_AD_TOKEN")
151
152azure_ad_token_provider: _azure.AzureADTokenProvider | None = None
153
154
155class _ModuleClient(OpenAI):
156    # Note: we have to use type: ignores here as overriding class members
157    # with properties is technically unsafe but it is fine for our use case
158
159    @property  # type: ignore
160    @override
161    def api_key(self) -> str | None:
162        return api_key
163
164    @api_key.setter  # type: ignore
165    def api_key(self, value: str | None) -> None:  # type: ignore
166        global api_key
167
168        api_key = value
169
170    @property  # type: ignore
171    @override
172    def organization(self) -> str | None:
173        return organization
174
175    @organization.setter  # type: ignore
176    def organization(self, value: str | None) -> None:  # type: ignore
177        global organization
178
179        organization = value
180
181    @property  # type: ignore
182    @override
183    def project(self) -> str | None:
184        return project
185
186    @project.setter  # type: ignore
187    def project(self, value: str | None) -> None:  # type: ignore
188        global project
189
190        project = value
191
192    @property  # type: ignore
193    @override
194    def webhook_secret(self) -> str | None:
195        return webhook_secret
196
197    @webhook_secret.setter  # type: ignore
198    def webhook_secret(self, value: str | None) -> None:  # type: ignore
199        global webhook_secret
200
201        webhook_secret = value
202
203    @property
204    @override
205    def base_url(self) -> _httpx.URL:
206        if base_url is not None:
207            return _httpx.URL(base_url)
208
209        return super().base_url
210
211    @base_url.setter
212    def base_url(self, url: _httpx.URL | str) -> None:
213        super().base_url = url  # type: ignore[misc]
214
215    @property  # type: ignore
216    @override
217    def timeout(self) -> float | Timeout | None:
218        return timeout
219
220    @timeout.setter  # type: ignore
221    def timeout(self, value: float | Timeout | None) -> None:  # type: ignore
222        global timeout
223
224        timeout = value
225
226    @property  # type: ignore
227    @override
228    def max_retries(self) -> int:
229        return max_retries
230
231    @max_retries.setter  # type: ignore
232    def max_retries(self, value: int) -> None:  # type: ignore
233        global max_retries
234
235        max_retries = value
236
237    @property  # type: ignore
238    @override
239    def _custom_headers(self) -> _t.Mapping[str, str] | None:
240        return default_headers
241
242    @_custom_headers.setter  # type: ignore
243    def _custom_headers(self, value: _t.Mapping[str, str] | None) -> None:  # type: ignore
244        global default_headers
245
246        default_headers = value
247
248    @property  # type: ignore
249    @override
250    def _custom_query(self) -> _t.Mapping[str, object] | None:
251        return default_query
252
253    @_custom_query.setter  # type: ignore
254    def _custom_query(self, value: _t.Mapping[str, object] | None) -> None:  # type: ignore
255        global default_query
256
257        default_query = value
258
259    @property  # type: ignore
260    @override
261    def _client(self) -> _httpx.Client:
262        return http_client or super()._client
263
264    @_client.setter  # type: ignore
265    def _client(self, value: _httpx.Client) -> None:  # type: ignore
266        global http_client
267
268        http_client = value
269
270
271class _AzureModuleClient(_ModuleClient, AzureOpenAI):  # type: ignore
272    ...
273
274
275class _AmbiguousModuleClientUsageError(OpenAIError):
276    def __init__(self) -> None:
277        super().__init__(
278            "Ambiguous use of module client; please set `openai.api_type` or the `OPENAI_API_TYPE` environment variable to `openai` or `azure`"
279        )
280
281
282def _has_openai_credentials() -> bool:
283    return _os.environ.get("OPENAI_API_KEY") is not None
284
285
286def _has_azure_credentials() -> bool:
287    return azure_endpoint is not None or _os.environ.get("AZURE_OPENAI_API_KEY") is not None
288
289
290def _has_azure_ad_credentials() -> bool:
291    return (
292        _os.environ.get("AZURE_OPENAI_AD_TOKEN") is not None
293        or azure_ad_token is not None
294        or azure_ad_token_provider is not None
295    )
296
297
298_client: OpenAI | None = None
299
300
301def _load_client() -> OpenAI:  # type: ignore[reportUnusedFunction]
302    global _client
303
304    if _client is None:
305        global api_type, azure_endpoint, azure_ad_token, api_version
306
307        if azure_endpoint is None:
308            azure_endpoint = _os.environ.get("AZURE_OPENAI_ENDPOINT")
309
310        if azure_ad_token is None:
311            azure_ad_token = _os.environ.get("AZURE_OPENAI_AD_TOKEN")
312
313        if api_version is None:
314            api_version = _os.environ.get("OPENAI_API_VERSION")
315
316        if api_type is None:
317            has_openai = _has_openai_credentials()
318            has_azure = _has_azure_credentials()
319            has_azure_ad = _has_azure_ad_credentials()
320
321            if has_openai and (has_azure or has_azure_ad):
322                raise _AmbiguousModuleClientUsageError()
323
324            if (azure_ad_token is not None or azure_ad_token_provider is not None) and _os.environ.get(
325                "AZURE_OPENAI_API_KEY"
326            ) is not None:
327                raise _AmbiguousModuleClientUsageError()
328
329            if has_azure or has_azure_ad:
330                api_type = "azure"
331            else:
332                api_type = "openai"
333
334        if api_type == "azure":
335            _client = _AzureModuleClient(  # type: ignore
336                api_version=api_version,
337                azure_endpoint=azure_endpoint,
338                api_key=api_key,
339                azure_ad_token=azure_ad_token,
340                azure_ad_token_provider=azure_ad_token_provider,
341                organization=organization,
342                base_url=base_url,
343                timeout=timeout,
344                max_retries=max_retries,
345                default_headers=default_headers,
346                default_query=default_query,
347                http_client=http_client,
348            )
349            return _client
350
351        _client = _ModuleClient(
352            api_key=api_key,
353            organization=organization,
354            project=project,
355            webhook_secret=webhook_secret,
356            base_url=base_url,
357            timeout=timeout,
358            max_retries=max_retries,
359            default_headers=default_headers,
360            default_query=default_query,
361            http_client=http_client,
362        )
363        return _client
364
365    return _client
366
367
368def _reset_client() -> None:  # type: ignore[reportUnusedFunction]
369    global _client
370
371    _client = None
372
373
374from ._module_client import (
375    beta as beta,
376    chat as chat,
377    audio as audio,
378    evals as evals,
379    files as files,
380    images as images,
381    models as models,
382    videos as videos,
383    batches as batches,
384    uploads as uploads,
385    realtime as realtime,
386    webhooks as webhooks,
387    responses as responses,
388    containers as containers,
389    embeddings as embeddings,
390    completions as completions,
391    fine_tuning as fine_tuning,
392    moderations as moderations,
393    conversations as conversations,
394    vector_stores as vector_stores,
395)