Commit d052708a

Stainless Bot <107565488+stainless-bot@users.noreply.github.com>
2023-12-09 03:02:27
fix: avoid leaking memory when Client.with_options is used (#956)
Fixes https://github.com/openai/openai-python/issues/865.
1 parent d468f30
src/openai/_base_client.py
@@ -403,14 +403,12 @@ class BaseClient(Generic[_HttpxClientT, _DefaultStreamT]):
         headers_dict = _merge_mappings(self.default_headers, custom_headers)
         self._validate_headers(headers_dict, custom_headers)
 
+        # headers are case-insensitive while dictionaries are not.
         headers = httpx.Headers(headers_dict)
 
         idempotency_header = self._idempotency_header
         if idempotency_header and options.method.lower() != "get" and idempotency_header not in headers:
-            if not options.idempotency_key:
-                options.idempotency_key = self._idempotency_key()
-
-            headers[idempotency_header] = options.idempotency_key
+            headers[idempotency_header] = options.idempotency_key or self._idempotency_key()
 
         return headers
 
@@ -594,16 +592,8 @@ class BaseClient(Generic[_HttpxClientT, _DefaultStreamT]):
     def base_url(self, url: URL | str) -> None:
         self._base_url = self._enforce_trailing_slash(url if isinstance(url, URL) else URL(url))
 
-    @lru_cache(maxsize=None)
     def platform_headers(self) -> Dict[str, str]:
-        return {
-            "X-Stainless-Lang": "python",
-            "X-Stainless-Package-Version": self._version,
-            "X-Stainless-OS": str(get_platform()),
-            "X-Stainless-Arch": str(get_architecture()),
-            "X-Stainless-Runtime": platform.python_implementation(),
-            "X-Stainless-Runtime-Version": platform.python_version(),
-        }
+        return platform_headers(self._version)
 
     def _calculate_retry_timeout(
         self,
@@ -1691,6 +1681,18 @@ def get_platform() -> Platform:
     return "Unknown"
 
 
+@lru_cache(maxsize=None)
+def platform_headers(version: str) -> Dict[str, str]:
+    return {
+        "X-Stainless-Lang": "python",
+        "X-Stainless-Package-Version": version,
+        "X-Stainless-OS": str(get_platform()),
+        "X-Stainless-Arch": str(get_architecture()),
+        "X-Stainless-Runtime": platform.python_implementation(),
+        "X-Stainless-Runtime-Version": platform.python_version(),
+    }
+
+
 class OtherArch:
     def __init__(self, name: str) -> None:
         self.name = name
src/openai/_client.py
@@ -192,7 +192,7 @@ class OpenAI(SyncAPIClient):
         return self.__class__(
             api_key=api_key or self.api_key,
             organization=organization or self.organization,
-            base_url=base_url or str(self.base_url),
+            base_url=base_url or self.base_url,
             timeout=self.timeout if isinstance(timeout, NotGiven) else timeout,
             http_client=http_client,
             max_retries=max_retries if is_given(max_retries) else self.max_retries,
@@ -402,7 +402,7 @@ class AsyncOpenAI(AsyncAPIClient):
         return self.__class__(
             api_key=api_key or self.api_key,
             organization=organization or self.organization,
-            base_url=base_url or str(self.base_url),
+            base_url=base_url or self.base_url,
             timeout=self.timeout if isinstance(timeout, NotGiven) else timeout,
             http_client=http_client,
             max_retries=max_retries if is_given(max_retries) else self.max_retries,
tests/test_client.py
@@ -2,10 +2,12 @@
 
 from __future__ import annotations
 
+import gc
 import os
 import json
 import asyncio
 import inspect
+import tracemalloc
 from typing import Any, Union, cast
 from unittest import mock
 
@@ -195,6 +197,67 @@ class TestOpenAI:
             copy_param = copy_signature.parameters.get(name)
             assert copy_param is not None, f"copy() signature is missing the {name} param"
 
+    def test_copy_build_request(self) -> None:
+        options = FinalRequestOptions(method="get", url="/foo")
+
+        def build_request(options: FinalRequestOptions) -> None:
+            client = self.client.copy()
+            client._build_request(options)
+
+        # ensure that the machinery is warmed up before tracing starts.
+        build_request(options)
+        gc.collect()
+
+        tracemalloc.start(1000)
+
+        snapshot_before = tracemalloc.take_snapshot()
+
+        ITERATIONS = 10
+        for _ in range(ITERATIONS):
+            build_request(options)
+            gc.collect()
+
+        snapshot_after = tracemalloc.take_snapshot()
+
+        tracemalloc.stop()
+
+        def add_leak(leaks: list[tracemalloc.StatisticDiff], diff: tracemalloc.StatisticDiff) -> None:
+            if diff.count == 0:
+                # Avoid false positives by considering only leaks (i.e. allocations that persist).
+                return
+
+            if diff.count % ITERATIONS != 0:
+                # Avoid false positives by considering only leaks that appear per iteration.
+                return
+
+            for frame in diff.traceback:
+                if any(
+                    frame.filename.endswith(fragment)
+                    for fragment in [
+                        # to_raw_response_wrapper leaks through the @functools.wraps() decorator.
+                        #
+                        # removing the decorator fixes the leak for reasons we don't understand.
+                        "openai/_response.py",
+                        # pydantic.BaseModel.model_dump || pydantic.BaseModel.dict leak memory for some reason.
+                        "openai/_compat.py",
+                        # Standard library leaks we don't care about.
+                        "/logging/__init__.py",
+                    ]
+                ):
+                    return
+
+            leaks.append(diff)
+
+        leaks: list[tracemalloc.StatisticDiff] = []
+        for diff in snapshot_after.compare_to(snapshot_before, "traceback"):
+            add_leak(leaks, diff)
+        if leaks:
+            for leak in leaks:
+                print("MEMORY LEAK:", leak)
+                for frame in leak.traceback:
+                    print(frame)
+            raise AssertionError()
+
     def test_request_timeout(self) -> None:
         request = self.client._build_request(FinalRequestOptions(method="get", url="/foo"))
         timeout = httpx.Timeout(**request.extensions["timeout"])  # type: ignore
@@ -858,6 +921,67 @@ class TestAsyncOpenAI:
             copy_param = copy_signature.parameters.get(name)
             assert copy_param is not None, f"copy() signature is missing the {name} param"
 
+    def test_copy_build_request(self) -> None:
+        options = FinalRequestOptions(method="get", url="/foo")
+
+        def build_request(options: FinalRequestOptions) -> None:
+            client = self.client.copy()
+            client._build_request(options)
+
+        # ensure that the machinery is warmed up before tracing starts.
+        build_request(options)
+        gc.collect()
+
+        tracemalloc.start(1000)
+
+        snapshot_before = tracemalloc.take_snapshot()
+
+        ITERATIONS = 10
+        for _ in range(ITERATIONS):
+            build_request(options)
+            gc.collect()
+
+        snapshot_after = tracemalloc.take_snapshot()
+
+        tracemalloc.stop()
+
+        def add_leak(leaks: list[tracemalloc.StatisticDiff], diff: tracemalloc.StatisticDiff) -> None:
+            if diff.count == 0:
+                # Avoid false positives by considering only leaks (i.e. allocations that persist).
+                return
+
+            if diff.count % ITERATIONS != 0:
+                # Avoid false positives by considering only leaks that appear per iteration.
+                return
+
+            for frame in diff.traceback:
+                if any(
+                    frame.filename.endswith(fragment)
+                    for fragment in [
+                        # to_raw_response_wrapper leaks through the @functools.wraps() decorator.
+                        #
+                        # removing the decorator fixes the leak for reasons we don't understand.
+                        "openai/_response.py",
+                        # pydantic.BaseModel.model_dump || pydantic.BaseModel.dict leak memory for some reason.
+                        "openai/_compat.py",
+                        # Standard library leaks we don't care about.
+                        "/logging/__init__.py",
+                    ]
+                ):
+                    return
+
+            leaks.append(diff)
+
+        leaks: list[tracemalloc.StatisticDiff] = []
+        for diff in snapshot_after.compare_to(snapshot_before, "traceback"):
+            add_leak(leaks, diff)
+        if leaks:
+            for leak in leaks:
+                print("MEMORY LEAK:", leak)
+                for frame in leak.traceback:
+                    print(frame)
+            raise AssertionError()
+
     async def test_request_timeout(self) -> None:
         request = self.client._build_request(FinalRequestOptions(method="get", url="/foo"))
         timeout = httpx.Timeout(**request.extensions["timeout"])  # type: ignore
pyproject.toml
@@ -152,8 +152,6 @@ select = [
   "T203",
 ]
 ignore = [
-  # lru_cache in methods, will be fixed separately
-  "B019",
   # mutable defaults
   "B006",
 ]