Commit 97a6895b
Changed files (3)
src
openai
tests
src/openai/lib/azure.py
@@ -3,7 +3,7 @@ from __future__ import annotations
import os
import inspect
from typing import Any, Union, Mapping, TypeVar, Callable, Awaitable, overload
-from typing_extensions import override
+from typing_extensions import Self, override
import httpx
@@ -178,7 +178,7 @@ class AzureOpenAI(BaseAzureClient[httpx.Client, Stream[Any]], OpenAI):
if default_query is None:
default_query = {"api-version": api_version}
else:
- default_query = {"api-version": api_version, **default_query}
+ default_query = {**default_query, "api-version": api_version}
if base_url is None:
if azure_endpoint is None:
@@ -212,9 +212,53 @@ class AzureOpenAI(BaseAzureClient[httpx.Client, Stream[Any]], OpenAI):
http_client=http_client,
_strict_response_validation=_strict_response_validation,
)
+ self._api_version = api_version
self._azure_ad_token = azure_ad_token
self._azure_ad_token_provider = azure_ad_token_provider
+ @override
+ def copy(
+ self,
+ *,
+ api_key: str | None = None,
+ organization: str | None = None,
+ api_version: str | None = None,
+ azure_ad_token: str | None = None,
+ azure_ad_token_provider: AzureADTokenProvider | None = None,
+ base_url: str | httpx.URL | None = None,
+ timeout: float | Timeout | None | NotGiven = NOT_GIVEN,
+ http_client: httpx.Client | None = None,
+ max_retries: int | NotGiven = NOT_GIVEN,
+ default_headers: Mapping[str, str] | None = None,
+ set_default_headers: Mapping[str, str] | None = None,
+ default_query: Mapping[str, object] | None = None,
+ set_default_query: Mapping[str, object] | None = None,
+ _extra_kwargs: Mapping[str, Any] = {},
+ ) -> Self:
+ """
+ Create a new client instance re-using the same options given to the current client with optional overriding.
+ """
+ return super().copy(
+ api_key=api_key,
+ organization=organization,
+ base_url=base_url,
+ timeout=timeout,
+ http_client=http_client,
+ max_retries=max_retries,
+ default_headers=default_headers,
+ set_default_headers=set_default_headers,
+ default_query=default_query,
+ set_default_query=set_default_query,
+ _extra_kwargs={
+ "api_version": api_version or self._api_version,
+ "azure_ad_token": azure_ad_token or self._azure_ad_token,
+ "azure_ad_token_provider": azure_ad_token_provider or self._azure_ad_token_provider,
+ **_extra_kwargs,
+ },
+ )
+
+ with_options = copy
+
def _get_azure_ad_token(self) -> str | None:
if self._azure_ad_token is not None:
return self._azure_ad_token
@@ -367,7 +411,7 @@ class AsyncAzureOpenAI(BaseAzureClient[httpx.AsyncClient, AsyncStream[Any]], Asy
if default_query is None:
default_query = {"api-version": api_version}
else:
- default_query = {"api-version": api_version, **default_query}
+ default_query = {**default_query, "api-version": api_version}
if base_url is None:
if azure_endpoint is None:
@@ -401,9 +445,53 @@ class AsyncAzureOpenAI(BaseAzureClient[httpx.AsyncClient, AsyncStream[Any]], Asy
http_client=http_client,
_strict_response_validation=_strict_response_validation,
)
+ self._api_version = api_version
self._azure_ad_token = azure_ad_token
self._azure_ad_token_provider = azure_ad_token_provider
+ @override
+ def copy(
+ self,
+ *,
+ api_key: str | None = None,
+ organization: str | None = None,
+ api_version: str | None = None,
+ azure_ad_token: str | None = None,
+ azure_ad_token_provider: AsyncAzureADTokenProvider | None = None,
+ base_url: str | httpx.URL | None = None,
+ timeout: float | Timeout | None | NotGiven = NOT_GIVEN,
+ http_client: httpx.AsyncClient | None = None,
+ max_retries: int | NotGiven = NOT_GIVEN,
+ default_headers: Mapping[str, str] | None = None,
+ set_default_headers: Mapping[str, str] | None = None,
+ default_query: Mapping[str, object] | None = None,
+ set_default_query: Mapping[str, object] | None = None,
+ _extra_kwargs: Mapping[str, Any] = {},
+ ) -> Self:
+ """
+ Create a new client instance re-using the same options given to the current client with optional overriding.
+ """
+ return super().copy(
+ api_key=api_key,
+ organization=organization,
+ base_url=base_url,
+ timeout=timeout,
+ http_client=http_client,
+ max_retries=max_retries,
+ default_headers=default_headers,
+ set_default_headers=set_default_headers,
+ default_query=default_query,
+ set_default_query=set_default_query,
+ _extra_kwargs={
+ "api_version": api_version or self._api_version,
+ "azure_ad_token": azure_ad_token or self._azure_ad_token,
+ "azure_ad_token_provider": azure_ad_token_provider or self._azure_ad_token_provider,
+ **_extra_kwargs,
+ },
+ )
+
+ with_options = copy
+
async def _get_azure_ad_token(self) -> str | None:
if self._azure_ad_token is not None:
return self._azure_ad_token
src/openai/_client.py
@@ -4,8 +4,8 @@ from __future__ import annotations
import os
import asyncio
-from typing import Union, Mapping
-from typing_extensions import override
+from typing import Any, Union, Mapping
+from typing_extensions import Self, override
import httpx
@@ -164,12 +164,10 @@ class OpenAI(SyncAPIClient):
set_default_headers: Mapping[str, str] | None = None,
default_query: Mapping[str, object] | None = None,
set_default_query: Mapping[str, object] | None = None,
- ) -> OpenAI:
+ _extra_kwargs: Mapping[str, Any] = {},
+ ) -> Self:
"""
Create a new client instance re-using the same options given to the current client with optional overriding.
-
- It should be noted that this does not share the underlying httpx client class which may lead
- to performance issues.
"""
if default_headers is not None and set_default_headers is not None:
raise ValueError("The `default_headers` and `set_default_headers` arguments are mutually exclusive")
@@ -199,6 +197,7 @@ class OpenAI(SyncAPIClient):
max_retries=max_retries if is_given(max_retries) else self.max_retries,
default_headers=headers,
default_query=params,
+ **_extra_kwargs,
)
# Alias for `copy` for nicer inline usage, e.g.
@@ -374,12 +373,10 @@ class AsyncOpenAI(AsyncAPIClient):
set_default_headers: Mapping[str, str] | None = None,
default_query: Mapping[str, object] | None = None,
set_default_query: Mapping[str, object] | None = None,
- ) -> AsyncOpenAI:
+ _extra_kwargs: Mapping[str, Any] = {},
+ ) -> Self:
"""
Create a new client instance re-using the same options given to the current client with optional overriding.
-
- It should be noted that this does not share the underlying httpx client class which may lead
- to performance issues.
"""
if default_headers is not None and set_default_headers is not None:
raise ValueError("The `default_headers` and `set_default_headers` arguments are mutually exclusive")
@@ -409,6 +406,7 @@ class AsyncOpenAI(AsyncAPIClient):
max_retries=max_retries if is_given(max_retries) else self.max_retries,
default_headers=headers,
default_query=params,
+ **_extra_kwargs,
)
# Alias for `copy` for nicer inline usage, e.g.
tests/lib/test_azure.py
@@ -1,4 +1,5 @@
from typing import Union
+from typing_extensions import Literal
import pytest
@@ -34,3 +35,32 @@ def test_implicit_deployment_path(client: Client) -> None:
req.url
== "https://example-resource.azure.openai.com/openai/deployments/my-deployment-model/chat/completions?api-version=2023-07-01"
)
+
+
+@pytest.mark.parametrize(
+ "client,method",
+ [
+ (sync_client, "copy"),
+ (sync_client, "with_options"),
+ (async_client, "copy"),
+ (async_client, "with_options"),
+ ],
+)
+def test_client_copying(client: Client, method: Literal["copy", "with_options"]) -> None:
+ if method == "copy":
+ copied = client.copy()
+ else:
+ copied = client.with_options()
+
+ assert copied._custom_query == {"api-version": "2023-07-01"}
+
+
+@pytest.mark.parametrize(
+ "client",
+ [sync_client, async_client],
+)
+def test_client_copying_override_options(client: Client) -> None:
+ copied = client.copy(
+ api_version="2022-05-01",
+ )
+ assert copied._custom_query == {"api-version": "2022-05-01"}