main
1from __future__ import annotations
2
3from typing import TYPE_CHECKING, Any, cast
4from argparse import ArgumentParser
5
6from .._utils import get_client, print_model
7from ..._types import Omit, Omittable, omit
8from .._models import BaseModel
9from .._progress import BufferReader
10
11if TYPE_CHECKING:
12 from argparse import _SubParsersAction
13
14
15def register(subparser: _SubParsersAction[ArgumentParser]) -> None:
16 sub = subparser.add_parser("images.generate")
17 sub.add_argument("-m", "--model", type=str)
18 sub.add_argument("-p", "--prompt", type=str, required=True)
19 sub.add_argument("-n", "--num-images", type=int, default=1)
20 sub.add_argument("-s", "--size", type=str, default="1024x1024", help="Size of the output image")
21 sub.add_argument("--response-format", type=str, default="url")
22 sub.set_defaults(func=CLIImage.create, args_model=CLIImageCreateArgs)
23
24 sub = subparser.add_parser("images.edit")
25 sub.add_argument("-m", "--model", type=str)
26 sub.add_argument("-p", "--prompt", type=str, required=True)
27 sub.add_argument("-n", "--num-images", type=int, default=1)
28 sub.add_argument(
29 "-I",
30 "--image",
31 type=str,
32 required=True,
33 help="Image to modify. Should be a local path and a PNG encoded image.",
34 )
35 sub.add_argument("-s", "--size", type=str, default="1024x1024", help="Size of the output image")
36 sub.add_argument("--response-format", type=str, default="url")
37 sub.add_argument(
38 "-M",
39 "--mask",
40 type=str,
41 required=False,
42 help="Path to a mask image. It should be the same size as the image you're editing and a RGBA PNG image. The Alpha channel acts as the mask.",
43 )
44 sub.set_defaults(func=CLIImage.edit, args_model=CLIImageEditArgs)
45
46 sub = subparser.add_parser("images.create_variation")
47 sub.add_argument("-m", "--model", type=str)
48 sub.add_argument("-n", "--num-images", type=int, default=1)
49 sub.add_argument(
50 "-I",
51 "--image",
52 type=str,
53 required=True,
54 help="Image to modify. Should be a local path and a PNG encoded image.",
55 )
56 sub.add_argument("-s", "--size", type=str, default="1024x1024", help="Size of the output image")
57 sub.add_argument("--response-format", type=str, default="url")
58 sub.set_defaults(func=CLIImage.create_variation, args_model=CLIImageCreateVariationArgs)
59
60
61class CLIImageCreateArgs(BaseModel):
62 prompt: str
63 num_images: int
64 size: str
65 response_format: str
66 model: Omittable[str] = omit
67
68
69class CLIImageCreateVariationArgs(BaseModel):
70 image: str
71 num_images: int
72 size: str
73 response_format: str
74 model: Omittable[str] = omit
75
76
77class CLIImageEditArgs(BaseModel):
78 image: str
79 num_images: int
80 size: str
81 response_format: str
82 prompt: str
83 mask: Omittable[str] = omit
84 model: Omittable[str] = omit
85
86
87class CLIImage:
88 @staticmethod
89 def create(args: CLIImageCreateArgs) -> None:
90 image = get_client().images.generate(
91 model=args.model,
92 prompt=args.prompt,
93 n=args.num_images,
94 # casts required because the API is typed for enums
95 # but we don't want to validate that here for forwards-compat
96 size=cast(Any, args.size),
97 response_format=cast(Any, args.response_format),
98 )
99 print_model(image)
100
101 @staticmethod
102 def create_variation(args: CLIImageCreateVariationArgs) -> None:
103 with open(args.image, "rb") as file_reader:
104 buffer_reader = BufferReader(file_reader.read(), desc="Upload progress")
105
106 image = get_client().images.create_variation(
107 model=args.model,
108 image=("image", buffer_reader),
109 n=args.num_images,
110 # casts required because the API is typed for enums
111 # but we don't want to validate that here for forwards-compat
112 size=cast(Any, args.size),
113 response_format=cast(Any, args.response_format),
114 )
115 print_model(image)
116
117 @staticmethod
118 def edit(args: CLIImageEditArgs) -> None:
119 with open(args.image, "rb") as file_reader:
120 buffer_reader = BufferReader(file_reader.read(), desc="Image upload progress")
121
122 if isinstance(args.mask, Omit):
123 mask: Omittable[BufferReader] = omit
124 else:
125 with open(args.mask, "rb") as file_reader:
126 mask = BufferReader(file_reader.read(), desc="Mask progress")
127
128 image = get_client().images.edit(
129 model=args.model,
130 prompt=args.prompt,
131 image=("image", buffer_reader),
132 n=args.num_images,
133 mask=("mask", mask) if not isinstance(mask, Omit) else mask,
134 # casts required because the API is typed for enums
135 # but we don't want to validate that here for forwards-compat
136 size=cast(Any, args.size),
137 response_format=cast(Any, args.response_format),
138 )
139 print_model(image)