Commit 0733934f

Stainless Bot <107565488+stainless-bot@users.noreply.github.com>
2023-11-16 00:54:17
feat(client): support reading the base url from an env variable (#829)
1 parent 71d3172
Changed files (4)
src/openai/_client.py
@@ -99,6 +99,8 @@ class OpenAI(SyncAPIClient):
             organization = os.environ.get("OPENAI_ORG_ID")
         self.organization = organization
 
+        if base_url is None:
+            base_url = os.environ.get("OPENAI_BASE_URL")
         if base_url is None:
             base_url = f"https://api.openai.com/v1"
 
@@ -307,6 +309,8 @@ class AsyncOpenAI(AsyncAPIClient):
             organization = os.environ.get("OPENAI_ORG_ID")
         self.organization = organization
 
+        if base_url is None:
+            base_url = os.environ.get("OPENAI_BASE_URL")
         if base_url is None:
             base_url = f"https://api.openai.com/v1"
 
tests/test_client.py
@@ -26,6 +26,8 @@ from openai._base_client import (
     make_request_options,
 )
 
+from .utils import update_env
+
 base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010")
 api_key = "My API Key"
 
@@ -399,6 +401,11 @@ class TestOpenAI:
         assert isinstance(response, Model1)
         assert response.foo == 1
 
+    def test_base_url_env(self) -> None:
+        with update_env(OPENAI_BASE_URL="http://localhost:5000/from/env"):
+            client = OpenAI(api_key=api_key, _strict_response_validation=True)
+            assert client.base_url == "http://localhost:5000/from/env/"
+
     @pytest.mark.parametrize(
         "client",
         [
@@ -932,6 +939,11 @@ class TestAsyncOpenAI:
         assert isinstance(response, Model1)
         assert response.foo == 1
 
+    def test_base_url_env(self) -> None:
+        with update_env(OPENAI_BASE_URL="http://localhost:5000/from/env"):
+            client = AsyncOpenAI(api_key=api_key, _strict_response_validation=True)
+            assert client.base_url == "http://localhost:5000/from/env/"
+
     @pytest.mark.parametrize(
         "client",
         [
tests/utils.py
@@ -1,7 +1,9 @@
 from __future__ import annotations
 
+import os
 import traceback
-from typing import Any, TypeVar, cast
+import contextlib
+from typing import Any, TypeVar, Iterator, cast
 from datetime import date, datetime
 from typing_extensions import Literal, get_args, get_origin, assert_type
 
@@ -103,3 +105,16 @@ def _assert_list_type(type_: type[object], value: object) -> None:
     inner_type = get_args(type_)[0]
     for entry in value:
         assert_type(inner_type, entry)  # type: ignore
+
+
+@contextlib.contextmanager
+def update_env(**new_env: str) -> Iterator[None]:
+    old = os.environ.copy()
+
+    try:
+        os.environ.update(new_env)
+
+        yield None
+    finally:
+        os.environ.clear()
+        os.environ.update(old)
README.md
@@ -437,6 +437,7 @@ import httpx
 from openai import OpenAI
 
 client = OpenAI(
+    # Or use the `OPENAI_BASE_URL` env var
     base_url="http://my.test.server.example.com:8083",
     http_client=httpx.Client(
         proxies="http://my.test.proxy.example.com",