main
1# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
2
3from __future__ import annotations
4
5import os as _os
6import typing as _t
7from typing_extensions import override
8
9from . import types
10from ._types import NOT_GIVEN, Omit, NoneType, NotGiven, Transport, ProxiesTypes, omit, not_given
11from ._utils import file_from_path
12from ._client import Client, OpenAI, Stream, Timeout, Transport, AsyncClient, AsyncOpenAI, AsyncStream, RequestOptions
13from ._models import BaseModel
14from ._version import __title__, __version__
15from ._response import APIResponse as APIResponse, AsyncAPIResponse as AsyncAPIResponse
16from ._constants import DEFAULT_TIMEOUT, DEFAULT_MAX_RETRIES, DEFAULT_CONNECTION_LIMITS
17from ._exceptions import (
18 APIError,
19 OpenAIError,
20 ConflictError,
21 NotFoundError,
22 APIStatusError,
23 RateLimitError,
24 APITimeoutError,
25 BadRequestError,
26 APIConnectionError,
27 AuthenticationError,
28 InternalServerError,
29 PermissionDeniedError,
30 LengthFinishReasonError,
31 UnprocessableEntityError,
32 APIResponseValidationError,
33 InvalidWebhookSignatureError,
34 ContentFilterFinishReasonError,
35)
36from ._base_client import DefaultHttpxClient, DefaultAioHttpClient, DefaultAsyncHttpxClient
37from ._utils._logs import setup_logging as _setup_logging
38from ._legacy_response import HttpxBinaryResponseContent as HttpxBinaryResponseContent
39
40__all__ = [
41 "types",
42 "__version__",
43 "__title__",
44 "NoneType",
45 "Transport",
46 "ProxiesTypes",
47 "NotGiven",
48 "NOT_GIVEN",
49 "not_given",
50 "Omit",
51 "omit",
52 "OpenAIError",
53 "APIError",
54 "APIStatusError",
55 "APITimeoutError",
56 "APIConnectionError",
57 "APIResponseValidationError",
58 "BadRequestError",
59 "AuthenticationError",
60 "PermissionDeniedError",
61 "NotFoundError",
62 "ConflictError",
63 "UnprocessableEntityError",
64 "RateLimitError",
65 "InternalServerError",
66 "LengthFinishReasonError",
67 "ContentFilterFinishReasonError",
68 "InvalidWebhookSignatureError",
69 "Timeout",
70 "RequestOptions",
71 "Client",
72 "AsyncClient",
73 "Stream",
74 "AsyncStream",
75 "OpenAI",
76 "AsyncOpenAI",
77 "file_from_path",
78 "BaseModel",
79 "DEFAULT_TIMEOUT",
80 "DEFAULT_MAX_RETRIES",
81 "DEFAULT_CONNECTION_LIMITS",
82 "DefaultHttpxClient",
83 "DefaultAsyncHttpxClient",
84 "DefaultAioHttpClient",
85]
86
87if not _t.TYPE_CHECKING:
88 from ._utils._resources_proxy import resources as resources
89
90from .lib import azure as _azure, pydantic_function_tool as pydantic_function_tool
91from .version import VERSION as VERSION
92from .lib.azure import AzureOpenAI as AzureOpenAI, AsyncAzureOpenAI as AsyncAzureOpenAI
93from .lib._old_api import *
94from .lib.streaming import (
95 AssistantEventHandler as AssistantEventHandler,
96 AsyncAssistantEventHandler as AsyncAssistantEventHandler,
97)
98
99_setup_logging()
100
101# Update the __module__ attribute for exported symbols so that
102# error messages point to this module instead of the module
103# it was originally defined in, e.g.
104# openai._exceptions.NotFoundError -> openai.NotFoundError
105__locals = locals()
106for __name in __all__:
107 if not __name.startswith("__"):
108 try:
109 __locals[__name].__module__ = "openai"
110 except (TypeError, AttributeError):
111 # Some of our exported symbols are builtins which we can't set attributes for.
112 pass
113
114# ------ Module level client ------
115import typing as _t
116import typing_extensions as _te
117
118import httpx as _httpx
119
120from ._base_client import DEFAULT_TIMEOUT, DEFAULT_MAX_RETRIES
121
122api_key: str | None = None
123
124organization: str | None = None
125
126project: str | None = None
127
128webhook_secret: str | None = None
129
130base_url: str | _httpx.URL | None = None
131
132timeout: float | Timeout | None = DEFAULT_TIMEOUT
133
134max_retries: int = DEFAULT_MAX_RETRIES
135
136default_headers: _t.Mapping[str, str] | None = None
137
138default_query: _t.Mapping[str, object] | None = None
139
140http_client: _httpx.Client | None = None
141
142_ApiType = _te.Literal["openai", "azure"]
143
144api_type: _ApiType | None = _t.cast(_ApiType, _os.environ.get("OPENAI_API_TYPE"))
145
146api_version: str | None = _os.environ.get("OPENAI_API_VERSION")
147
148azure_endpoint: str | None = _os.environ.get("AZURE_OPENAI_ENDPOINT")
149
150azure_ad_token: str | None = _os.environ.get("AZURE_OPENAI_AD_TOKEN")
151
152azure_ad_token_provider: _azure.AzureADTokenProvider | None = None
153
154
155class _ModuleClient(OpenAI):
156 # Note: we have to use type: ignores here as overriding class members
157 # with properties is technically unsafe but it is fine for our use case
158
159 @property # type: ignore
160 @override
161 def api_key(self) -> str | None:
162 return api_key
163
164 @api_key.setter # type: ignore
165 def api_key(self, value: str | None) -> None: # type: ignore
166 global api_key
167
168 api_key = value
169
170 @property # type: ignore
171 @override
172 def organization(self) -> str | None:
173 return organization
174
175 @organization.setter # type: ignore
176 def organization(self, value: str | None) -> None: # type: ignore
177 global organization
178
179 organization = value
180
181 @property # type: ignore
182 @override
183 def project(self) -> str | None:
184 return project
185
186 @project.setter # type: ignore
187 def project(self, value: str | None) -> None: # type: ignore
188 global project
189
190 project = value
191
192 @property # type: ignore
193 @override
194 def webhook_secret(self) -> str | None:
195 return webhook_secret
196
197 @webhook_secret.setter # type: ignore
198 def webhook_secret(self, value: str | None) -> None: # type: ignore
199 global webhook_secret
200
201 webhook_secret = value
202
203 @property
204 @override
205 def base_url(self) -> _httpx.URL:
206 if base_url is not None:
207 return _httpx.URL(base_url)
208
209 return super().base_url
210
211 @base_url.setter
212 def base_url(self, url: _httpx.URL | str) -> None:
213 super().base_url = url # type: ignore[misc]
214
215 @property # type: ignore
216 @override
217 def timeout(self) -> float | Timeout | None:
218 return timeout
219
220 @timeout.setter # type: ignore
221 def timeout(self, value: float | Timeout | None) -> None: # type: ignore
222 global timeout
223
224 timeout = value
225
226 @property # type: ignore
227 @override
228 def max_retries(self) -> int:
229 return max_retries
230
231 @max_retries.setter # type: ignore
232 def max_retries(self, value: int) -> None: # type: ignore
233 global max_retries
234
235 max_retries = value
236
237 @property # type: ignore
238 @override
239 def _custom_headers(self) -> _t.Mapping[str, str] | None:
240 return default_headers
241
242 @_custom_headers.setter # type: ignore
243 def _custom_headers(self, value: _t.Mapping[str, str] | None) -> None: # type: ignore
244 global default_headers
245
246 default_headers = value
247
248 @property # type: ignore
249 @override
250 def _custom_query(self) -> _t.Mapping[str, object] | None:
251 return default_query
252
253 @_custom_query.setter # type: ignore
254 def _custom_query(self, value: _t.Mapping[str, object] | None) -> None: # type: ignore
255 global default_query
256
257 default_query = value
258
259 @property # type: ignore
260 @override
261 def _client(self) -> _httpx.Client:
262 return http_client or super()._client
263
264 @_client.setter # type: ignore
265 def _client(self, value: _httpx.Client) -> None: # type: ignore
266 global http_client
267
268 http_client = value
269
270
271class _AzureModuleClient(_ModuleClient, AzureOpenAI): # type: ignore
272 ...
273
274
275class _AmbiguousModuleClientUsageError(OpenAIError):
276 def __init__(self) -> None:
277 super().__init__(
278 "Ambiguous use of module client; please set `openai.api_type` or the `OPENAI_API_TYPE` environment variable to `openai` or `azure`"
279 )
280
281
282def _has_openai_credentials() -> bool:
283 return _os.environ.get("OPENAI_API_KEY") is not None
284
285
286def _has_azure_credentials() -> bool:
287 return azure_endpoint is not None or _os.environ.get("AZURE_OPENAI_API_KEY") is not None
288
289
290def _has_azure_ad_credentials() -> bool:
291 return (
292 _os.environ.get("AZURE_OPENAI_AD_TOKEN") is not None
293 or azure_ad_token is not None
294 or azure_ad_token_provider is not None
295 )
296
297
298_client: OpenAI | None = None
299
300
301def _load_client() -> OpenAI: # type: ignore[reportUnusedFunction]
302 global _client
303
304 if _client is None:
305 global api_type, azure_endpoint, azure_ad_token, api_version
306
307 if azure_endpoint is None:
308 azure_endpoint = _os.environ.get("AZURE_OPENAI_ENDPOINT")
309
310 if azure_ad_token is None:
311 azure_ad_token = _os.environ.get("AZURE_OPENAI_AD_TOKEN")
312
313 if api_version is None:
314 api_version = _os.environ.get("OPENAI_API_VERSION")
315
316 if api_type is None:
317 has_openai = _has_openai_credentials()
318 has_azure = _has_azure_credentials()
319 has_azure_ad = _has_azure_ad_credentials()
320
321 if has_openai and (has_azure or has_azure_ad):
322 raise _AmbiguousModuleClientUsageError()
323
324 if (azure_ad_token is not None or azure_ad_token_provider is not None) and _os.environ.get(
325 "AZURE_OPENAI_API_KEY"
326 ) is not None:
327 raise _AmbiguousModuleClientUsageError()
328
329 if has_azure or has_azure_ad:
330 api_type = "azure"
331 else:
332 api_type = "openai"
333
334 if api_type == "azure":
335 _client = _AzureModuleClient( # type: ignore
336 api_version=api_version,
337 azure_endpoint=azure_endpoint,
338 api_key=api_key,
339 azure_ad_token=azure_ad_token,
340 azure_ad_token_provider=azure_ad_token_provider,
341 organization=organization,
342 base_url=base_url,
343 timeout=timeout,
344 max_retries=max_retries,
345 default_headers=default_headers,
346 default_query=default_query,
347 http_client=http_client,
348 )
349 return _client
350
351 _client = _ModuleClient(
352 api_key=api_key,
353 organization=organization,
354 project=project,
355 webhook_secret=webhook_secret,
356 base_url=base_url,
357 timeout=timeout,
358 max_retries=max_retries,
359 default_headers=default_headers,
360 default_query=default_query,
361 http_client=http_client,
362 )
363 return _client
364
365 return _client
366
367
368def _reset_client() -> None: # type: ignore[reportUnusedFunction]
369 global _client
370
371 _client = None
372
373
374from ._module_client import (
375 beta as beta,
376 chat as chat,
377 audio as audio,
378 evals as evals,
379 files as files,
380 images as images,
381 models as models,
382 videos as videos,
383 batches as batches,
384 uploads as uploads,
385 realtime as realtime,
386 webhooks as webhooks,
387 responses as responses,
388 containers as containers,
389 embeddings as embeddings,
390 completions as completions,
391 fine_tuning as fine_tuning,
392 moderations as moderations,
393 conversations as conversations,
394 vector_stores as vector_stores,
395)