Commit dea7ee14
Changed files (4)
src
openai
src/openai/_utils/__init__.py
@@ -49,3 +49,4 @@ from ._transform import (
maybe_transform as maybe_transform,
async_maybe_transform as async_maybe_transform,
)
+from ._reflection import function_has_argument as function_has_argument
src/openai/_utils/_reflection.py
@@ -0,0 +1,8 @@
+import inspect
+from typing import Any, Callable
+
+
+def function_has_argument(func: Callable[..., Any], arg_name: str) -> bool:
+ """Returns whether or not the given function has a specific parameter"""
+ sig = inspect.signature(func)
+ return arg_name in sig.parameters
src/openai/_utils/_sync.py
@@ -7,6 +7,8 @@ from typing_extensions import ParamSpec
import anyio
import anyio.to_thread
+from ._reflection import function_has_argument
+
T_Retval = TypeVar("T_Retval")
T_ParamSpec = ParamSpec("T_ParamSpec")
@@ -59,6 +61,21 @@ def asyncify(
async def wrapper(*args: T_ParamSpec.args, **kwargs: T_ParamSpec.kwargs) -> T_Retval:
partial_f = functools.partial(function, *args, **kwargs)
- return await anyio.to_thread.run_sync(partial_f, cancellable=cancellable, limiter=limiter)
+
+ # In `v4.1.0` anyio added the `abandon_on_cancel` argument and deprecated the old
+ # `cancellable` argument, so we need to use the new `abandon_on_cancel` to avoid
+ # surfacing deprecation warnings.
+ if function_has_argument(anyio.to_thread.run_sync, "abandon_on_cancel"):
+ return await anyio.to_thread.run_sync(
+ partial_f,
+ abandon_on_cancel=cancellable,
+ limiter=limiter,
+ )
+
+ return await anyio.to_thread.run_sync(
+ partial_f,
+ cancellable=cancellable,
+ limiter=limiter,
+ )
return wrapper
src/openai/_base_client.py
@@ -60,7 +60,7 @@ from ._types import (
RequestOptions,
ModelBuilderProtocol,
)
-from ._utils import is_dict, is_list, is_given, lru_cache, is_mapping
+from ._utils import is_dict, is_list, asyncify, is_given, lru_cache, is_mapping
from ._compat import model_copy, model_dump
from ._models import GenericModel, FinalRequestOptions, validate_type, construct_type
from ._response import (
@@ -359,6 +359,7 @@ class BaseClient(Generic[_HttpxClientT, _DefaultStreamT]):
self._custom_query = custom_query or {}
self._strict_response_validation = _strict_response_validation
self._idempotency_header = None
+ self._platform: Platform | None = None
if max_retries is None: # pyright: ignore[reportUnnecessaryComparison]
raise TypeError(
@@ -623,7 +624,10 @@ class BaseClient(Generic[_HttpxClientT, _DefaultStreamT]):
self._base_url = self._enforce_trailing_slash(url if isinstance(url, URL) else URL(url))
def platform_headers(self) -> Dict[str, str]:
- return platform_headers(self._version)
+ # the actual implementation is in a separate `lru_cache` decorated
+ # function because adding `lru_cache` to methods will leak memory
+ # https://github.com/python/cpython/issues/88476
+ return platform_headers(self._version, platform=self._platform)
def _parse_retry_after_header(self, response_headers: Optional[httpx.Headers] = None) -> float | None:
"""Returns a float of the number of seconds (not milliseconds) to wait after retrying, or None if unspecified.
@@ -1513,6 +1517,11 @@ class AsyncAPIClient(BaseClient[httpx.AsyncClient, AsyncStream[Any]]):
stream_cls: type[_AsyncStreamT] | None,
remaining_retries: int | None,
) -> ResponseT | _AsyncStreamT:
+ if self._platform is None:
+ # `get_platform` can make blocking IO calls so we
+ # execute it earlier while we are in an async context
+ self._platform = await asyncify(get_platform)()
+
cast_to = self._maybe_override_cast_to(cast_to, options)
await self._prepare_options(options)
@@ -1949,11 +1958,11 @@ def get_platform() -> Platform:
@lru_cache(maxsize=None)
-def platform_headers(version: str) -> Dict[str, str]:
+def platform_headers(version: str, *, platform: Platform | None) -> Dict[str, str]:
return {
"X-Stainless-Lang": "python",
"X-Stainless-Package-Version": version,
- "X-Stainless-OS": str(get_platform()),
+ "X-Stainless-OS": str(platform or get_platform()),
"X-Stainless-Arch": str(get_architecture()),
"X-Stainless-Runtime": get_python_runtime(),
"X-Stainless-Runtime-Version": get_python_version(),