Commit dea7ee14

stainless-app[bot] <142633134+stainless-app[bot]@users.noreply.github.com>
2024-06-19 19:18:26
fix(client/async): avoid blocking io call for platform headers (#1488)
1 parent 6aa2a80
Changed files (4)
src/openai/_utils/__init__.py
@@ -49,3 +49,4 @@ from ._transform import (
     maybe_transform as maybe_transform,
     async_maybe_transform as async_maybe_transform,
 )
+from ._reflection import function_has_argument as function_has_argument
src/openai/_utils/_reflection.py
@@ -0,0 +1,8 @@
+import inspect
+from typing import Any, Callable
+
+
+def function_has_argument(func: Callable[..., Any], arg_name: str) -> bool:
+    """Returns whether or not the given function has a specific parameter"""
+    sig = inspect.signature(func)
+    return arg_name in sig.parameters
src/openai/_utils/_sync.py
@@ -7,6 +7,8 @@ from typing_extensions import ParamSpec
 import anyio
 import anyio.to_thread
 
+from ._reflection import function_has_argument
+
 T_Retval = TypeVar("T_Retval")
 T_ParamSpec = ParamSpec("T_ParamSpec")
 
@@ -59,6 +61,21 @@ def asyncify(
 
     async def wrapper(*args: T_ParamSpec.args, **kwargs: T_ParamSpec.kwargs) -> T_Retval:
         partial_f = functools.partial(function, *args, **kwargs)
-        return await anyio.to_thread.run_sync(partial_f, cancellable=cancellable, limiter=limiter)
+
+        # In `v4.1.0` anyio added the `abandon_on_cancel` argument and deprecated the old
+        # `cancellable` argument, so we need to use the new `abandon_on_cancel` to avoid
+        # surfacing deprecation warnings.
+        if function_has_argument(anyio.to_thread.run_sync, "abandon_on_cancel"):
+            return await anyio.to_thread.run_sync(
+                partial_f,
+                abandon_on_cancel=cancellable,
+                limiter=limiter,
+            )
+
+        return await anyio.to_thread.run_sync(
+            partial_f,
+            cancellable=cancellable,
+            limiter=limiter,
+        )
 
     return wrapper
src/openai/_base_client.py
@@ -60,7 +60,7 @@ from ._types import (
     RequestOptions,
     ModelBuilderProtocol,
 )
-from ._utils import is_dict, is_list, is_given, lru_cache, is_mapping
+from ._utils import is_dict, is_list, asyncify, is_given, lru_cache, is_mapping
 from ._compat import model_copy, model_dump
 from ._models import GenericModel, FinalRequestOptions, validate_type, construct_type
 from ._response import (
@@ -359,6 +359,7 @@ class BaseClient(Generic[_HttpxClientT, _DefaultStreamT]):
         self._custom_query = custom_query or {}
         self._strict_response_validation = _strict_response_validation
         self._idempotency_header = None
+        self._platform: Platform | None = None
 
         if max_retries is None:  # pyright: ignore[reportUnnecessaryComparison]
             raise TypeError(
@@ -623,7 +624,10 @@ class BaseClient(Generic[_HttpxClientT, _DefaultStreamT]):
         self._base_url = self._enforce_trailing_slash(url if isinstance(url, URL) else URL(url))
 
     def platform_headers(self) -> Dict[str, str]:
-        return platform_headers(self._version)
+        # the actual implementation is in a separate `lru_cache` decorated
+        # function because adding `lru_cache` to methods will leak memory
+        # https://github.com/python/cpython/issues/88476
+        return platform_headers(self._version, platform=self._platform)
 
     def _parse_retry_after_header(self, response_headers: Optional[httpx.Headers] = None) -> float | None:
         """Returns a float of the number of seconds (not milliseconds) to wait after retrying, or None if unspecified.
@@ -1513,6 +1517,11 @@ class AsyncAPIClient(BaseClient[httpx.AsyncClient, AsyncStream[Any]]):
         stream_cls: type[_AsyncStreamT] | None,
         remaining_retries: int | None,
     ) -> ResponseT | _AsyncStreamT:
+        if self._platform is None:
+            # `get_platform` can make blocking IO calls so we
+            # execute it earlier while we are in an async context
+            self._platform = await asyncify(get_platform)()
+
         cast_to = self._maybe_override_cast_to(cast_to, options)
         await self._prepare_options(options)
 
@@ -1949,11 +1958,11 @@ def get_platform() -> Platform:
 
 
 @lru_cache(maxsize=None)
-def platform_headers(version: str) -> Dict[str, str]:
+def platform_headers(version: str, *, platform: Platform | None) -> Dict[str, str]:
     return {
         "X-Stainless-Lang": "python",
         "X-Stainless-Package-Version": version,
-        "X-Stainless-OS": str(get_platform()),
+        "X-Stainless-OS": str(platform or get_platform()),
         "X-Stainless-Arch": str(get_architecture()),
         "X-Stainless-Runtime": get_python_runtime(),
         "X-Stainless-Runtime-Version": get_python_version(),