Commit 83ebf66f

Robert Craigie <robert@craigie.dev>
2024-07-15 16:57:42
chore(internal): minor options / compat functions updates (#1549)
1 parent f14f859
Changed files (3)
src/openai/lib/azure.py
@@ -10,6 +10,7 @@ import httpx
 from .._types import NOT_GIVEN, Omit, Timeout, NotGiven
 from .._utils import is_given, is_mapping
 from .._client import OpenAI, AsyncOpenAI
+from .._compat import model_copy
 from .._models import FinalRequestOptions
 from .._streaming import Stream, AsyncStream
 from .._exceptions import OpenAIError
@@ -281,8 +282,10 @@ class AzureOpenAI(BaseAzureClient[httpx.Client, Stream[Any]], OpenAI):
         return None
 
     @override
-    def _prepare_options(self, options: FinalRequestOptions) -> None:
+    def _prepare_options(self, options: FinalRequestOptions) -> FinalRequestOptions:
         headers: dict[str, str | Omit] = {**options.headers} if is_given(options.headers) else {}
+
+        options = model_copy(options)
         options.headers = headers
 
         azure_ad_token = self._get_azure_ad_token()
@@ -296,7 +299,7 @@ class AzureOpenAI(BaseAzureClient[httpx.Client, Stream[Any]], OpenAI):
             # should never be hit
             raise ValueError("Unable to handle auth")
 
-        return super()._prepare_options(options)
+        return options
 
 
 class AsyncAzureOpenAI(BaseAzureClient[httpx.AsyncClient, AsyncStream[Any]], AsyncOpenAI):
@@ -524,8 +527,10 @@ class AsyncAzureOpenAI(BaseAzureClient[httpx.AsyncClient, AsyncStream[Any]], Asy
         return None
 
     @override
-    async def _prepare_options(self, options: FinalRequestOptions) -> None:
+    async def _prepare_options(self, options: FinalRequestOptions) -> FinalRequestOptions:
         headers: dict[str, str | Omit] = {**options.headers} if is_given(options.headers) else {}
+
+        options = model_copy(options)
         options.headers = headers
 
         azure_ad_token = await self._get_azure_ad_token()
@@ -539,4 +544,4 @@ class AsyncAzureOpenAI(BaseAzureClient[httpx.AsyncClient, AsyncStream[Any]], Asy
             # should never be hit
             raise ValueError("Unable to handle auth")
 
-        return await super()._prepare_options(options)
+        return options
src/openai/_base_client.py
@@ -880,9 +880,9 @@ class SyncAPIClient(BaseClient[httpx.Client, Stream[Any]]):
     def _prepare_options(
         self,
         options: FinalRequestOptions,  # noqa: ARG002
-    ) -> None:
+    ) -> FinalRequestOptions:
         """Hook for mutating the given options"""
-        return None
+        return options
 
     def _prepare_request(
         self,
@@ -962,7 +962,7 @@ class SyncAPIClient(BaseClient[httpx.Client, Stream[Any]]):
         input_options = model_copy(options)
 
         cast_to = self._maybe_override_cast_to(cast_to, options)
-        self._prepare_options(options)
+        options = self._prepare_options(options)
 
         retries = self._remaining_retries(remaining_retries, options)
         request = self._build_request(options)
@@ -1457,9 +1457,9 @@ class AsyncAPIClient(BaseClient[httpx.AsyncClient, AsyncStream[Any]]):
     async def _prepare_options(
         self,
         options: FinalRequestOptions,  # noqa: ARG002
-    ) -> None:
+    ) -> FinalRequestOptions:
         """Hook for mutating the given options"""
-        return None
+        return options
 
     async def _prepare_request(
         self,
@@ -1544,7 +1544,7 @@ class AsyncAPIClient(BaseClient[httpx.AsyncClient, AsyncStream[Any]]):
         input_options = model_copy(options)
 
         cast_to = self._maybe_override_cast_to(cast_to, options)
-        await self._prepare_options(options)
+        options = await self._prepare_options(options)
 
         retries = self._remaining_retries(remaining_retries, options)
         request = self._build_request(options)
src/openai/_compat.py
@@ -118,10 +118,10 @@ def get_model_fields(model: type[pydantic.BaseModel]) -> dict[str, FieldInfo]:
     return model.__fields__  # type: ignore
 
 
-def model_copy(model: _ModelT) -> _ModelT:
+def model_copy(model: _ModelT, *, deep: bool = False) -> _ModelT:
     if PYDANTIC_V2:
-        return model.model_copy()
-    return model.copy()  # type: ignore
+        return model.model_copy(deep=deep)
+    return model.copy(deep=deep)  # type: ignore
 
 
 def model_json(model: pydantic.BaseModel, *, indent: int | None = None) -> str: