main
  1from __future__ import annotations
  2
  3from typing import Any, List, Tuple, Union, Mapping, TypeVar
  4from urllib.parse import parse_qs, urlencode
  5from typing_extensions import Literal, get_args
  6
  7from ._types import NotGiven, not_given
  8from ._utils import flatten
  9
 10_T = TypeVar("_T")
 11
 12
 13ArrayFormat = Literal["comma", "repeat", "indices", "brackets"]
 14NestedFormat = Literal["dots", "brackets"]
 15
 16PrimitiveData = Union[str, int, float, bool, None]
 17# this should be Data = Union[PrimitiveData, "List[Data]", "Tuple[Data]", "Mapping[str, Data]"]
 18# https://github.com/microsoft/pyright/issues/3555
 19Data = Union[PrimitiveData, List[Any], Tuple[Any], "Mapping[str, Any]"]
 20Params = Mapping[str, Data]
 21
 22
 23class Querystring:
 24    array_format: ArrayFormat
 25    nested_format: NestedFormat
 26
 27    def __init__(
 28        self,
 29        *,
 30        array_format: ArrayFormat = "repeat",
 31        nested_format: NestedFormat = "brackets",
 32    ) -> None:
 33        self.array_format = array_format
 34        self.nested_format = nested_format
 35
 36    def parse(self, query: str) -> Mapping[str, object]:
 37        # Note: custom format syntax is not supported yet
 38        return parse_qs(query)
 39
 40    def stringify(
 41        self,
 42        params: Params,
 43        *,
 44        array_format: ArrayFormat | NotGiven = not_given,
 45        nested_format: NestedFormat | NotGiven = not_given,
 46    ) -> str:
 47        return urlencode(
 48            self.stringify_items(
 49                params,
 50                array_format=array_format,
 51                nested_format=nested_format,
 52            )
 53        )
 54
 55    def stringify_items(
 56        self,
 57        params: Params,
 58        *,
 59        array_format: ArrayFormat | NotGiven = not_given,
 60        nested_format: NestedFormat | NotGiven = not_given,
 61    ) -> list[tuple[str, str]]:
 62        opts = Options(
 63            qs=self,
 64            array_format=array_format,
 65            nested_format=nested_format,
 66        )
 67        return flatten([self._stringify_item(key, value, opts) for key, value in params.items()])
 68
 69    def _stringify_item(
 70        self,
 71        key: str,
 72        value: Data,
 73        opts: Options,
 74    ) -> list[tuple[str, str]]:
 75        if isinstance(value, Mapping):
 76            items: list[tuple[str, str]] = []
 77            nested_format = opts.nested_format
 78            for subkey, subvalue in value.items():
 79                items.extend(
 80                    self._stringify_item(
 81                        # TODO: error if unknown format
 82                        f"{key}.{subkey}" if nested_format == "dots" else f"{key}[{subkey}]",
 83                        subvalue,
 84                        opts,
 85                    )
 86                )
 87            return items
 88
 89        if isinstance(value, (list, tuple)):
 90            array_format = opts.array_format
 91            if array_format == "comma":
 92                return [
 93                    (
 94                        key,
 95                        ",".join(self._primitive_value_to_str(item) for item in value if item is not None),
 96                    ),
 97                ]
 98            elif array_format == "repeat":
 99                items = []
100                for item in value:
101                    items.extend(self._stringify_item(key, item, opts))
102                return items
103            elif array_format == "indices":
104                raise NotImplementedError("The array indices format is not supported yet")
105            elif array_format == "brackets":
106                items = []
107                key = key + "[]"
108                for item in value:
109                    items.extend(self._stringify_item(key, item, opts))
110                return items
111            else:
112                raise NotImplementedError(
113                    f"Unknown array_format value: {array_format}, choose from {', '.join(get_args(ArrayFormat))}"
114                )
115
116        serialised = self._primitive_value_to_str(value)
117        if not serialised:
118            return []
119        return [(key, serialised)]
120
121    def _primitive_value_to_str(self, value: PrimitiveData) -> str:
122        # copied from httpx
123        if value is True:
124            return "true"
125        elif value is False:
126            return "false"
127        elif value is None:
128            return ""
129        return str(value)
130
131
132_qs = Querystring()
133parse = _qs.parse
134stringify = _qs.stringify
135stringify_items = _qs.stringify_items
136
137
138class Options:
139    array_format: ArrayFormat
140    nested_format: NestedFormat
141
142    def __init__(
143        self,
144        qs: Querystring = _qs,
145        *,
146        array_format: ArrayFormat | NotGiven = not_given,
147        nested_format: NestedFormat | NotGiven = not_given,
148    ) -> None:
149        self.array_format = qs.array_format if isinstance(array_format, NotGiven) else array_format
150        self.nested_format = qs.nested_format if isinstance(nested_format, NotGiven) else nested_format