Commit d59672b3

hallacy <hallacy@openai.com>
2022-10-22 01:29:27
Add Dalle Support (#147) (#131) tag: v0.24.0
* Add Dalle Support (#147)
1 parent d1769c1
Changed files (5)
openai/api_resources/__init__.py
@@ -2,8 +2,9 @@ from openai.api_resources.answer import Answer  # noqa: F401
 from openai.api_resources.classification import Classification  # noqa: F401
 from openai.api_resources.completion import Completion  # noqa: F401
 from openai.api_resources.customer import Customer  # noqa: F401
-from openai.api_resources.edit import Edit  # noqa: F401
+from openai.api_resources.dalle import DALLE  # noqa: F401
 from openai.api_resources.deployment import Deployment  # noqa: F401
+from openai.api_resources.edit import Edit  # noqa: F401
 from openai.api_resources.embedding import Embedding  # noqa: F401
 from openai.api_resources.engine import Engine  # noqa: F401
 from openai.api_resources.error_object import ErrorObject  # noqa: F401
openai/api_resources/dalle.py
@@ -0,0 +1,90 @@
+# WARNING: This interface is considered experimental and may changed in the future without warning.
+from typing import Any, List
+
+import openai
+from openai import api_requestor, util
+from openai.api_resources.abstract import APIResource
+
+
+class DALLE(APIResource):
+    OBJECT_NAME = "images"
+
+    @classmethod
+    def _get_url(cls, action):
+        return cls.class_url() + f"/{action}"
+
+    @classmethod
+    def generations(
+        cls,
+        **params,
+    ):
+        instance = cls()
+        return instance.request("post", cls._get_url("generations"), params)
+
+    @classmethod
+    def variations(
+        cls,
+        image,
+        api_key=None,
+        api_base=None,
+        api_type=None,
+        api_version=None,
+        organization=None,
+        **params,
+    ):
+        requestor = api_requestor.APIRequestor(
+            api_key,
+            api_base=api_base or openai.api_base,
+            api_type=api_type,
+            api_version=api_version,
+            organization=organization,
+        )
+        _, api_version = cls._get_api_type_and_version(api_type, api_version)
+
+        url = cls._get_url("variations")
+
+        files: List[Any] = []
+        for key, value in params.items():
+            files.append((key, (None, value)))
+        files.append(("image", ("image", image, "application/octet-stream")))
+
+        response, _, api_key = requestor.request("post", url, files=files)
+
+        return util.convert_to_openai_object(
+            response, api_key, api_version, organization
+        )
+
+    @classmethod
+    def edits(
+        cls,
+        image,
+        mask,
+        api_key=None,
+        api_base=None,
+        api_type=None,
+        api_version=None,
+        organization=None,
+        **params,
+    ):
+        requestor = api_requestor.APIRequestor(
+            api_key,
+            api_base=api_base or openai.api_base,
+            api_type=api_type,
+            api_version=api_version,
+            organization=organization,
+        )
+        _, api_version = cls._get_api_type_and_version(api_type, api_version)
+
+        url = cls._get_url("edits")
+
+        files: List[Any] = []
+        for key, value in params.items():
+            files.append((key, (None, value)))
+        files.append(("image", ("image", image, "application/octet-stream")))
+        files.append(("mask", ("mask", mask, "application/octet-stream")))
+
+        response, _, api_key = requestor.request("post", url, files=files)
+
+        return util.convert_to_openai_object(
+            response, api_key, api_version, organization
+        )
openai/__init__.py
@@ -6,6 +6,7 @@ import os
 from typing import Optional
 
 from openai.api_resources import (
+    DALLE,
     Answer,
     Classification,
     Completion,
@@ -50,6 +51,7 @@ __all__ = [
     "Completion",
     "Customer",
     "Edit",
+    "DALLE",
     "Deployment",
     "Embedding",
     "Engine",
openai/cli.py
@@ -229,6 +229,49 @@ class File:
         print(file)
 
 
+class DALLE:
+    @classmethod
+    def generations(cls, args):
+        resp = openai.DALLE.generations(
+            prompt=args.prompt,
+            model=args.model,
+            size=args.size,
+            num_images=args.num_images,
+            response_format=args.response_format,
+        )
+        print(resp)
+
+    @classmethod
+    def variations(cls, args):
+        with open(args.image, "rb") as file_reader:
+            buffer_reader = BufferReader(file_reader.read(), desc="Upload progress")
+        resp = openai.DALLE.variations(
+            image=buffer_reader,
+            model=args.model,
+            size=args.size,
+            num_images=args.num_images,
+            response_format=args.response_format,
+        )
+        print(resp)
+
+    @classmethod
+    def edits(cls, args):
+        with open(args.image, "rb") as file_reader:
+            image_reader = BufferReader(file_reader.read(), desc="Upload progress")
+        with open(args.mask, "rb") as file_reader:
+            mask_reader = BufferReader(file_reader.read(), desc="Upload progress")
+        resp = openai.DALLE.edits(
+            image=image_reader,
+            mask=mask_reader,
+            prompt=args.prompt,
+            model=args.model,
+            size=args.size,
+            num_images=args.num_images,
+            response_format=args.response_format,
+        )
+        print(resp)
+
+
 class Search:
     @classmethod
     def prepare_data(cls, args, purpose):
@@ -983,6 +1026,57 @@ Mutually exclusive with `top_p`.""",
     sub.add_argument("-i", "--id", required=True, help="The id of the fine-tune job")
     sub.set_defaults(func=FineTune.cancel)
 
+    # DALLE
+    sub = subparsers.add_parser("dalle.generations")
+    sub.add_argument("-m", "--model", type=str, default="image-alpha-001")
+    sub.add_argument("-p", "--prompt", type=str, required=True)
+    sub.add_argument("-n", "--num-images", type=int, default=1)
+    sub.add_argument(
+        "-s", "--size", type=str, default="1024x1024", help="Size of the output image"
+    )
+    sub.add_argument("--response-format", type=str, default="url")
+    sub.set_defaults(func=DALLE.generations)
+
+    sub = subparsers.add_parser("dalle.edits")
+    sub.add_argument("-m", "--model", type=str, default="image-alpha-001")
+    sub.add_argument("-p", "--prompt", type=str, required=True)
+    sub.add_argument("-n", "--num-images", type=int, default=1)
+    sub.add_argument(
+        "-I",
+        "--image",
+        type=str,
+        required=True,
+        help="Image to modify. Should be a local path and a PNG encoded image.",
+    )
+    sub.add_argument(
+        "-s", "--size", type=str, default="1024x1024", help="Size of the output image"
+    )
+    sub.add_argument("--response-format", type=str, default="url")
+    sub.add_argument(
+        "-M",
+        "--mask",
+        type=str,
+        required=True,
+        help="Path to a mask image. It should be the same size as the image you're editing and a RGBA PNG image. The Alpha channel acts as the mask.",
+    )
+    sub.set_defaults(func=DALLE.edits)
+
+    sub = subparsers.add_parser("dalle.variations")
+    sub.add_argument("-m", "--model", type=str, default="image-alpha-001")
+    sub.add_argument("-n", "--num-images", type=int, default=1)
+    sub.add_argument(
+        "-I",
+        "--image",
+        type=str,
+        required=True,
+        help="Image to modify. Should be a local path and a PNG encoded image.",
+    )
+    sub.add_argument(
+        "-s", "--size", type=str, default="1024x1024", help="Size of the output image"
+    )
+    sub.add_argument("--response-format", type=str, default="url")
+    sub.set_defaults(func=DALLE.variations)
+
 
 def wandb_register(parser):
     subparsers = parser.add_subparsers(
openai/version.py
@@ -1,1 +1,1 @@
-VERSION = "0.23.1"
+VERSION = "0.24.0"