main
  1from __future__ import annotations
  2
  3import sys
  4import logging
  5import argparse
  6from typing import Any, List, Type, Optional
  7from typing_extensions import ClassVar
  8
  9import httpx
 10import pydantic
 11
 12import openai
 13
 14from . import _tools
 15from .. import _ApiType, __version__
 16from ._api import register_commands
 17from ._utils import can_use_http2
 18from ._errors import CLIError, display_error
 19from .._compat import PYDANTIC_V1, ConfigDict, model_parse
 20from .._models import BaseModel
 21from .._exceptions import APIError
 22
 23logger = logging.getLogger()
 24formatter = logging.Formatter("[%(asctime)s] %(message)s")
 25handler = logging.StreamHandler(sys.stderr)
 26handler.setFormatter(formatter)
 27logger.addHandler(handler)
 28
 29
 30class Arguments(BaseModel):
 31    if PYDANTIC_V1:
 32
 33        class Config(pydantic.BaseConfig):  # type: ignore
 34            extra: Any = pydantic.Extra.ignore  # type: ignore
 35    else:
 36        model_config: ClassVar[ConfigDict] = ConfigDict(
 37            extra="ignore",
 38        )
 39
 40    verbosity: int
 41    version: Optional[str] = None
 42
 43    api_key: Optional[str]
 44    api_base: Optional[str]
 45    organization: Optional[str]
 46    proxy: Optional[List[str]]
 47    api_type: Optional[_ApiType] = None
 48    api_version: Optional[str] = None
 49
 50    # azure
 51    azure_endpoint: Optional[str] = None
 52    azure_ad_token: Optional[str] = None
 53
 54    # internal, set by subparsers to parse their specific args
 55    args_model: Optional[Type[BaseModel]] = None
 56
 57    # internal, used so that subparsers can forward unknown arguments
 58    unknown_args: List[str] = []
 59    allow_unknown_args: bool = False
 60
 61
 62def _build_parser() -> argparse.ArgumentParser:
 63    parser = argparse.ArgumentParser(description=None, prog="openai")
 64    parser.add_argument(
 65        "-v",
 66        "--verbose",
 67        action="count",
 68        dest="verbosity",
 69        default=0,
 70        help="Set verbosity.",
 71    )
 72    parser.add_argument("-b", "--api-base", help="What API base url to use.")
 73    parser.add_argument("-k", "--api-key", help="What API key to use.")
 74    parser.add_argument("-p", "--proxy", nargs="+", help="What proxy to use.")
 75    parser.add_argument(
 76        "-o",
 77        "--organization",
 78        help="Which organization to run as (will use your default organization if not specified)",
 79    )
 80    parser.add_argument(
 81        "-t",
 82        "--api-type",
 83        type=str,
 84        choices=("openai", "azure"),
 85        help="The backend API to call, must be `openai` or `azure`",
 86    )
 87    parser.add_argument(
 88        "--api-version",
 89        help="The Azure API version, e.g. 'https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#rest-api-versioning'",
 90    )
 91
 92    # azure
 93    parser.add_argument(
 94        "--azure-endpoint",
 95        help="The Azure endpoint, e.g. 'https://endpoint.openai.azure.com'",
 96    )
 97    parser.add_argument(
 98        "--azure-ad-token",
 99        help="A token from Azure Active Directory, https://www.microsoft.com/en-us/security/business/identity-access/microsoft-entra-id",
100    )
101
102    # prints the package version
103    parser.add_argument(
104        "-V",
105        "--version",
106        action="version",
107        version="%(prog)s " + __version__,
108    )
109
110    def help() -> None:
111        parser.print_help()
112
113    parser.set_defaults(func=help)
114
115    subparsers = parser.add_subparsers()
116    sub_api = subparsers.add_parser("api", help="Direct API calls")
117
118    register_commands(sub_api)
119
120    sub_tools = subparsers.add_parser("tools", help="Client side tools for convenience")
121    _tools.register_commands(sub_tools, subparsers)
122
123    return parser
124
125
126def main() -> int:
127    try:
128        _main()
129    except (APIError, CLIError, pydantic.ValidationError) as err:
130        display_error(err)
131        return 1
132    except KeyboardInterrupt:
133        sys.stderr.write("\n")
134        return 1
135    return 0
136
137
138def _parse_args(parser: argparse.ArgumentParser) -> tuple[argparse.Namespace, Arguments, list[str]]:
139    # argparse by default will strip out the `--` but we want to keep it for unknown arguments
140    if "--" in sys.argv:
141        idx = sys.argv.index("--")
142        known_args = sys.argv[1:idx]
143        unknown_args = sys.argv[idx:]
144    else:
145        known_args = sys.argv[1:]
146        unknown_args = []
147
148    parsed, remaining_unknown = parser.parse_known_args(known_args)
149
150    # append any remaining unknown arguments from the initial parsing
151    remaining_unknown.extend(unknown_args)
152
153    args = model_parse(Arguments, vars(parsed))
154    if not args.allow_unknown_args:
155        # we have to parse twice to ensure any unknown arguments
156        # result in an error if that behaviour is desired
157        parser.parse_args()
158
159    return parsed, args, remaining_unknown
160
161
162def _main() -> None:
163    parser = _build_parser()
164    parsed, args, unknown = _parse_args(parser)
165
166    if args.verbosity != 0:
167        sys.stderr.write("Warning: --verbosity isn't supported yet\n")
168
169    proxies: dict[str, httpx.BaseTransport] = {}
170    if args.proxy is not None:
171        for proxy in args.proxy:
172            key = "https://" if proxy.startswith("https") else "http://"
173            if key in proxies:
174                raise CLIError(f"Multiple {key} proxies given - only the last one would be used")
175
176            proxies[key] = httpx.HTTPTransport(proxy=httpx.Proxy(httpx.URL(proxy)))
177
178    http_client = httpx.Client(
179        mounts=proxies or None,
180        http2=can_use_http2(),
181    )
182    openai.http_client = http_client
183
184    if args.organization:
185        openai.organization = args.organization
186
187    if args.api_key:
188        openai.api_key = args.api_key
189
190    if args.api_base:
191        openai.base_url = args.api_base
192
193    # azure
194    if args.api_type is not None:
195        openai.api_type = args.api_type
196
197    if args.azure_endpoint is not None:
198        openai.azure_endpoint = args.azure_endpoint
199
200    if args.api_version is not None:
201        openai.api_version = args.api_version
202
203    if args.azure_ad_token is not None:
204        openai.azure_ad_token = args.azure_ad_token
205
206    try:
207        if args.args_model:
208            parsed.func(
209                model_parse(
210                    args.args_model,
211                    {
212                        **{
213                            # we omit None values so that they can be defaulted to `NotGiven`
214                            # and we'll strip it from the API request
215                            key: value
216                            for key, value in vars(parsed).items()
217                            if value is not None
218                        },
219                        "unknown_args": unknown,
220                    },
221                )
222            )
223        else:
224            parsed.func()
225    finally:
226        try:
227            http_client.close()
228        except Exception:
229            pass
230
231
232if __name__ == "__main__":
233    sys.exit(main())