main
1from __future__ import annotations
2
3import sys
4import logging
5import argparse
6from typing import Any, List, Type, Optional
7from typing_extensions import ClassVar
8
9import httpx
10import pydantic
11
12import openai
13
14from . import _tools
15from .. import _ApiType, __version__
16from ._api import register_commands
17from ._utils import can_use_http2
18from ._errors import CLIError, display_error
19from .._compat import PYDANTIC_V1, ConfigDict, model_parse
20from .._models import BaseModel
21from .._exceptions import APIError
22
23logger = logging.getLogger()
24formatter = logging.Formatter("[%(asctime)s] %(message)s")
25handler = logging.StreamHandler(sys.stderr)
26handler.setFormatter(formatter)
27logger.addHandler(handler)
28
29
30class Arguments(BaseModel):
31 if PYDANTIC_V1:
32
33 class Config(pydantic.BaseConfig): # type: ignore
34 extra: Any = pydantic.Extra.ignore # type: ignore
35 else:
36 model_config: ClassVar[ConfigDict] = ConfigDict(
37 extra="ignore",
38 )
39
40 verbosity: int
41 version: Optional[str] = None
42
43 api_key: Optional[str]
44 api_base: Optional[str]
45 organization: Optional[str]
46 proxy: Optional[List[str]]
47 api_type: Optional[_ApiType] = None
48 api_version: Optional[str] = None
49
50 # azure
51 azure_endpoint: Optional[str] = None
52 azure_ad_token: Optional[str] = None
53
54 # internal, set by subparsers to parse their specific args
55 args_model: Optional[Type[BaseModel]] = None
56
57 # internal, used so that subparsers can forward unknown arguments
58 unknown_args: List[str] = []
59 allow_unknown_args: bool = False
60
61
62def _build_parser() -> argparse.ArgumentParser:
63 parser = argparse.ArgumentParser(description=None, prog="openai")
64 parser.add_argument(
65 "-v",
66 "--verbose",
67 action="count",
68 dest="verbosity",
69 default=0,
70 help="Set verbosity.",
71 )
72 parser.add_argument("-b", "--api-base", help="What API base url to use.")
73 parser.add_argument("-k", "--api-key", help="What API key to use.")
74 parser.add_argument("-p", "--proxy", nargs="+", help="What proxy to use.")
75 parser.add_argument(
76 "-o",
77 "--organization",
78 help="Which organization to run as (will use your default organization if not specified)",
79 )
80 parser.add_argument(
81 "-t",
82 "--api-type",
83 type=str,
84 choices=("openai", "azure"),
85 help="The backend API to call, must be `openai` or `azure`",
86 )
87 parser.add_argument(
88 "--api-version",
89 help="The Azure API version, e.g. 'https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#rest-api-versioning'",
90 )
91
92 # azure
93 parser.add_argument(
94 "--azure-endpoint",
95 help="The Azure endpoint, e.g. 'https://endpoint.openai.azure.com'",
96 )
97 parser.add_argument(
98 "--azure-ad-token",
99 help="A token from Azure Active Directory, https://www.microsoft.com/en-us/security/business/identity-access/microsoft-entra-id",
100 )
101
102 # prints the package version
103 parser.add_argument(
104 "-V",
105 "--version",
106 action="version",
107 version="%(prog)s " + __version__,
108 )
109
110 def help() -> None:
111 parser.print_help()
112
113 parser.set_defaults(func=help)
114
115 subparsers = parser.add_subparsers()
116 sub_api = subparsers.add_parser("api", help="Direct API calls")
117
118 register_commands(sub_api)
119
120 sub_tools = subparsers.add_parser("tools", help="Client side tools for convenience")
121 _tools.register_commands(sub_tools, subparsers)
122
123 return parser
124
125
126def main() -> int:
127 try:
128 _main()
129 except (APIError, CLIError, pydantic.ValidationError) as err:
130 display_error(err)
131 return 1
132 except KeyboardInterrupt:
133 sys.stderr.write("\n")
134 return 1
135 return 0
136
137
138def _parse_args(parser: argparse.ArgumentParser) -> tuple[argparse.Namespace, Arguments, list[str]]:
139 # argparse by default will strip out the `--` but we want to keep it for unknown arguments
140 if "--" in sys.argv:
141 idx = sys.argv.index("--")
142 known_args = sys.argv[1:idx]
143 unknown_args = sys.argv[idx:]
144 else:
145 known_args = sys.argv[1:]
146 unknown_args = []
147
148 parsed, remaining_unknown = parser.parse_known_args(known_args)
149
150 # append any remaining unknown arguments from the initial parsing
151 remaining_unknown.extend(unknown_args)
152
153 args = model_parse(Arguments, vars(parsed))
154 if not args.allow_unknown_args:
155 # we have to parse twice to ensure any unknown arguments
156 # result in an error if that behaviour is desired
157 parser.parse_args()
158
159 return parsed, args, remaining_unknown
160
161
162def _main() -> None:
163 parser = _build_parser()
164 parsed, args, unknown = _parse_args(parser)
165
166 if args.verbosity != 0:
167 sys.stderr.write("Warning: --verbosity isn't supported yet\n")
168
169 proxies: dict[str, httpx.BaseTransport] = {}
170 if args.proxy is not None:
171 for proxy in args.proxy:
172 key = "https://" if proxy.startswith("https") else "http://"
173 if key in proxies:
174 raise CLIError(f"Multiple {key} proxies given - only the last one would be used")
175
176 proxies[key] = httpx.HTTPTransport(proxy=httpx.Proxy(httpx.URL(proxy)))
177
178 http_client = httpx.Client(
179 mounts=proxies or None,
180 http2=can_use_http2(),
181 )
182 openai.http_client = http_client
183
184 if args.organization:
185 openai.organization = args.organization
186
187 if args.api_key:
188 openai.api_key = args.api_key
189
190 if args.api_base:
191 openai.base_url = args.api_base
192
193 # azure
194 if args.api_type is not None:
195 openai.api_type = args.api_type
196
197 if args.azure_endpoint is not None:
198 openai.azure_endpoint = args.azure_endpoint
199
200 if args.api_version is not None:
201 openai.api_version = args.api_version
202
203 if args.azure_ad_token is not None:
204 openai.azure_ad_token = args.azure_ad_token
205
206 try:
207 if args.args_model:
208 parsed.func(
209 model_parse(
210 args.args_model,
211 {
212 **{
213 # we omit None values so that they can be defaulted to `NotGiven`
214 # and we'll strip it from the API request
215 key: value
216 for key, value in vars(parsed).items()
217 if value is not None
218 },
219 "unknown_args": unknown,
220 },
221 )
222 )
223 else:
224 parsed.func()
225 finally:
226 try:
227 http_client.close()
228 except Exception:
229 pass
230
231
232if __name__ == "__main__":
233 sys.exit(main())