Commit df8b0ba7

John Allard <john@jhallard.com>
2023-08-23 00:19:06
Updates to the fine tuning SDK + addition of pagination primitives (#582) tag: v0.27.9
* Add support for new fine_tuning SDK + pagination primitives * typo
1 parent b8dfa35
openai/api_resources/abstract/__init__.py
@@ -7,4 +7,7 @@ from openai.api_resources.abstract.listable_api_resource import ListableAPIResou
 from openai.api_resources.abstract.nested_resource_class_methods import (
     nested_resource_class_methods,
 )
+from openai.api_resources.abstract.paginatable_api_resource import (
+    PaginatableAPIResource,
+)
 from openai.api_resources.abstract.updateable_api_resource import UpdateableAPIResource
openai/api_resources/abstract/nested_resource_class_methods.py
@@ -124,6 +124,19 @@ def _nested_resource_class_methods(
                 list_method = "list_%s" % resource_plural
                 setattr(cls, list_method, classmethod(list_nested_resources))
 
+            elif operation == "paginated_list":
+
+                def paginated_list_nested_resources(
+                    cls, id, limit=None, cursor=None, **params
+                ):
+                    url = getattr(cls, resource_url_method)(id)
+                    return getattr(cls, resource_request_method)(
+                        "get", url, limit=limit, cursor=cursor, **params
+                    )
+
+                list_method = "list_%s" % resource_plural
+                setattr(cls, list_method, classmethod(paginated_list_nested_resources))
+
             else:
                 raise ValueError("Unknown operation: %s" % operation)
 
openai/api_resources/abstract/paginatable_api_resource.py
@@ -0,0 +1,125 @@
+from openai import api_requestor, error, util
+from openai.api_resources.abstract.listable_api_resource import ListableAPIResource
+from openai.util import ApiType
+
+
+class PaginatableAPIResource(ListableAPIResource):
+    @classmethod
+    def auto_paging_iter(cls, *args, **params):
+        next_cursor = None
+        has_more = True
+        if not params.get("limit"):
+            params["limit"] = 20
+        while has_more:
+            if next_cursor:
+                params["after"] = next_cursor
+            response = cls.list(*args, **params)
+
+            for item in response.data:
+                yield item
+
+            if response.data:
+                next_cursor = response.data[-1].id
+            has_more = response.has_more
+
+    @classmethod
+    def __prepare_list_requestor(
+        cls,
+        api_key=None,
+        api_version=None,
+        organization=None,
+        api_base=None,
+        api_type=None,
+    ):
+        requestor = api_requestor.APIRequestor(
+            api_key,
+            api_base=api_base or cls.api_base(),
+            api_version=api_version,
+            api_type=api_type,
+            organization=organization,
+        )
+
+        typed_api_type, api_version = cls._get_api_type_and_version(
+            api_type, api_version
+        )
+
+        if typed_api_type in (ApiType.AZURE, ApiType.AZURE_AD):
+            base = cls.class_url()
+            url = "/%s%s?api-version=%s" % (cls.azure_api_prefix, base, api_version)
+        elif typed_api_type == ApiType.OPEN_AI:
+            url = cls.class_url()
+        else:
+            raise error.InvalidAPIType("Unsupported API type %s" % api_type)
+        return requestor, url
+
+    @classmethod
+    def list(
+        cls,
+        limit=None,
+        starting_after=None,
+        api_key=None,
+        request_id=None,
+        api_version=None,
+        organization=None,
+        api_base=None,
+        api_type=None,
+        **params,
+    ):
+        requestor, url = cls.__prepare_list_requestor(
+            api_key,
+            api_version,
+            organization,
+            api_base,
+            api_type,
+        )
+
+        params = {
+            **params,
+            "limit": limit,
+            "starting_after": starting_after,
+        }
+
+        response, _, api_key = requestor.request(
+            "get", url, params, request_id=request_id
+        )
+        openai_object = util.convert_to_openai_object(
+            response, api_key, api_version, organization
+        )
+        openai_object._retrieve_params = params
+        return openai_object
+
+    @classmethod
+    async def alist(
+        cls,
+        limit=None,
+        starting_after=None,
+        api_key=None,
+        request_id=None,
+        api_version=None,
+        organization=None,
+        api_base=None,
+        api_type=None,
+        **params,
+    ):
+        requestor, url = cls.__prepare_list_requestor(
+            api_key,
+            api_version,
+            organization,
+            api_base,
+            api_type,
+        )
+
+        params = {
+            **params,
+            "limit": limit,
+            "starting_after": starting_after,
+        }
+
+        response, _, api_key = await requestor.arequest(
+            "get", url, params, request_id=request_id
+        )
+        openai_object = util.convert_to_openai_object(
+            response, api_key, api_version, organization
+        )
+        openai_object._retrieve_params = params
+        return openai_object
openai/api_resources/__init__.py
@@ -9,6 +9,7 @@ from openai.api_resources.engine import Engine  # noqa: F401
 from openai.api_resources.error_object import ErrorObject  # noqa: F401
 from openai.api_resources.file import File  # noqa: F401
 from openai.api_resources.fine_tune import FineTune  # noqa: F401
+from openai.api_resources.fine_tuning import FineTuningJob  # noqa: F401
 from openai.api_resources.image import Image  # noqa: F401
 from openai.api_resources.model import Model  # noqa: F401
 from openai.api_resources.moderation import Moderation  # noqa: F401
openai/api_resources/fine_tuning.py
@@ -0,0 +1,88 @@
+from urllib.parse import quote_plus
+
+from openai import error
+from openai.api_resources.abstract import (
+    CreateableAPIResource,
+    PaginatableAPIResource,
+    nested_resource_class_methods,
+)
+from openai.api_resources.abstract.deletable_api_resource import DeletableAPIResource
+from openai.util import ApiType
+
+
+@nested_resource_class_methods("event", operations=["paginated_list"])
+class FineTuningJob(
+    PaginatableAPIResource, CreateableAPIResource, DeletableAPIResource
+):
+    OBJECT_NAME = "fine_tuning.jobs"
+
+    @classmethod
+    def _prepare_cancel(
+        cls,
+        id,
+        api_key=None,
+        api_type=None,
+        request_id=None,
+        api_version=None,
+        **params,
+    ):
+        base = cls.class_url()
+        extn = quote_plus(id)
+
+        typed_api_type, api_version = cls._get_api_type_and_version(
+            api_type, api_version
+        )
+        if typed_api_type in (ApiType.AZURE, ApiType.AZURE_AD):
+            url = "/%s%s/%s/cancel?api-version=%s" % (
+                cls.azure_api_prefix,
+                base,
+                extn,
+                api_version,
+            )
+        elif typed_api_type == ApiType.OPEN_AI:
+            url = "%s/%s/cancel" % (base, extn)
+        else:
+            raise error.InvalidAPIType("Unsupported API type %s" % api_type)
+
+        instance = cls(id, api_key, **params)
+        return instance, url
+
+    @classmethod
+    def cancel(
+        cls,
+        id,
+        api_key=None,
+        api_type=None,
+        request_id=None,
+        api_version=None,
+        **params,
+    ):
+        instance, url = cls._prepare_cancel(
+            id,
+            api_key,
+            api_type,
+            request_id,
+            api_version,
+            **params,
+        )
+        return instance.request("post", url, request_id=request_id)
+
+    @classmethod
+    def acancel(
+        cls,
+        id,
+        api_key=None,
+        api_type=None,
+        request_id=None,
+        api_version=None,
+        **params,
+    ):
+        instance, url = cls._prepare_cancel(
+            id,
+            api_key,
+            api_type,
+            request_id,
+            api_version,
+            **params,
+        )
+        return instance.arequest("post", url, request_id=request_id)
openai/__init__.py
@@ -28,6 +28,7 @@ from openai.api_resources import (
     ErrorObject,
     File,
     FineTune,
+    FineTuningJob,
     Image,
     Model,
     Moderation,
@@ -84,6 +85,7 @@ __all__ = [
     "ErrorObject",
     "File",
     "FineTune",
+    "FineTuningJob",
     "InvalidRequestError",
     "Model",
     "Moderation",
openai/object_classes.py
@@ -8,4 +8,5 @@ OBJECT_CLASSES = {
     "fine-tune": api_resources.FineTune,
     "model": api_resources.Model,
     "deployment": api_resources.Deployment,
+    "fine_tuning.job": api_resources.FineTuningJob,
 }