main
 1from __future__ import annotations
 2
 3import inspect
 4from typing import Any, Iterable
 5from typing_extensions import TypeAlias
 6
 7import pytest
 8import pydantic
 9
10from ..utils import rich_print_str
11
12ReprArgs: TypeAlias = "Iterable[tuple[str | None, Any]]"
13
14
15def print_obj(obj: object, monkeypatch: pytest.MonkeyPatch) -> str:
16    """Pretty print an object to a string"""
17
18    # monkeypatch pydantic model printing so that model fields
19    # are always printed in the same order so we can reliably
20    # use this for snapshot tests
21    original_repr = pydantic.BaseModel.__repr_args__
22
23    def __repr_args__(self: pydantic.BaseModel) -> ReprArgs:
24        return sorted(original_repr(self), key=lambda arg: arg[0] or arg)
25
26    with monkeypatch.context() as m:
27        m.setattr(pydantic.BaseModel, "__repr_args__", __repr_args__)
28
29        string = rich_print_str(obj)
30
31        # we remove all `fn_name.<locals>.` occurrences
32        # so that we can share the same snapshots between
33        # pydantic v1 and pydantic v2 as their output for
34        # generic models differs, e.g.
35        #
36        # v2: `ParsedChatCompletion[test_parse_pydantic_model.<locals>.Location]`
37        # v1: `ParsedChatCompletion[Location]`
38        return clear_locals(string, stacklevel=2)
39
40
41def get_caller_name(*, stacklevel: int = 1) -> str:
42    frame = inspect.currentframe()
43    assert frame is not None
44
45    for i in range(stacklevel):
46        frame = frame.f_back
47        assert frame is not None, f"no {i}th frame"
48
49    return frame.f_code.co_name
50
51
52def clear_locals(string: str, *, stacklevel: int) -> str:
53    caller = get_caller_name(stacklevel=stacklevel + 1)
54    return string.replace(f"{caller}.<locals>.", "")