main
1from __future__ import annotations
2
3import sys
4from typing import TYPE_CHECKING, List, Optional, cast
5from argparse import ArgumentParser
6from typing_extensions import Literal, NamedTuple
7
8from ..._utils import get_client
9from ..._models import BaseModel
10from ...._streaming import Stream
11from ....types.chat import (
12 ChatCompletionRole,
13 ChatCompletionChunk,
14 CompletionCreateParams,
15)
16from ....types.chat.completion_create_params import (
17 CompletionCreateParamsStreaming,
18 CompletionCreateParamsNonStreaming,
19)
20
21if TYPE_CHECKING:
22 from argparse import _SubParsersAction
23
24
25def register(subparser: _SubParsersAction[ArgumentParser]) -> None:
26 sub = subparser.add_parser("chat.completions.create")
27
28 sub._action_groups.pop()
29 req = sub.add_argument_group("required arguments")
30 opt = sub.add_argument_group("optional arguments")
31
32 req.add_argument(
33 "-g",
34 "--message",
35 action="append",
36 nargs=2,
37 metavar=("ROLE", "CONTENT"),
38 help="A message in `{role} {content}` format. Use this argument multiple times to add multiple messages.",
39 required=True,
40 )
41 req.add_argument(
42 "-m",
43 "--model",
44 help="The model to use.",
45 required=True,
46 )
47
48 opt.add_argument(
49 "-n",
50 "--n",
51 help="How many completions to generate for the conversation.",
52 type=int,
53 )
54 opt.add_argument("-M", "--max-tokens", help="The maximum number of tokens to generate.", type=int)
55 opt.add_argument(
56 "-t",
57 "--temperature",
58 help="""What sampling temperature to use. Higher values means the model will take more risks. Try 0.9 for more creative applications, and 0 (argmax sampling) for ones with a well-defined answer.
59
60Mutually exclusive with `top_p`.""",
61 type=float,
62 )
63 opt.add_argument(
64 "-P",
65 "--top_p",
66 help="""An alternative to sampling with temperature, called nucleus sampling, where the considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10%% probability mass are considered.
67
68 Mutually exclusive with `temperature`.""",
69 type=float,
70 )
71 opt.add_argument(
72 "--stop",
73 help="A stop sequence at which to stop generating tokens for the message.",
74 )
75 opt.add_argument("--stream", help="Stream messages as they're ready.", action="store_true")
76 sub.set_defaults(func=CLIChatCompletion.create, args_model=CLIChatCompletionCreateArgs)
77
78
79class CLIMessage(NamedTuple):
80 role: ChatCompletionRole
81 content: str
82
83
84class CLIChatCompletionCreateArgs(BaseModel):
85 message: List[CLIMessage]
86 model: str
87 n: Optional[int] = None
88 max_tokens: Optional[int] = None
89 temperature: Optional[float] = None
90 top_p: Optional[float] = None
91 stop: Optional[str] = None
92 stream: bool = False
93
94
95class CLIChatCompletion:
96 @staticmethod
97 def create(args: CLIChatCompletionCreateArgs) -> None:
98 params: CompletionCreateParams = {
99 "model": args.model,
100 "messages": [
101 {"role": cast(Literal["user"], message.role), "content": message.content} for message in args.message
102 ],
103 # type checkers are not good at inferring union types so we have to set stream afterwards
104 "stream": False,
105 }
106 if args.temperature is not None:
107 params["temperature"] = args.temperature
108 if args.stop is not None:
109 params["stop"] = args.stop
110 if args.top_p is not None:
111 params["top_p"] = args.top_p
112 if args.n is not None:
113 params["n"] = args.n
114 if args.stream:
115 params["stream"] = args.stream # type: ignore
116 if args.max_tokens is not None:
117 params["max_tokens"] = args.max_tokens
118
119 if args.stream:
120 return CLIChatCompletion._stream_create(cast(CompletionCreateParamsStreaming, params))
121
122 return CLIChatCompletion._create(cast(CompletionCreateParamsNonStreaming, params))
123
124 @staticmethod
125 def _create(params: CompletionCreateParamsNonStreaming) -> None:
126 completion = get_client().chat.completions.create(**params)
127 should_print_header = len(completion.choices) > 1
128 for choice in completion.choices:
129 if should_print_header:
130 sys.stdout.write("===== Chat Completion {} =====\n".format(choice.index))
131
132 content = choice.message.content if choice.message.content is not None else "None"
133 sys.stdout.write(content)
134
135 if should_print_header or not content.endswith("\n"):
136 sys.stdout.write("\n")
137
138 sys.stdout.flush()
139
140 @staticmethod
141 def _stream_create(params: CompletionCreateParamsStreaming) -> None:
142 # cast is required for mypy
143 stream = cast( # pyright: ignore[reportUnnecessaryCast]
144 Stream[ChatCompletionChunk], get_client().chat.completions.create(**params)
145 )
146 for chunk in stream:
147 should_print_header = len(chunk.choices) > 1
148 for choice in chunk.choices:
149 if should_print_header:
150 sys.stdout.write("===== Chat Completion {} =====\n".format(choice.index))
151
152 content = choice.delta.content or ""
153 sys.stdout.write(content)
154
155 if should_print_header:
156 sys.stdout.write("\n")
157
158 sys.stdout.flush()
159
160 sys.stdout.write("\n")