main
  1from __future__ import annotations
  2
  3from typing import TYPE_CHECKING, Any, cast
  4from argparse import ArgumentParser
  5
  6from .._utils import get_client, print_model
  7from ..._types import Omit, Omittable, omit
  8from .._models import BaseModel
  9from .._progress import BufferReader
 10
 11if TYPE_CHECKING:
 12    from argparse import _SubParsersAction
 13
 14
 15def register(subparser: _SubParsersAction[ArgumentParser]) -> None:
 16    sub = subparser.add_parser("images.generate")
 17    sub.add_argument("-m", "--model", type=str)
 18    sub.add_argument("-p", "--prompt", type=str, required=True)
 19    sub.add_argument("-n", "--num-images", type=int, default=1)
 20    sub.add_argument("-s", "--size", type=str, default="1024x1024", help="Size of the output image")
 21    sub.add_argument("--response-format", type=str, default="url")
 22    sub.set_defaults(func=CLIImage.create, args_model=CLIImageCreateArgs)
 23
 24    sub = subparser.add_parser("images.edit")
 25    sub.add_argument("-m", "--model", type=str)
 26    sub.add_argument("-p", "--prompt", type=str, required=True)
 27    sub.add_argument("-n", "--num-images", type=int, default=1)
 28    sub.add_argument(
 29        "-I",
 30        "--image",
 31        type=str,
 32        required=True,
 33        help="Image to modify. Should be a local path and a PNG encoded image.",
 34    )
 35    sub.add_argument("-s", "--size", type=str, default="1024x1024", help="Size of the output image")
 36    sub.add_argument("--response-format", type=str, default="url")
 37    sub.add_argument(
 38        "-M",
 39        "--mask",
 40        type=str,
 41        required=False,
 42        help="Path to a mask image. It should be the same size as the image you're editing and a RGBA PNG image. The Alpha channel acts as the mask.",
 43    )
 44    sub.set_defaults(func=CLIImage.edit, args_model=CLIImageEditArgs)
 45
 46    sub = subparser.add_parser("images.create_variation")
 47    sub.add_argument("-m", "--model", type=str)
 48    sub.add_argument("-n", "--num-images", type=int, default=1)
 49    sub.add_argument(
 50        "-I",
 51        "--image",
 52        type=str,
 53        required=True,
 54        help="Image to modify. Should be a local path and a PNG encoded image.",
 55    )
 56    sub.add_argument("-s", "--size", type=str, default="1024x1024", help="Size of the output image")
 57    sub.add_argument("--response-format", type=str, default="url")
 58    sub.set_defaults(func=CLIImage.create_variation, args_model=CLIImageCreateVariationArgs)
 59
 60
 61class CLIImageCreateArgs(BaseModel):
 62    prompt: str
 63    num_images: int
 64    size: str
 65    response_format: str
 66    model: Omittable[str] = omit
 67
 68
 69class CLIImageCreateVariationArgs(BaseModel):
 70    image: str
 71    num_images: int
 72    size: str
 73    response_format: str
 74    model: Omittable[str] = omit
 75
 76
 77class CLIImageEditArgs(BaseModel):
 78    image: str
 79    num_images: int
 80    size: str
 81    response_format: str
 82    prompt: str
 83    mask: Omittable[str] = omit
 84    model: Omittable[str] = omit
 85
 86
 87class CLIImage:
 88    @staticmethod
 89    def create(args: CLIImageCreateArgs) -> None:
 90        image = get_client().images.generate(
 91            model=args.model,
 92            prompt=args.prompt,
 93            n=args.num_images,
 94            # casts required because the API is typed for enums
 95            # but we don't want to validate that here for forwards-compat
 96            size=cast(Any, args.size),
 97            response_format=cast(Any, args.response_format),
 98        )
 99        print_model(image)
100
101    @staticmethod
102    def create_variation(args: CLIImageCreateVariationArgs) -> None:
103        with open(args.image, "rb") as file_reader:
104            buffer_reader = BufferReader(file_reader.read(), desc="Upload progress")
105
106        image = get_client().images.create_variation(
107            model=args.model,
108            image=("image", buffer_reader),
109            n=args.num_images,
110            # casts required because the API is typed for enums
111            # but we don't want to validate that here for forwards-compat
112            size=cast(Any, args.size),
113            response_format=cast(Any, args.response_format),
114        )
115        print_model(image)
116
117    @staticmethod
118    def edit(args: CLIImageEditArgs) -> None:
119        with open(args.image, "rb") as file_reader:
120            buffer_reader = BufferReader(file_reader.read(), desc="Image upload progress")
121
122        if isinstance(args.mask, Omit):
123            mask: Omittable[BufferReader] = omit
124        else:
125            with open(args.mask, "rb") as file_reader:
126                mask = BufferReader(file_reader.read(), desc="Mask progress")
127
128        image = get_client().images.edit(
129            model=args.model,
130            prompt=args.prompt,
131            image=("image", buffer_reader),
132            n=args.num_images,
133            mask=("mask", mask) if not isinstance(mask, Omit) else mask,
134            # casts required because the API is typed for enums
135            # but we don't want to validate that here for forwards-compat
136            size=cast(Any, args.size),
137            response_format=cast(Any, args.response_format),
138        )
139        print_model(image)