main
1from __future__ import annotations
2
3import sys
4from typing import TYPE_CHECKING, Optional, cast
5from argparse import ArgumentParser
6from functools import partial
7
8from openai.types.completion import Completion
9
10from .._utils import get_client
11from ..._types import Omittable, omit
12from ..._utils import is_given
13from .._errors import CLIError
14from .._models import BaseModel
15from ..._streaming import Stream
16
17if TYPE_CHECKING:
18 from argparse import _SubParsersAction
19
20
21def register(subparser: _SubParsersAction[ArgumentParser]) -> None:
22 sub = subparser.add_parser("completions.create")
23
24 # Required
25 sub.add_argument(
26 "-m",
27 "--model",
28 help="The model to use",
29 required=True,
30 )
31
32 # Optional
33 sub.add_argument("-p", "--prompt", help="An optional prompt to complete from")
34 sub.add_argument("--stream", help="Stream tokens as they're ready.", action="store_true")
35 sub.add_argument("-M", "--max-tokens", help="The maximum number of tokens to generate", type=int)
36 sub.add_argument(
37 "-t",
38 "--temperature",
39 help="""What sampling temperature to use. Higher values means the model will take more risks. Try 0.9 for more creative applications, and 0 (argmax sampling) for ones with a well-defined answer.
40
41Mutually exclusive with `top_p`.""",
42 type=float,
43 )
44 sub.add_argument(
45 "-P",
46 "--top_p",
47 help="""An alternative to sampling with temperature, called nucleus sampling, where the considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10%% probability mass are considered.
48
49 Mutually exclusive with `temperature`.""",
50 type=float,
51 )
52 sub.add_argument(
53 "-n",
54 "--n",
55 help="How many sub-completions to generate for each prompt.",
56 type=int,
57 )
58 sub.add_argument(
59 "--logprobs",
60 help="Include the log probabilities on the `logprobs` most likely tokens, as well the chosen tokens. So for example, if `logprobs` is 10, the API will return a list of the 10 most likely tokens. If `logprobs` is 0, only the chosen tokens will have logprobs returned.",
61 type=int,
62 )
63 sub.add_argument(
64 "--best_of",
65 help="Generates `best_of` completions server-side and returns the 'best' (the one with the highest log probability per token). Results cannot be streamed.",
66 type=int,
67 )
68 sub.add_argument(
69 "--echo",
70 help="Echo back the prompt in addition to the completion",
71 action="store_true",
72 )
73 sub.add_argument(
74 "--frequency_penalty",
75 help="Positive values penalize new tokens based on their existing frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim.",
76 type=float,
77 )
78 sub.add_argument(
79 "--presence_penalty",
80 help="Positive values penalize new tokens based on whether they appear in the text so far, increasing the model's likelihood to talk about new topics.",
81 type=float,
82 )
83 sub.add_argument("--suffix", help="The suffix that comes after a completion of inserted text.")
84 sub.add_argument("--stop", help="A stop sequence at which to stop generating tokens.")
85 sub.add_argument(
86 "--user",
87 help="A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse.",
88 )
89 # TODO: add support for logit_bias
90 sub.set_defaults(func=CLICompletions.create, args_model=CLICompletionCreateArgs)
91
92
93class CLICompletionCreateArgs(BaseModel):
94 model: str
95 stream: bool = False
96
97 prompt: Optional[str] = None
98 n: Omittable[int] = omit
99 stop: Omittable[str] = omit
100 user: Omittable[str] = omit
101 echo: Omittable[bool] = omit
102 suffix: Omittable[str] = omit
103 best_of: Omittable[int] = omit
104 top_p: Omittable[float] = omit
105 logprobs: Omittable[int] = omit
106 max_tokens: Omittable[int] = omit
107 temperature: Omittable[float] = omit
108 presence_penalty: Omittable[float] = omit
109 frequency_penalty: Omittable[float] = omit
110
111
112class CLICompletions:
113 @staticmethod
114 def create(args: CLICompletionCreateArgs) -> None:
115 if is_given(args.n) and args.n > 1 and args.stream:
116 raise CLIError("Can't stream completions with n>1 with the current CLI")
117
118 make_request = partial(
119 get_client().completions.create,
120 n=args.n,
121 echo=args.echo,
122 stop=args.stop,
123 user=args.user,
124 model=args.model,
125 top_p=args.top_p,
126 prompt=args.prompt,
127 suffix=args.suffix,
128 best_of=args.best_of,
129 logprobs=args.logprobs,
130 max_tokens=args.max_tokens,
131 temperature=args.temperature,
132 presence_penalty=args.presence_penalty,
133 frequency_penalty=args.frequency_penalty,
134 )
135
136 if args.stream:
137 return CLICompletions._stream_create(
138 # mypy doesn't understand the `partial` function but pyright does
139 cast(Stream[Completion], make_request(stream=True)) # pyright: ignore[reportUnnecessaryCast]
140 )
141
142 return CLICompletions._create(make_request())
143
144 @staticmethod
145 def _create(completion: Completion) -> None:
146 should_print_header = len(completion.choices) > 1
147 for choice in completion.choices:
148 if should_print_header:
149 sys.stdout.write("===== Completion {} =====\n".format(choice.index))
150
151 sys.stdout.write(choice.text)
152
153 if should_print_header or not choice.text.endswith("\n"):
154 sys.stdout.write("\n")
155
156 sys.stdout.flush()
157
158 @staticmethod
159 def _stream_create(stream: Stream[Completion]) -> None:
160 for completion in stream:
161 should_print_header = len(completion.choices) > 1
162 for choice in sorted(completion.choices, key=lambda c: c.index):
163 if should_print_header:
164 sys.stdout.write("===== Chat Completion {} =====\n".format(choice.index))
165
166 sys.stdout.write(choice.text)
167
168 if should_print_header:
169 sys.stdout.write("\n")
170
171 sys.stdout.flush()
172
173 sys.stdout.write("\n")