main
  1from __future__ import annotations
  2
  3import sys
  4from typing import TYPE_CHECKING, Optional, cast
  5from argparse import ArgumentParser
  6from functools import partial
  7
  8from openai.types.completion import Completion
  9
 10from .._utils import get_client
 11from ..._types import Omittable, omit
 12from ..._utils import is_given
 13from .._errors import CLIError
 14from .._models import BaseModel
 15from ..._streaming import Stream
 16
 17if TYPE_CHECKING:
 18    from argparse import _SubParsersAction
 19
 20
 21def register(subparser: _SubParsersAction[ArgumentParser]) -> None:
 22    sub = subparser.add_parser("completions.create")
 23
 24    # Required
 25    sub.add_argument(
 26        "-m",
 27        "--model",
 28        help="The model to use",
 29        required=True,
 30    )
 31
 32    # Optional
 33    sub.add_argument("-p", "--prompt", help="An optional prompt to complete from")
 34    sub.add_argument("--stream", help="Stream tokens as they're ready.", action="store_true")
 35    sub.add_argument("-M", "--max-tokens", help="The maximum number of tokens to generate", type=int)
 36    sub.add_argument(
 37        "-t",
 38        "--temperature",
 39        help="""What sampling temperature to use. Higher values means the model will take more risks. Try 0.9 for more creative applications, and 0 (argmax sampling) for ones with a well-defined answer.
 40
 41Mutually exclusive with `top_p`.""",
 42        type=float,
 43    )
 44    sub.add_argument(
 45        "-P",
 46        "--top_p",
 47        help="""An alternative to sampling with temperature, called nucleus sampling, where the considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10%% probability mass are considered.
 48
 49            Mutually exclusive with `temperature`.""",
 50        type=float,
 51    )
 52    sub.add_argument(
 53        "-n",
 54        "--n",
 55        help="How many sub-completions to generate for each prompt.",
 56        type=int,
 57    )
 58    sub.add_argument(
 59        "--logprobs",
 60        help="Include the log probabilities on the `logprobs` most likely tokens, as well the chosen tokens. So for example, if `logprobs` is 10, the API will return a list of the 10 most likely tokens. If `logprobs` is 0, only the chosen tokens will have logprobs returned.",
 61        type=int,
 62    )
 63    sub.add_argument(
 64        "--best_of",
 65        help="Generates `best_of` completions server-side and returns the 'best' (the one with the highest log probability per token). Results cannot be streamed.",
 66        type=int,
 67    )
 68    sub.add_argument(
 69        "--echo",
 70        help="Echo back the prompt in addition to the completion",
 71        action="store_true",
 72    )
 73    sub.add_argument(
 74        "--frequency_penalty",
 75        help="Positive values penalize new tokens based on their existing frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim.",
 76        type=float,
 77    )
 78    sub.add_argument(
 79        "--presence_penalty",
 80        help="Positive values penalize new tokens based on whether they appear in the text so far, increasing the model's likelihood to talk about new topics.",
 81        type=float,
 82    )
 83    sub.add_argument("--suffix", help="The suffix that comes after a completion of inserted text.")
 84    sub.add_argument("--stop", help="A stop sequence at which to stop generating tokens.")
 85    sub.add_argument(
 86        "--user",
 87        help="A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse.",
 88    )
 89    # TODO: add support for logit_bias
 90    sub.set_defaults(func=CLICompletions.create, args_model=CLICompletionCreateArgs)
 91
 92
 93class CLICompletionCreateArgs(BaseModel):
 94    model: str
 95    stream: bool = False
 96
 97    prompt: Optional[str] = None
 98    n: Omittable[int] = omit
 99    stop: Omittable[str] = omit
100    user: Omittable[str] = omit
101    echo: Omittable[bool] = omit
102    suffix: Omittable[str] = omit
103    best_of: Omittable[int] = omit
104    top_p: Omittable[float] = omit
105    logprobs: Omittable[int] = omit
106    max_tokens: Omittable[int] = omit
107    temperature: Omittable[float] = omit
108    presence_penalty: Omittable[float] = omit
109    frequency_penalty: Omittable[float] = omit
110
111
112class CLICompletions:
113    @staticmethod
114    def create(args: CLICompletionCreateArgs) -> None:
115        if is_given(args.n) and args.n > 1 and args.stream:
116            raise CLIError("Can't stream completions with n>1 with the current CLI")
117
118        make_request = partial(
119            get_client().completions.create,
120            n=args.n,
121            echo=args.echo,
122            stop=args.stop,
123            user=args.user,
124            model=args.model,
125            top_p=args.top_p,
126            prompt=args.prompt,
127            suffix=args.suffix,
128            best_of=args.best_of,
129            logprobs=args.logprobs,
130            max_tokens=args.max_tokens,
131            temperature=args.temperature,
132            presence_penalty=args.presence_penalty,
133            frequency_penalty=args.frequency_penalty,
134        )
135
136        if args.stream:
137            return CLICompletions._stream_create(
138                # mypy doesn't understand the `partial` function but pyright does
139                cast(Stream[Completion], make_request(stream=True))  # pyright: ignore[reportUnnecessaryCast]
140            )
141
142        return CLICompletions._create(make_request())
143
144    @staticmethod
145    def _create(completion: Completion) -> None:
146        should_print_header = len(completion.choices) > 1
147        for choice in completion.choices:
148            if should_print_header:
149                sys.stdout.write("===== Completion {} =====\n".format(choice.index))
150
151            sys.stdout.write(choice.text)
152
153            if should_print_header or not choice.text.endswith("\n"):
154                sys.stdout.write("\n")
155
156            sys.stdout.flush()
157
158    @staticmethod
159    def _stream_create(stream: Stream[Completion]) -> None:
160        for completion in stream:
161            should_print_header = len(completion.choices) > 1
162            for choice in sorted(completion.choices, key=lambda c: c.index):
163                if should_print_header:
164                    sys.stdout.write("===== Chat Completion {} =====\n".format(choice.index))
165
166                sys.stdout.write(choice.text)
167
168                if should_print_header:
169                    sys.stdout.write("\n")
170
171                sys.stdout.flush()
172
173        sys.stdout.write("\n")