main
1from __future__ import annotations
2
3import inspect
4from typing import Any, Iterable
5from typing_extensions import TypeAlias
6
7import pytest
8import pydantic
9
10from ..utils import rich_print_str
11
12ReprArgs: TypeAlias = "Iterable[tuple[str | None, Any]]"
13
14
15def print_obj(obj: object, monkeypatch: pytest.MonkeyPatch) -> str:
16 """Pretty print an object to a string"""
17
18 # monkeypatch pydantic model printing so that model fields
19 # are always printed in the same order so we can reliably
20 # use this for snapshot tests
21 original_repr = pydantic.BaseModel.__repr_args__
22
23 def __repr_args__(self: pydantic.BaseModel) -> ReprArgs:
24 return sorted(original_repr(self), key=lambda arg: arg[0] or arg)
25
26 with monkeypatch.context() as m:
27 m.setattr(pydantic.BaseModel, "__repr_args__", __repr_args__)
28
29 string = rich_print_str(obj)
30
31 # we remove all `fn_name.<locals>.` occurrences
32 # so that we can share the same snapshots between
33 # pydantic v1 and pydantic v2 as their output for
34 # generic models differs, e.g.
35 #
36 # v2: `ParsedChatCompletion[test_parse_pydantic_model.<locals>.Location]`
37 # v1: `ParsedChatCompletion[Location]`
38 return clear_locals(string, stacklevel=2)
39
40
41def get_caller_name(*, stacklevel: int = 1) -> str:
42 frame = inspect.currentframe()
43 assert frame is not None
44
45 for i in range(stacklevel):
46 frame = frame.f_back
47 assert frame is not None, f"no {i}th frame"
48
49 return frame.f_code.co_name
50
51
52def clear_locals(string: str, *, stacklevel: int) -> str:
53 caller = get_caller_name(stacklevel=stacklevel + 1)
54 return string.replace(f"{caller}.<locals>.", "")