main
  1from __future__ import annotations
  2
  3import os
  4import inspect
  5import logging
  6import datetime
  7import functools
  8from typing import (
  9    TYPE_CHECKING,
 10    Any,
 11    Union,
 12    Generic,
 13    TypeVar,
 14    Callable,
 15    Iterator,
 16    AsyncIterator,
 17    cast,
 18    overload,
 19)
 20from typing_extensions import Awaitable, ParamSpec, override, deprecated, get_origin
 21
 22import anyio
 23import httpx
 24import pydantic
 25
 26from ._types import NoneType
 27from ._utils import is_given, extract_type_arg, is_annotated_type, is_type_alias_type
 28from ._models import BaseModel, is_basemodel, add_request_id
 29from ._constants import RAW_RESPONSE_HEADER
 30from ._streaming import Stream, AsyncStream, is_stream_class_type, extract_stream_chunk_type
 31from ._exceptions import APIResponseValidationError
 32
 33if TYPE_CHECKING:
 34    from ._models import FinalRequestOptions
 35    from ._base_client import BaseClient
 36
 37
 38P = ParamSpec("P")
 39R = TypeVar("R")
 40_T = TypeVar("_T")
 41
 42log: logging.Logger = logging.getLogger(__name__)
 43
 44
 45class LegacyAPIResponse(Generic[R]):
 46    """This is a legacy class as it will be replaced by `APIResponse`
 47    and `AsyncAPIResponse` in the `_response.py` file in the next major
 48    release.
 49
 50    For the sync client this will mostly be the same with the exception
 51    of `content` & `text` will be methods instead of properties. In the
 52    async client, all methods will be async.
 53
 54    A migration script will be provided & the migration in general should
 55    be smooth.
 56    """
 57
 58    _cast_to: type[R]
 59    _client: BaseClient[Any, Any]
 60    _parsed_by_type: dict[type[Any], Any]
 61    _stream: bool
 62    _stream_cls: type[Stream[Any]] | type[AsyncStream[Any]] | None
 63    _options: FinalRequestOptions
 64
 65    http_response: httpx.Response
 66
 67    retries_taken: int
 68    """The number of retries made. If no retries happened this will be `0`"""
 69
 70    def __init__(
 71        self,
 72        *,
 73        raw: httpx.Response,
 74        cast_to: type[R],
 75        client: BaseClient[Any, Any],
 76        stream: bool,
 77        stream_cls: type[Stream[Any]] | type[AsyncStream[Any]] | None,
 78        options: FinalRequestOptions,
 79        retries_taken: int = 0,
 80    ) -> None:
 81        self._cast_to = cast_to
 82        self._client = client
 83        self._parsed_by_type = {}
 84        self._stream = stream
 85        self._stream_cls = stream_cls
 86        self._options = options
 87        self.http_response = raw
 88        self.retries_taken = retries_taken
 89
 90    @property
 91    def request_id(self) -> str | None:
 92        return self.http_response.headers.get("x-request-id")  # type: ignore[no-any-return]
 93
 94    @overload
 95    def parse(self, *, to: type[_T]) -> _T: ...
 96
 97    @overload
 98    def parse(self) -> R: ...
 99
100    def parse(self, *, to: type[_T] | None = None) -> R | _T:
101        """Returns the rich python representation of this response's data.
102
103        NOTE: For the async client: this will become a coroutine in the next major version.
104
105        For lower-level control, see `.read()`, `.json()`, `.iter_bytes()`.
106
107        You can customise the type that the response is parsed into through
108        the `to` argument, e.g.
109
110        ```py
111        from openai import BaseModel
112
113
114        class MyModel(BaseModel):
115            foo: str
116
117
118        obj = response.parse(to=MyModel)
119        print(obj.foo)
120        ```
121
122        We support parsing:
123          - `BaseModel`
124          - `dict`
125          - `list`
126          - `Union`
127          - `str`
128          - `int`
129          - `float`
130          - `httpx.Response`
131        """
132        cache_key = to if to is not None else self._cast_to
133        cached = self._parsed_by_type.get(cache_key)
134        if cached is not None:
135            return cached  # type: ignore[no-any-return]
136
137        parsed = self._parse(to=to)
138        if is_given(self._options.post_parser):
139            parsed = self._options.post_parser(parsed)
140
141        if isinstance(parsed, BaseModel):
142            add_request_id(parsed, self.request_id)
143
144        self._parsed_by_type[cache_key] = parsed
145        return cast(R, parsed)
146
147    @property
148    def headers(self) -> httpx.Headers:
149        return self.http_response.headers
150
151    @property
152    def http_request(self) -> httpx.Request:
153        return self.http_response.request
154
155    @property
156    def status_code(self) -> int:
157        return self.http_response.status_code
158
159    @property
160    def url(self) -> httpx.URL:
161        return self.http_response.url
162
163    @property
164    def method(self) -> str:
165        return self.http_request.method
166
167    @property
168    def content(self) -> bytes:
169        """Return the binary response content.
170
171        NOTE: this will be removed in favour of `.read()` in the
172        next major version.
173        """
174        return self.http_response.content
175
176    @property
177    def text(self) -> str:
178        """Return the decoded response content.
179
180        NOTE: this will be turned into a method in the next major version.
181        """
182        return self.http_response.text
183
184    @property
185    def http_version(self) -> str:
186        return self.http_response.http_version
187
188    @property
189    def is_closed(self) -> bool:
190        return self.http_response.is_closed
191
192    @property
193    def elapsed(self) -> datetime.timedelta:
194        """The time taken for the complete request/response cycle to complete."""
195        return self.http_response.elapsed
196
197    def _parse(self, *, to: type[_T] | None = None) -> R | _T:
198        cast_to = to if to is not None else self._cast_to
199
200        # unwrap `TypeAlias('Name', T)` -> `T`
201        if is_type_alias_type(cast_to):
202            cast_to = cast_to.__value__  # type: ignore[unreachable]
203
204        # unwrap `Annotated[T, ...]` -> `T`
205        if cast_to and is_annotated_type(cast_to):
206            cast_to = extract_type_arg(cast_to, 0)
207
208        origin = get_origin(cast_to) or cast_to
209
210        if self._stream:
211            if to:
212                if not is_stream_class_type(to):
213                    raise TypeError(f"Expected custom parse type to be a subclass of {Stream} or {AsyncStream}")
214
215                return cast(
216                    _T,
217                    to(
218                        cast_to=extract_stream_chunk_type(
219                            to,
220                            failure_message="Expected custom stream type to be passed with a type argument, e.g. Stream[ChunkType]",
221                        ),
222                        response=self.http_response,
223                        client=cast(Any, self._client),
224                    ),
225                )
226
227            if self._stream_cls:
228                return cast(
229                    R,
230                    self._stream_cls(
231                        cast_to=extract_stream_chunk_type(self._stream_cls),
232                        response=self.http_response,
233                        client=cast(Any, self._client),
234                    ),
235                )
236
237            stream_cls = cast("type[Stream[Any]] | type[AsyncStream[Any]] | None", self._client._default_stream_cls)
238            if stream_cls is None:
239                raise MissingStreamClassError()
240
241            return cast(
242                R,
243                stream_cls(
244                    cast_to=cast_to,
245                    response=self.http_response,
246                    client=cast(Any, self._client),
247                ),
248            )
249
250        if cast_to is NoneType:
251            return cast(R, None)
252
253        response = self.http_response
254        if cast_to == str:
255            return cast(R, response.text)
256
257        if cast_to == int:
258            return cast(R, int(response.text))
259
260        if cast_to == float:
261            return cast(R, float(response.text))
262
263        if cast_to == bool:
264            return cast(R, response.text.lower() == "true")
265
266        if inspect.isclass(origin) and issubclass(origin, HttpxBinaryResponseContent):
267            return cast(R, cast_to(response))  # type: ignore
268
269        if origin == LegacyAPIResponse:
270            raise RuntimeError("Unexpected state - cast_to is `APIResponse`")
271
272        if inspect.isclass(
273            origin  # pyright: ignore[reportUnknownArgumentType]
274        ) and issubclass(origin, httpx.Response):
275            # Because of the invariance of our ResponseT TypeVar, users can subclass httpx.Response
276            # and pass that class to our request functions. We cannot change the variance to be either
277            # covariant or contravariant as that makes our usage of ResponseT illegal. We could construct
278            # the response class ourselves but that is something that should be supported directly in httpx
279            # as it would be easy to incorrectly construct the Response object due to the multitude of arguments.
280            if cast_to != httpx.Response:
281                raise ValueError(f"Subclasses of httpx.Response cannot be passed to `cast_to`")
282            return cast(R, response)
283
284        if (
285            inspect.isclass(
286                origin  # pyright: ignore[reportUnknownArgumentType]
287            )
288            and not issubclass(origin, BaseModel)
289            and issubclass(origin, pydantic.BaseModel)
290        ):
291            raise TypeError("Pydantic models must subclass our base model type, e.g. `from openai import BaseModel`")
292
293        if (
294            cast_to is not object
295            and not origin is list
296            and not origin is dict
297            and not origin is Union
298            and not issubclass(origin, BaseModel)
299        ):
300            raise RuntimeError(
301                f"Unsupported type, expected {cast_to} to be a subclass of {BaseModel}, {dict}, {list}, {Union}, {NoneType}, {str} or {httpx.Response}."
302            )
303
304        # split is required to handle cases where additional information is included
305        # in the response, e.g. application/json; charset=utf-8
306        content_type, *_ = response.headers.get("content-type", "*").split(";")
307        if not content_type.endswith("json"):
308            if is_basemodel(cast_to):
309                try:
310                    data = response.json()
311                except Exception as exc:
312                    log.debug("Could not read JSON from response data due to %s - %s", type(exc), exc)
313                else:
314                    return self._client._process_response_data(
315                        data=data,
316                        cast_to=cast_to,  # type: ignore
317                        response=response,
318                    )
319
320            if self._client._strict_response_validation:
321                raise APIResponseValidationError(
322                    response=response,
323                    message=f"Expected Content-Type response header to be `application/json` but received `{content_type}` instead.",
324                    body=response.text,
325                )
326
327            # If the API responds with content that isn't JSON then we just return
328            # the (decoded) text without performing any parsing so that you can still
329            # handle the response however you need to.
330            return response.text  # type: ignore
331
332        data = response.json()
333
334        return self._client._process_response_data(
335            data=data,
336            cast_to=cast_to,  # type: ignore
337            response=response,
338        )
339
340    @override
341    def __repr__(self) -> str:
342        return f"<APIResponse [{self.status_code} {self.http_response.reason_phrase}] type={self._cast_to}>"
343
344
345class MissingStreamClassError(TypeError):
346    def __init__(self) -> None:
347        super().__init__(
348            "The `stream` argument was set to `True` but the `stream_cls` argument was not given. See `openai._streaming` for reference",
349        )
350
351
352def to_raw_response_wrapper(func: Callable[P, R]) -> Callable[P, LegacyAPIResponse[R]]:
353    """Higher order function that takes one of our bound API methods and wraps it
354    to support returning the raw `APIResponse` object directly.
355    """
356
357    @functools.wraps(func)
358    def wrapped(*args: P.args, **kwargs: P.kwargs) -> LegacyAPIResponse[R]:
359        extra_headers: dict[str, str] = {**(cast(Any, kwargs.get("extra_headers")) or {})}
360        extra_headers[RAW_RESPONSE_HEADER] = "true"
361
362        kwargs["extra_headers"] = extra_headers
363
364        return cast(LegacyAPIResponse[R], func(*args, **kwargs))
365
366    return wrapped
367
368
369def async_to_raw_response_wrapper(func: Callable[P, Awaitable[R]]) -> Callable[P, Awaitable[LegacyAPIResponse[R]]]:
370    """Higher order function that takes one of our bound API methods and wraps it
371    to support returning the raw `APIResponse` object directly.
372    """
373
374    @functools.wraps(func)
375    async def wrapped(*args: P.args, **kwargs: P.kwargs) -> LegacyAPIResponse[R]:
376        extra_headers: dict[str, str] = {**(cast(Any, kwargs.get("extra_headers")) or {})}
377        extra_headers[RAW_RESPONSE_HEADER] = "true"
378
379        kwargs["extra_headers"] = extra_headers
380
381        return cast(LegacyAPIResponse[R], await func(*args, **kwargs))
382
383    return wrapped
384
385
386class HttpxBinaryResponseContent:
387    response: httpx.Response
388
389    def __init__(self, response: httpx.Response) -> None:
390        self.response = response
391
392    @property
393    def content(self) -> bytes:
394        return self.response.content
395
396    @property
397    def text(self) -> str:
398        return self.response.text
399
400    @property
401    def encoding(self) -> str | None:
402        return self.response.encoding
403
404    @property
405    def charset_encoding(self) -> str | None:
406        return self.response.charset_encoding
407
408    def json(self, **kwargs: Any) -> Any:
409        return self.response.json(**kwargs)
410
411    def read(self) -> bytes:
412        return self.response.read()
413
414    def iter_bytes(self, chunk_size: int | None = None) -> Iterator[bytes]:
415        return self.response.iter_bytes(chunk_size)
416
417    def iter_text(self, chunk_size: int | None = None) -> Iterator[str]:
418        return self.response.iter_text(chunk_size)
419
420    def iter_lines(self) -> Iterator[str]:
421        return self.response.iter_lines()
422
423    def iter_raw(self, chunk_size: int | None = None) -> Iterator[bytes]:
424        return self.response.iter_raw(chunk_size)
425
426    def write_to_file(
427        self,
428        file: str | os.PathLike[str],
429    ) -> None:
430        """Write the output to the given file.
431
432        Accepts a filename or any path-like object, e.g. pathlib.Path
433
434        Note: if you want to stream the data to the file instead of writing
435        all at once then you should use `.with_streaming_response` when making
436        the API request, e.g. `client.with_streaming_response.foo().stream_to_file('my_filename.txt')`
437        """
438        with open(file, mode="wb") as f:
439            for data in self.response.iter_bytes():
440                f.write(data)
441
442    @deprecated(
443        "Due to a bug, this method doesn't actually stream the response content, `.with_streaming_response.method()` should be used instead"
444    )
445    def stream_to_file(
446        self,
447        file: str | os.PathLike[str],
448        *,
449        chunk_size: int | None = None,
450    ) -> None:
451        with open(file, mode="wb") as f:
452            for data in self.response.iter_bytes(chunk_size):
453                f.write(data)
454
455    def close(self) -> None:
456        return self.response.close()
457
458    async def aread(self) -> bytes:
459        return await self.response.aread()
460
461    async def aiter_bytes(self, chunk_size: int | None = None) -> AsyncIterator[bytes]:
462        return self.response.aiter_bytes(chunk_size)
463
464    async def aiter_text(self, chunk_size: int | None = None) -> AsyncIterator[str]:
465        return self.response.aiter_text(chunk_size)
466
467    async def aiter_lines(self) -> AsyncIterator[str]:
468        return self.response.aiter_lines()
469
470    async def aiter_raw(self, chunk_size: int | None = None) -> AsyncIterator[bytes]:
471        return self.response.aiter_raw(chunk_size)
472
473    @deprecated(
474        "Due to a bug, this method doesn't actually stream the response content, `.with_streaming_response.method()` should be used instead"
475    )
476    async def astream_to_file(
477        self,
478        file: str | os.PathLike[str],
479        *,
480        chunk_size: int | None = None,
481    ) -> None:
482        path = anyio.Path(file)
483        async with await path.open(mode="wb") as f:
484            async for data in self.response.aiter_bytes(chunk_size):
485                await f.write(data)
486
487    async def aclose(self) -> None:
488        return await self.response.aclose()