Commit 97a6895b

Stainless Bot <107565488+stainless-bot@users.noreply.github.com>
2023-11-22 02:14:58
fix(azure): ensure custom options can be passed to copy (#858)
1 parent b085b12
Changed files (3)
src
tests
src/openai/lib/azure.py
@@ -3,7 +3,7 @@ from __future__ import annotations
 import os
 import inspect
 from typing import Any, Union, Mapping, TypeVar, Callable, Awaitable, overload
-from typing_extensions import override
+from typing_extensions import Self, override
 
 import httpx
 
@@ -178,7 +178,7 @@ class AzureOpenAI(BaseAzureClient[httpx.Client, Stream[Any]], OpenAI):
         if default_query is None:
             default_query = {"api-version": api_version}
         else:
-            default_query = {"api-version": api_version, **default_query}
+            default_query = {**default_query, "api-version": api_version}
 
         if base_url is None:
             if azure_endpoint is None:
@@ -212,9 +212,53 @@ class AzureOpenAI(BaseAzureClient[httpx.Client, Stream[Any]], OpenAI):
             http_client=http_client,
             _strict_response_validation=_strict_response_validation,
         )
+        self._api_version = api_version
         self._azure_ad_token = azure_ad_token
         self._azure_ad_token_provider = azure_ad_token_provider
 
+    @override
+    def copy(
+        self,
+        *,
+        api_key: str | None = None,
+        organization: str | None = None,
+        api_version: str | None = None,
+        azure_ad_token: str | None = None,
+        azure_ad_token_provider: AzureADTokenProvider | None = None,
+        base_url: str | httpx.URL | None = None,
+        timeout: float | Timeout | None | NotGiven = NOT_GIVEN,
+        http_client: httpx.Client | None = None,
+        max_retries: int | NotGiven = NOT_GIVEN,
+        default_headers: Mapping[str, str] | None = None,
+        set_default_headers: Mapping[str, str] | None = None,
+        default_query: Mapping[str, object] | None = None,
+        set_default_query: Mapping[str, object] | None = None,
+        _extra_kwargs: Mapping[str, Any] = {},
+    ) -> Self:
+        """
+        Create a new client instance re-using the same options given to the current client with optional overriding.
+        """
+        return super().copy(
+            api_key=api_key,
+            organization=organization,
+            base_url=base_url,
+            timeout=timeout,
+            http_client=http_client,
+            max_retries=max_retries,
+            default_headers=default_headers,
+            set_default_headers=set_default_headers,
+            default_query=default_query,
+            set_default_query=set_default_query,
+            _extra_kwargs={
+                "api_version": api_version or self._api_version,
+                "azure_ad_token": azure_ad_token or self._azure_ad_token,
+                "azure_ad_token_provider": azure_ad_token_provider or self._azure_ad_token_provider,
+                **_extra_kwargs,
+            },
+        )
+
+    with_options = copy
+
     def _get_azure_ad_token(self) -> str | None:
         if self._azure_ad_token is not None:
             return self._azure_ad_token
@@ -367,7 +411,7 @@ class AsyncAzureOpenAI(BaseAzureClient[httpx.AsyncClient, AsyncStream[Any]], Asy
         if default_query is None:
             default_query = {"api-version": api_version}
         else:
-            default_query = {"api-version": api_version, **default_query}
+            default_query = {**default_query, "api-version": api_version}
 
         if base_url is None:
             if azure_endpoint is None:
@@ -401,9 +445,53 @@ class AsyncAzureOpenAI(BaseAzureClient[httpx.AsyncClient, AsyncStream[Any]], Asy
             http_client=http_client,
             _strict_response_validation=_strict_response_validation,
         )
+        self._api_version = api_version
         self._azure_ad_token = azure_ad_token
         self._azure_ad_token_provider = azure_ad_token_provider
 
+    @override
+    def copy(
+        self,
+        *,
+        api_key: str | None = None,
+        organization: str | None = None,
+        api_version: str | None = None,
+        azure_ad_token: str | None = None,
+        azure_ad_token_provider: AsyncAzureADTokenProvider | None = None,
+        base_url: str | httpx.URL | None = None,
+        timeout: float | Timeout | None | NotGiven = NOT_GIVEN,
+        http_client: httpx.AsyncClient | None = None,
+        max_retries: int | NotGiven = NOT_GIVEN,
+        default_headers: Mapping[str, str] | None = None,
+        set_default_headers: Mapping[str, str] | None = None,
+        default_query: Mapping[str, object] | None = None,
+        set_default_query: Mapping[str, object] | None = None,
+        _extra_kwargs: Mapping[str, Any] = {},
+    ) -> Self:
+        """
+        Create a new client instance re-using the same options given to the current client with optional overriding.
+        """
+        return super().copy(
+            api_key=api_key,
+            organization=organization,
+            base_url=base_url,
+            timeout=timeout,
+            http_client=http_client,
+            max_retries=max_retries,
+            default_headers=default_headers,
+            set_default_headers=set_default_headers,
+            default_query=default_query,
+            set_default_query=set_default_query,
+            _extra_kwargs={
+                "api_version": api_version or self._api_version,
+                "azure_ad_token": azure_ad_token or self._azure_ad_token,
+                "azure_ad_token_provider": azure_ad_token_provider or self._azure_ad_token_provider,
+                **_extra_kwargs,
+            },
+        )
+
+    with_options = copy
+
     async def _get_azure_ad_token(self) -> str | None:
         if self._azure_ad_token is not None:
             return self._azure_ad_token
src/openai/_client.py
@@ -4,8 +4,8 @@ from __future__ import annotations
 
 import os
 import asyncio
-from typing import Union, Mapping
-from typing_extensions import override
+from typing import Any, Union, Mapping
+from typing_extensions import Self, override
 
 import httpx
 
@@ -164,12 +164,10 @@ class OpenAI(SyncAPIClient):
         set_default_headers: Mapping[str, str] | None = None,
         default_query: Mapping[str, object] | None = None,
         set_default_query: Mapping[str, object] | None = None,
-    ) -> OpenAI:
+        _extra_kwargs: Mapping[str, Any] = {},
+    ) -> Self:
         """
         Create a new client instance re-using the same options given to the current client with optional overriding.
-
-        It should be noted that this does not share the underlying httpx client class which may lead
-        to performance issues.
         """
         if default_headers is not None and set_default_headers is not None:
             raise ValueError("The `default_headers` and `set_default_headers` arguments are mutually exclusive")
@@ -199,6 +197,7 @@ class OpenAI(SyncAPIClient):
             max_retries=max_retries if is_given(max_retries) else self.max_retries,
             default_headers=headers,
             default_query=params,
+            **_extra_kwargs,
         )
 
     # Alias for `copy` for nicer inline usage, e.g.
@@ -374,12 +373,10 @@ class AsyncOpenAI(AsyncAPIClient):
         set_default_headers: Mapping[str, str] | None = None,
         default_query: Mapping[str, object] | None = None,
         set_default_query: Mapping[str, object] | None = None,
-    ) -> AsyncOpenAI:
+        _extra_kwargs: Mapping[str, Any] = {},
+    ) -> Self:
         """
         Create a new client instance re-using the same options given to the current client with optional overriding.
-
-        It should be noted that this does not share the underlying httpx client class which may lead
-        to performance issues.
         """
         if default_headers is not None and set_default_headers is not None:
             raise ValueError("The `default_headers` and `set_default_headers` arguments are mutually exclusive")
@@ -409,6 +406,7 @@ class AsyncOpenAI(AsyncAPIClient):
             max_retries=max_retries if is_given(max_retries) else self.max_retries,
             default_headers=headers,
             default_query=params,
+            **_extra_kwargs,
         )
 
     # Alias for `copy` for nicer inline usage, e.g.
tests/lib/test_azure.py
@@ -1,4 +1,5 @@
 from typing import Union
+from typing_extensions import Literal
 
 import pytest
 
@@ -34,3 +35,32 @@ def test_implicit_deployment_path(client: Client) -> None:
         req.url
         == "https://example-resource.azure.openai.com/openai/deployments/my-deployment-model/chat/completions?api-version=2023-07-01"
     )
+
+
+@pytest.mark.parametrize(
+    "client,method",
+    [
+        (sync_client, "copy"),
+        (sync_client, "with_options"),
+        (async_client, "copy"),
+        (async_client, "with_options"),
+    ],
+)
+def test_client_copying(client: Client, method: Literal["copy", "with_options"]) -> None:
+    if method == "copy":
+        copied = client.copy()
+    else:
+        copied = client.with_options()
+
+    assert copied._custom_query == {"api-version": "2023-07-01"}
+
+
+@pytest.mark.parametrize(
+    "client",
+    [sync_client, async_client],
+)
+def test_client_copying_override_options(client: Client) -> None:
+    copied = client.copy(
+        api_version="2022-05-01",
+    )
+    assert copied._custom_query == {"api-version": "2022-05-01"}