Commit 62b51ca0

Boris Dayma <boris.dayma@gmail.com>
2022-02-02 03:58:29
feat: openai wandb sync (#64)
* feat: log fine_tune with wandb * feat: ensure we are logged in * feat: cli wandb namespace * feat: add fine_tuned_model to summary * feat: log training & validation files * feat: re-log if was not successful or force * doc: add docstring * feat: set wandb api only when needed * fix: train/validation files are inputs * feat: rename artifact type * feat: improve config logging * feat: log all jobs by default * feat: log job details * feat: log -> sync * feat: cli wandb log -> sync * fix: validation_files not always present * feat: format created_at + style * feat: log number of training/validation samples * feat(wandb): avoid download if file already synced * feat(wandb): add number of items to metadata * fix(wandb): allow force sync * feat(wandb): job -> fine-tune * refactor(wandb): use show_individual_warnings * feat(wandb): Logger -> WandbLogger * feat(wandb): retrive number of items from artifact * doc(wandb): add link to documentation
1 parent f288b00
openai/_openai_scripts.py
@@ -4,7 +4,7 @@ import logging
 import sys
 
 import openai
-from openai.cli import api_register, display_error, tools_register
+from openai.cli import api_register, display_error, tools_register, wandb_register
 
 logger = logging.getLogger()
 formatter = logging.Formatter("[%(asctime)s] %(message)s")
@@ -39,9 +39,11 @@ def main():
     subparsers = parser.add_subparsers()
     sub_api = subparsers.add_parser("api", help="Direct API calls")
     sub_tools = subparsers.add_parser("tools", help="Client side tools for convenience")
+    sub_wandb = subparsers.add_parser("wandb", help="Logging with Weights & Biases")
 
     api_register(sub_api)
     tools_register(sub_tools)
+    wandb_register(sub_wandb)
 
     args = parser.parse_args()
     if args.verbosity == 1:
openai/cli.py
@@ -19,6 +19,7 @@ from openai.validators import (
     write_out_file,
     write_out_search_file,
 )
+import openai.wandb_logger
 
 
 class bcolors:
@@ -535,6 +536,19 @@ class FineTune:
         )
 
 
+class WandbLogger:
+    @classmethod
+    def sync(cls, args):
+        resp = openai.wandb_logger.WandbLogger.sync(
+            id=args.id,
+            n_fine_tunes=args.n_fine_tunes,
+            project=args.project,
+            entity=args.entity,
+            force=args.force,
+        )
+        print(resp)
+
+
 def tools_register(parser):
     subparsers = parser.add_subparsers(
         title="Tools", help="Convenience client side tools"
@@ -954,3 +968,40 @@ Mutually exclusive with `top_p`.""",
     sub = subparsers.add_parser("fine_tunes.cancel")
     sub.add_argument("-i", "--id", required=True, help="The id of the fine-tune job")
     sub.set_defaults(func=FineTune.cancel)
+
+
+def wandb_register(parser):
+    subparsers = parser.add_subparsers(
+        title="wandb", help="Logging with Weights & Biases"
+    )
+
+    def help(args):
+        parser.print_help()
+
+    parser.set_defaults(func=help)
+
+    sub = subparsers.add_parser("sync")
+    sub.add_argument("-i", "--id", help="The id of the fine-tune job (optional)")
+    sub.add_argument(
+        "-n",
+        "--n_fine_tunes",
+        type=int,
+        default=None,
+        help="Number of most recent fine-tunes to log when an id is not provided. By default, every fine-tune is synced.",
+    )
+    sub.add_argument(
+        "--project",
+        default="GPT-3",
+        help="""Name of the project where you're sending runs. By default, it is "GPT-3".""",
+    )
+    sub.add_argument(
+        "--entity",
+        help="Username or team name where you're sending runs. By default, your default entity is used, which is usually your username.",
+    )
+    sub.add_argument(
+        "--force",
+        action="store_true",
+        help="Forces logging and overwrite existing wandb run of the same fine-tune.",
+    )
+    sub.set_defaults(force=False)
+    sub.set_defaults(func=WandbLogger.sync)
openai/wandb_logger.py
@@ -0,0 +1,290 @@
+try:
+    import wandb
+
+    WANDB_AVAILABLE = True
+except:
+    WANDB_AVAILABLE = False
+
+
+if WANDB_AVAILABLE:
+    import datetime
+    import io
+    import json
+    from pathlib import Path
+
+    import numpy as np
+    import pandas as pd
+
+    from openai import File, FineTune
+
+
+class WandbLogger:
+    """
+    Log fine-tunes to [Weights & Biases](https://wandb.me/openai-docs)
+    """
+
+    if not WANDB_AVAILABLE:
+        print("Logging requires wandb to be installed. Run `pip install wandb`.")
+    else:
+        _wandb_api = None
+        _logged_in = False
+
+    @classmethod
+    def sync(
+        cls,
+        id=None,
+        n_fine_tunes=None,
+        project="GPT-3",
+        entity=None,
+        force=False,
+        **kwargs_wandb_init,
+    ):
+        """
+        Sync fine-tunes to Weights & Biases.
+        :param id: The id of the fine-tune (optional)
+        :param n_fine_tunes: Number of most recent fine-tunes to log when an id is not provided. By default, every fine-tune is synced.
+        :param project: Name of the project where you're sending runs. By default, it is "GPT-3".
+        :param entity: Username or team name where you're sending runs. By default, your default entity is used, which is usually your username.
+        :param force: Forces logging and overwrite existing wandb run of the same fine-tune.
+        """
+
+        if not WANDB_AVAILABLE:
+            return
+
+        if id:
+            fine_tune = FineTune.retrieve(id=id)
+            fine_tune.pop("events", None)
+            fine_tunes = [fine_tune]
+
+        else:
+            # get list of fine_tune to log
+            fine_tunes = FineTune.list()
+            if not fine_tunes or fine_tunes.get("data") is None:
+                print("No fine-tune has been retrieved")
+                return
+            fine_tunes = fine_tunes["data"][
+                -n_fine_tunes if n_fine_tunes is not None else None :
+            ]
+
+        # log starting from oldest fine_tune
+        show_individual_warnings = (
+            False if id is None and n_fine_tunes is None else True
+        )
+        fine_tune_logged = [
+            cls._log_fine_tune(
+                fine_tune,
+                project,
+                entity,
+                force,
+                show_individual_warnings,
+                **kwargs_wandb_init,
+            )
+            for fine_tune in fine_tunes
+        ]
+
+        if not show_individual_warnings and not any(fine_tune_logged):
+            print("No new successful fine-tunes were found")
+
+        return "🎉 wandb sync completed successfully"
+
+    @classmethod
+    def _log_fine_tune(
+        cls,
+        fine_tune,
+        project,
+        entity,
+        force,
+        show_individual_warnings,
+        **kwargs_wandb_init,
+    ):
+        fine_tune_id = fine_tune.get("id")
+        status = fine_tune.get("status")
+
+        # check run completed successfully
+        if show_individual_warnings and status != "succeeded":
+            print(
+                f'Fine-tune {fine_tune_id} has the status "{status}" and will not be logged'
+            )
+            return
+
+        # check run has not been logged already
+        run_path = f"{project}/{fine_tune_id}"
+        if entity is not None:
+            run_path = f"{entity}/{run_path}"
+        wandb_run = cls._get_wandb_run(run_path)
+        if wandb_run:
+            wandb_status = wandb_run.summary.get("status")
+            if show_individual_warnings:
+                if wandb_status == "succeeded":
+                    print(
+                        f"Fine-tune {fine_tune_id} has already been logged successfully at {wandb_run.url}"
+                    )
+                    if not force:
+                        print(
+                            'Use "--force" in the CLI or "force=True" in python if you want to overwrite previous run'
+                        )
+                else:
+                    print(
+                        f"A run for fine-tune {fine_tune_id} was previously created but didn't end successfully"
+                    )
+                if wandb_status != "succeeded" or force:
+                    print(
+                        f"A new wandb run will be created for fine-tune {fine_tune_id} and previous run will be overwritten"
+                    )
+            if wandb_status == "succeeded" and not force:
+                return
+
+        # retrieve results
+        results_id = fine_tune["result_files"][0]["id"]
+        results = File.download(id=results_id).decode("utf-8")
+
+        # start a wandb run
+        wandb.init(
+            job_type="fine-tune",
+            config=cls._get_config(fine_tune),
+            project=project,
+            entity=entity,
+            name=fine_tune_id,
+            id=fine_tune_id,
+            **kwargs_wandb_init,
+        )
+
+        # log results
+        df_results = pd.read_csv(io.StringIO(results))
+        for _, row in df_results.iterrows():
+            metrics = {k: v for k, v in row.items() if not np.isnan(v)}
+            step = metrics.pop("step")
+            if step is not None:
+                step = int(step)
+            wandb.log(metrics, step=step)
+        fine_tuned_model = fine_tune.get("fine_tuned_model")
+        if fine_tuned_model is not None:
+            wandb.summary["fine_tuned_model"] = fine_tuned_model
+
+        # training/validation files and fine-tune details
+        cls._log_artifacts(fine_tune, project, entity)
+
+        # mark run as complete
+        wandb.summary["status"] = "succeeded"
+
+        wandb.finish()
+        return True
+
+    @classmethod
+    def _ensure_logged_in(cls):
+        if not cls._logged_in:
+            if wandb.login():
+                cls._logged_in = True
+            else:
+                raise Exception("You need to log in to wandb")
+
+    @classmethod
+    def _get_wandb_run(cls, run_path):
+        cls._ensure_logged_in()
+        try:
+            if cls._wandb_api is None:
+                cls._wandb_api = wandb.Api()
+            return cls._wandb_api.run(run_path)
+        except Exception:
+            return None
+
+    @classmethod
+    def _get_wandb_artifact(cls, artifact_path):
+        cls._ensure_logged_in()
+        try:
+            if cls._wandb_api is None:
+                cls._wandb_api = wandb.Api()
+            return cls._wandb_api.artifact(artifact_path)
+        except Exception:
+            return None
+
+    @classmethod
+    def _get_config(cls, fine_tune):
+        config = dict(fine_tune)
+        for key in ("training_files", "validation_files", "result_files"):
+            if config.get(key) and len(config[key]):
+                config[key] = config[key][0]
+        if config.get("created_at"):
+            config["created_at"] = datetime.datetime.fromtimestamp(config["created_at"])
+        return config
+
+    @classmethod
+    def _log_artifacts(cls, fine_tune, project, entity):
+        # training/validation files
+        training_file = (
+            fine_tune["training_files"][0]
+            if fine_tune.get("training_files") and len(fine_tune["training_files"])
+            else None
+        )
+        validation_file = (
+            fine_tune["validation_files"][0]
+            if fine_tune.get("validation_files") and len(fine_tune["validation_files"])
+            else None
+        )
+        for file, prefix, artifact_type in (
+            (training_file, "train", "training_files"),
+            (validation_file, "valid", "validation_files"),
+        ):
+            if file is not None:
+                cls._log_artifact_inputs(file, prefix, artifact_type, project, entity)
+
+        # fine-tune details
+        fine_tune_id = fine_tune.get("id")
+        artifact = wandb.Artifact(
+            "fine_tune_details",
+            type="fine_tune_details",
+            metadata=fine_tune,
+        )
+        with artifact.new_file("fine_tune_details.json") as f:
+            json.dump(fine_tune, f, indent=2)
+        wandb.run.log_artifact(
+            artifact,
+            aliases=["latest", fine_tune_id],
+        )
+
+    @classmethod
+    def _log_artifact_inputs(cls, file, prefix, artifact_type, project, entity):
+        file_id = file["id"]
+        filename = Path(file["filename"]).name
+        stem = Path(file["filename"]).stem
+
+        # get input artifact
+        artifact_name = f"{prefix}-{filename}"
+        artifact_alias = file_id
+        artifact_path = f"{project}/{artifact_name}:{artifact_alias}"
+        if entity is not None:
+            artifact_path = f"{entity}/{artifact_path}"
+        artifact = cls._get_wandb_artifact(artifact_path)
+
+        # create artifact if file not already logged previously
+        if artifact is None:
+            # get file content
+            try:
+                file_content = File.download(id=file_id).decode("utf-8")
+            except:
+                print(
+                    f"File {file_id} could not be retrieved. Make sure you are allowed to download training/validation files"
+                )
+                return
+            artifact = wandb.Artifact(artifact_name, type=artifact_type, metadata=file)
+            with artifact.new_file(filename, mode="w") as f:
+                f.write(file_content)
+
+            # create a Table
+            try:
+                table, n_items = cls._make_table(file_content)
+                artifact.add(table, stem)
+                wandb.config.update({f"n_{prefix}": n_items})
+                artifact.metadata["items"] = n_items
+            except:
+                print(f"File {file_id} could not be read as a valid JSON file")
+        else:
+            # log number of items
+            wandb.config.update({f"n_{prefix}": artifact.metadata.get("items")})
+
+        wandb.run.use_artifact(artifact, aliases=["latest", artifact_alias])
+
+    @classmethod
+    def _make_table(cls, file_content):
+        df = pd.read_json(io.StringIO(file_content), orient="records", lines=True)
+        return wandb.Table(dataframe=df), len(df)
README.md
@@ -76,6 +76,7 @@ search = openai.Engine(id="deployment-namme").search(documents=["White House", "
 # print the search
 print(search)
 ```
+
 Please note that for the moment, the Microsoft Azure endpoints can only be used for completion and search operations.
 
 ### Command-line interface
@@ -142,6 +143,12 @@ Examples of fine tuning are shared in the following Jupyter notebooks:
   - [Step 2: Creating a synthetic Q&A dataset](https://github.com/openai/openai-python/blob/main/examples/finetuning/olympics-2-create-qa.ipynb)
   - [Step 3: Train a fine-tuning model specialized for Q&A](https://github.com/openai/openai-python/blob/main/examples/finetuning/olympics-3-train-qa.ipynb)
 
+Sync your fine-tunes to [Weights & Biases](https://wandb.me/openai-docs) to track experiments, models, and datasets in your central dashboard with:
+
+```bash
+openai wandb sync
+```
+
 For more information on fine tuning, read the [fine-tuning guide](https://beta.openai.com/docs/guides/fine-tuning) in the OpenAI documentation.
 
 ## Requirements