main
  1from __future__ import annotations
  2
  3import sys
  4from typing import TYPE_CHECKING, List, Optional, cast
  5from argparse import ArgumentParser
  6from typing_extensions import Literal, NamedTuple
  7
  8from ..._utils import get_client
  9from ..._models import BaseModel
 10from ...._streaming import Stream
 11from ....types.chat import (
 12    ChatCompletionRole,
 13    ChatCompletionChunk,
 14    CompletionCreateParams,
 15)
 16from ....types.chat.completion_create_params import (
 17    CompletionCreateParamsStreaming,
 18    CompletionCreateParamsNonStreaming,
 19)
 20
 21if TYPE_CHECKING:
 22    from argparse import _SubParsersAction
 23
 24
 25def register(subparser: _SubParsersAction[ArgumentParser]) -> None:
 26    sub = subparser.add_parser("chat.completions.create")
 27
 28    sub._action_groups.pop()
 29    req = sub.add_argument_group("required arguments")
 30    opt = sub.add_argument_group("optional arguments")
 31
 32    req.add_argument(
 33        "-g",
 34        "--message",
 35        action="append",
 36        nargs=2,
 37        metavar=("ROLE", "CONTENT"),
 38        help="A message in `{role} {content}` format. Use this argument multiple times to add multiple messages.",
 39        required=True,
 40    )
 41    req.add_argument(
 42        "-m",
 43        "--model",
 44        help="The model to use.",
 45        required=True,
 46    )
 47
 48    opt.add_argument(
 49        "-n",
 50        "--n",
 51        help="How many completions to generate for the conversation.",
 52        type=int,
 53    )
 54    opt.add_argument("-M", "--max-tokens", help="The maximum number of tokens to generate.", type=int)
 55    opt.add_argument(
 56        "-t",
 57        "--temperature",
 58        help="""What sampling temperature to use. Higher values means the model will take more risks. Try 0.9 for more creative applications, and 0 (argmax sampling) for ones with a well-defined answer.
 59
 60Mutually exclusive with `top_p`.""",
 61        type=float,
 62    )
 63    opt.add_argument(
 64        "-P",
 65        "--top_p",
 66        help="""An alternative to sampling with temperature, called nucleus sampling, where the considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10%% probability mass are considered.
 67
 68            Mutually exclusive with `temperature`.""",
 69        type=float,
 70    )
 71    opt.add_argument(
 72        "--stop",
 73        help="A stop sequence at which to stop generating tokens for the message.",
 74    )
 75    opt.add_argument("--stream", help="Stream messages as they're ready.", action="store_true")
 76    sub.set_defaults(func=CLIChatCompletion.create, args_model=CLIChatCompletionCreateArgs)
 77
 78
 79class CLIMessage(NamedTuple):
 80    role: ChatCompletionRole
 81    content: str
 82
 83
 84class CLIChatCompletionCreateArgs(BaseModel):
 85    message: List[CLIMessage]
 86    model: str
 87    n: Optional[int] = None
 88    max_tokens: Optional[int] = None
 89    temperature: Optional[float] = None
 90    top_p: Optional[float] = None
 91    stop: Optional[str] = None
 92    stream: bool = False
 93
 94
 95class CLIChatCompletion:
 96    @staticmethod
 97    def create(args: CLIChatCompletionCreateArgs) -> None:
 98        params: CompletionCreateParams = {
 99            "model": args.model,
100            "messages": [
101                {"role": cast(Literal["user"], message.role), "content": message.content} for message in args.message
102            ],
103            # type checkers are not good at inferring union types so we have to set stream afterwards
104            "stream": False,
105        }
106        if args.temperature is not None:
107            params["temperature"] = args.temperature
108        if args.stop is not None:
109            params["stop"] = args.stop
110        if args.top_p is not None:
111            params["top_p"] = args.top_p
112        if args.n is not None:
113            params["n"] = args.n
114        if args.stream:
115            params["stream"] = args.stream  # type: ignore
116        if args.max_tokens is not None:
117            params["max_tokens"] = args.max_tokens
118
119        if args.stream:
120            return CLIChatCompletion._stream_create(cast(CompletionCreateParamsStreaming, params))
121
122        return CLIChatCompletion._create(cast(CompletionCreateParamsNonStreaming, params))
123
124    @staticmethod
125    def _create(params: CompletionCreateParamsNonStreaming) -> None:
126        completion = get_client().chat.completions.create(**params)
127        should_print_header = len(completion.choices) > 1
128        for choice in completion.choices:
129            if should_print_header:
130                sys.stdout.write("===== Chat Completion {} =====\n".format(choice.index))
131
132            content = choice.message.content if choice.message.content is not None else "None"
133            sys.stdout.write(content)
134
135            if should_print_header or not content.endswith("\n"):
136                sys.stdout.write("\n")
137
138            sys.stdout.flush()
139
140    @staticmethod
141    def _stream_create(params: CompletionCreateParamsStreaming) -> None:
142        # cast is required for mypy
143        stream = cast(  # pyright: ignore[reportUnnecessaryCast]
144            Stream[ChatCompletionChunk], get_client().chat.completions.create(**params)
145        )
146        for chunk in stream:
147            should_print_header = len(chunk.choices) > 1
148            for choice in chunk.choices:
149                if should_print_header:
150                    sys.stdout.write("===== Chat Completion {} =====\n".format(choice.index))
151
152                content = choice.delta.content or ""
153                sys.stdout.write(content)
154
155                if should_print_header:
156                    sys.stdout.write("\n")
157
158                sys.stdout.flush()
159
160        sys.stdout.write("\n")