Commit a0caa090

Stainless Bot <107565488+stainless-bot@users.noreply.github.com>
2024-03-04 21:39:15
chore(internal): support more input types (#1211)
1 parent 31bfc12
src/openai/_utils/_transform.py
@@ -1,9 +1,13 @@
 from __future__ import annotations
 
+import io
+import base64
+import pathlib
 from typing import Any, Mapping, TypeVar, cast
 from datetime import date, datetime
 from typing_extensions import Literal, get_args, override, get_type_hints
 
+import anyio
 import pydantic
 
 from ._utils import (
@@ -11,6 +15,7 @@ from ._utils import (
     is_mapping,
     is_iterable,
 )
+from .._files import is_base64_file_input
 from ._typing import (
     is_list_type,
     is_union_type,
@@ -29,7 +34,7 @@ _T = TypeVar("_T")
 # TODO: ensure works correctly with forward references in all cases
 
 
-PropertyFormat = Literal["iso8601", "custom"]
+PropertyFormat = Literal["iso8601", "base64", "custom"]
 
 
 class PropertyInfo:
@@ -201,6 +206,22 @@ def _format_data(data: object, format_: PropertyFormat, format_template: str | N
         if format_ == "custom" and format_template is not None:
             return data.strftime(format_template)
 
+    if format_ == "base64" and is_base64_file_input(data):
+        binary: str | bytes | None = None
+
+        if isinstance(data, pathlib.Path):
+            binary = data.read_bytes()
+        elif isinstance(data, io.IOBase):
+            binary = data.read()
+
+            if isinstance(binary, str):  # type: ignore[unreachable]
+                binary = binary.encode()
+
+        if not isinstance(binary, bytes):
+            raise RuntimeError(f"Could not read bytes from {data}; Received {type(binary)}")
+
+        return base64.b64encode(binary).decode("ascii")
+
     return data
 
 
@@ -323,6 +344,22 @@ async def _async_format_data(data: object, format_: PropertyFormat, format_templ
         if format_ == "custom" and format_template is not None:
             return data.strftime(format_template)
 
+    if format_ == "base64" and is_base64_file_input(data):
+        binary: str | bytes | None = None
+
+        if isinstance(data, pathlib.Path):
+            binary = await anyio.Path(data).read_bytes()
+        elif isinstance(data, io.IOBase):
+            binary = data.read()
+
+            if isinstance(binary, str):  # type: ignore[unreachable]
+                binary = binary.encode()
+
+        if not isinstance(binary, bytes):
+            raise RuntimeError(f"Could not read bytes from {data}; Received {type(binary)}")
+
+        return base64.b64encode(binary).decode("ascii")
+
     return data
 
 
src/openai/_files.py
@@ -13,12 +13,17 @@ from ._types import (
     FileContent,
     RequestFiles,
     HttpxFileTypes,
+    Base64FileInput,
     HttpxFileContent,
     HttpxRequestFiles,
 )
 from ._utils import is_tuple_t, is_mapping_t, is_sequence_t
 
 
+def is_base64_file_input(obj: object) -> TypeGuard[Base64FileInput]:
+    return isinstance(obj, io.IOBase) or isinstance(obj, os.PathLike)
+
+
 def is_file_content(obj: object) -> TypeGuard[FileContent]:
     return (
         isinstance(obj, bytes) or isinstance(obj, tuple) or isinstance(obj, io.IOBase) or isinstance(obj, os.PathLike)
src/openai/_types.py
@@ -41,8 +41,10 @@ _T = TypeVar("_T")
 ProxiesDict = Dict["str | URL", Union[None, str, URL, Proxy]]
 ProxiesTypes = Union[str, Proxy, ProxiesDict]
 if TYPE_CHECKING:
+    Base64FileInput = Union[IO[bytes], PathLike[str]]
     FileContent = Union[IO[bytes], bytes, PathLike[str]]
 else:
+    Base64FileInput = Union[IO[bytes], PathLike]
     FileContent = Union[IO[bytes], bytes, PathLike]  # PathLike is not subscriptable in Python 3.8.
 FileTypes = Union[
     # file (or bytes)
tests/sample_file.txt
@@ -0,0 +1,1 @@
+Hello, world!
tests/test_transform.py
@@ -1,11 +1,14 @@
 from __future__ import annotations
 
+import io
+import pathlib
 from typing import Any, List, Union, TypeVar, Iterable, Optional, cast
 from datetime import date, datetime
 from typing_extensions import Required, Annotated, TypedDict
 
 import pytest
 
+from openai._types import Base64FileInput
 from openai._utils import (
     PropertyInfo,
     transform as _transform,
@@ -17,6 +20,8 @@ from openai._models import BaseModel
 
 _T = TypeVar("_T")
 
+SAMPLE_FILE_PATH = pathlib.Path(__file__).parent.joinpath("sample_file.txt")
+
 
 async def transform(
     data: _T,
@@ -377,3 +382,27 @@ async def test_iterable_union_str(use_async: bool) -> None:
     assert cast(Any, await transform(iter([{"foo_baz": "bar"}]), Union[str, Iterable[Baz8]], use_async)) == [
         {"fooBaz": "bar"}
     ]
+
+
+class TypedDictBase64Input(TypedDict):
+    foo: Annotated[Union[str, Base64FileInput], PropertyInfo(format="base64")]
+
+
+@parametrize
+@pytest.mark.asyncio
+async def test_base64_file_input(use_async: bool) -> None:
+    # strings are left as-is
+    assert await transform({"foo": "bar"}, TypedDictBase64Input, use_async) == {"foo": "bar"}
+
+    # pathlib.Path is automatically converted to base64
+    assert await transform({"foo": SAMPLE_FILE_PATH}, TypedDictBase64Input, use_async) == {
+        "foo": "SGVsbG8sIHdvcmxkIQo="
+    }  # type: ignore[comparison-overlap]
+
+    # io instances are automatically converted to base64
+    assert await transform({"foo": io.StringIO("Hello, world!")}, TypedDictBase64Input, use_async) == {
+        "foo": "SGVsbG8sIHdvcmxkIQ=="
+    }  # type: ignore[comparison-overlap]
+    assert await transform({"foo": io.BytesIO(b"Hello, world!")}, TypedDictBase64Input, use_async) == {
+        "foo": "SGVsbG8sIHdvcmxkIQ=="
+    }  # type: ignore[comparison-overlap]