Commit d8901d28

Seth Gilchrist <seth@sethgilchrist.com>
2024-11-18 20:41:32
fix(asyncify): avoid hanging process under certain conditions (#1853)
1 parent 0d6185e
src/openai/_utils/_sync.py
@@ -1,56 +1,60 @@
 from __future__ import annotations
 
+import sys
+import asyncio
 import functools
-from typing import TypeVar, Callable, Awaitable
+import contextvars
+from typing import Any, TypeVar, Callable, Awaitable
 from typing_extensions import ParamSpec
 
-import anyio
-import anyio.to_thread
-
-from ._reflection import function_has_argument
-
 T_Retval = TypeVar("T_Retval")
 T_ParamSpec = ParamSpec("T_ParamSpec")
 
 
-# copied from `asyncer`, https://github.com/tiangolo/asyncer
-def asyncify(
-    function: Callable[T_ParamSpec, T_Retval],
-    *,
-    cancellable: bool = False,
-    limiter: anyio.CapacityLimiter | None = None,
-) -> Callable[T_ParamSpec, Awaitable[T_Retval]]:
-    """
-    Take a blocking function and create an async one that receives the same
-    positional and keyword arguments, and that when called, calls the original function
-    in a worker thread using `anyio.to_thread.run_sync()`. Internally,
-    `asyncer.asyncify()` uses the same `anyio.to_thread.run_sync()`, but it supports
-    keyword arguments additional to positional arguments and it adds better support for
-    autocompletion and inline errors for the arguments of the function called and the
-    return value.
+if sys.version_info >= (3, 9):
+    to_thread = asyncio.to_thread
+else:
+    async def _to_thread(
+        func: Callable[T_ParamSpec, T_Retval], /, *args: T_ParamSpec.args, **kwargs: T_ParamSpec.kwargs
+    ) -> Any:
+        """Asynchronously run function *func* in a separate thread.
 
-    If the `cancellable` option is enabled and the task waiting for its completion is
-    cancelled, the thread will still run its course but its return value (or any raised
-    exception) will be ignored.
+        Any *args and **kwargs supplied for this function are directly passed
+        to *func*. Also, the current :class:`contextvars.Context` is propagated,
+        allowing context variables from the main thread to be accessed in the
+        separate thread.
 
-    Use it like this:
+        Returns a coroutine that can be awaited to get the eventual result of *func*.
+        """
+        loop = asyncio.events.get_running_loop()
+        ctx = contextvars.copy_context()
+        func_call = functools.partial(ctx.run, func, *args, **kwargs)
+        return await loop.run_in_executor(None, func_call)
+
+    to_thread = _to_thread
+
+# inspired by `asyncer`, https://github.com/tiangolo/asyncer
+def asyncify(function: Callable[T_ParamSpec, T_Retval]) -> Callable[T_ParamSpec, Awaitable[T_Retval]]:
+    """
+    Take a blocking function and create an async one that receives the same
+    positional and keyword arguments. For python version 3.9 and above, it uses
+    asyncio.to_thread to run the function in a separate thread. For python version
+    3.8, it uses locally defined copy of the asyncio.to_thread function which was
+    introduced in python 3.9.
 
-    ```Python
-    def do_work(arg1, arg2, kwarg1="", kwarg2="") -> str:
-        # Do work
-        return "Some result"
+    Usage:
 
+    ```python
+    def blocking_func(arg1, arg2, kwarg1=None):
+        # blocking code
+        return result
 
-    result = await to_thread.asyncify(do_work)("spam", "ham", kwarg1="a", kwarg2="b")
-    print(result)
+    result = asyncify(blocking_function)(arg1, arg2, kwarg1=value1)
     ```
 
     ## Arguments
 
     `function`: a blocking regular callable (e.g. a function)
-    `cancellable`: `True` to allow cancellation of the operation
-    `limiter`: capacity limiter to use to limit the total amount of threads running
-        (if omitted, the default limiter is used)
 
     ## Return
 
@@ -60,22 +64,6 @@ def asyncify(
     """
 
     async def wrapper(*args: T_ParamSpec.args, **kwargs: T_ParamSpec.kwargs) -> T_Retval:
-        partial_f = functools.partial(function, *args, **kwargs)
-
-        # In `v4.1.0` anyio added the `abandon_on_cancel` argument and deprecated the old
-        # `cancellable` argument, so we need to use the new `abandon_on_cancel` to avoid
-        # surfacing deprecation warnings.
-        if function_has_argument(anyio.to_thread.run_sync, "abandon_on_cancel"):
-            return await anyio.to_thread.run_sync(
-                partial_f,
-                abandon_on_cancel=cancellable,
-                limiter=limiter,
-            )
-
-        return await anyio.to_thread.run_sync(
-            partial_f,
-            cancellable=cancellable,
-            limiter=limiter,
-        )
+        return await to_thread(function, *args, **kwargs)
 
     return wrapper
tests/test_client.py
@@ -4,11 +4,14 @@ from __future__ import annotations
 
 import gc
 import os
+import sys
 import json
 import asyncio
 import inspect
+import subprocess
 import tracemalloc
 from typing import Any, Union, cast
+from textwrap import dedent
 from unittest import mock
 from typing_extensions import Literal
 
@@ -1766,3 +1769,43 @@ class TestAsyncOpenAI:
         ) as response:
             assert response.retries_taken == failures_before_success
             assert int(response.http_request.headers.get("x-stainless-retry-count")) == failures_before_success
+
+    def test_get_platform(self) -> None:
+        # Issue https://github.com/openai/openai-python/issues/1827 was caused
+        # asyncify leaving threads unterminated when used with nest_asyncio.
+        # Since nest_asyncio.apply() is global and cannot be un-applied, this
+        # test is run in a separate process to avoid affecting other tests.
+        test_code = dedent("""\
+        import asyncio
+        import nest_asyncio
+
+        import threading
+
+        from openai._base_client import get_platform
+        from openai._utils import asyncify
+
+        async def test_main() -> None:
+            result = await asyncify(get_platform)()
+            print(result)
+            for thread in threading.enumerate():
+                print(thread.name)
+
+        nest_asyncio.apply()
+        asyncio.run(test_main())
+        """)
+        with subprocess.Popen(
+            [sys.executable, "-c", test_code],
+            stdout=subprocess.PIPE,
+            stderr=subprocess.PIPE,
+            text=True,
+        ) as process:
+            try:
+                process.wait(2)
+                if process.returncode:
+                    print(process.stdout)
+                    print(process.stderr)
+                    raise AssertionError("calling get_platform using asyncify resulted in a non-zero exit code")
+            except subprocess.TimeoutExpired as e:
+                process.kill()
+                raise AssertionError("calling get_platform using asyncify resulted in a hung process") from e
+
pyproject.toml
@@ -65,7 +65,9 @@ dev-dependencies = [
     "azure-identity >=1.14.1",
     "types-tqdm > 4",
     "types-pyaudio > 0",
-    "trio >=0.22.2"
+    "trio >=0.22.2",
+    "nest_asyncio==1.6.0"
+
 ]
 
 [tool.rye.scripts]
requirements-dev.lock
@@ -7,6 +7,7 @@
 #   all-features: true
 #   with-sources: false
 #   generate-hashes: false
+#   universal: false
 
 -e file:.
 annotated-types==0.6.0
@@ -87,6 +88,7 @@ mypy==1.13.0
 mypy-extensions==1.0.0
     # via black
     # via mypy
+nest-asyncio==1.6.0
 nodeenv==1.8.0
     # via pyright
 nox==2023.4.22
requirements.lock
@@ -7,6 +7,7 @@
 #   all-features: true
 #   with-sources: false
 #   generate-hashes: false
+#   universal: false
 
 -e file:.
 annotated-types==0.6.0