main
  1# pyright: basic
  2from __future__ import annotations
  3
  4import os
  5import sys
  6from typing import Any, TypeVar, Callable, Optional, NamedTuple
  7from typing_extensions import TypeAlias
  8
  9from .._extras import pandas as pd
 10
 11
 12class Remediation(NamedTuple):
 13    name: str
 14    immediate_msg: Optional[str] = None
 15    necessary_msg: Optional[str] = None
 16    necessary_fn: Optional[Callable[[Any], Any]] = None
 17    optional_msg: Optional[str] = None
 18    optional_fn: Optional[Callable[[Any], Any]] = None
 19    error_msg: Optional[str] = None
 20
 21
 22OptionalDataFrameT = TypeVar("OptionalDataFrameT", bound="Optional[pd.DataFrame]")
 23
 24
 25def num_examples_validator(df: pd.DataFrame) -> Remediation:
 26    """
 27    This validator will only print out the number of examples and recommend to the user to increase the number of examples if less than 100.
 28    """
 29    MIN_EXAMPLES = 100
 30    optional_suggestion = (
 31        ""
 32        if len(df) >= MIN_EXAMPLES
 33        else ". In general, we recommend having at least a few hundred examples. We've found that performance tends to linearly increase for every doubling of the number of examples"
 34    )
 35    immediate_msg = f"\n- Your file contains {len(df)} prompt-completion pairs{optional_suggestion}"
 36    return Remediation(name="num_examples", immediate_msg=immediate_msg)
 37
 38
 39def necessary_column_validator(df: pd.DataFrame, necessary_column: str) -> Remediation:
 40    """
 41    This validator will ensure that the necessary column is present in the dataframe.
 42    """
 43
 44    def lower_case_column(df: pd.DataFrame, column: Any) -> pd.DataFrame:
 45        cols = [c for c in df.columns if str(c).lower() == column]
 46        df.rename(columns={cols[0]: column.lower()}, inplace=True)
 47        return df
 48
 49    immediate_msg = None
 50    necessary_fn = None
 51    necessary_msg = None
 52    error_msg = None
 53
 54    if necessary_column not in df.columns:
 55        if necessary_column in [str(c).lower() for c in df.columns]:
 56
 57            def lower_case_column_creator(df: pd.DataFrame) -> pd.DataFrame:
 58                return lower_case_column(df, necessary_column)
 59
 60            necessary_fn = lower_case_column_creator
 61            immediate_msg = f"\n- The `{necessary_column}` column/key should be lowercase"
 62            necessary_msg = f"Lower case column name to `{necessary_column}`"
 63        else:
 64            error_msg = f"`{necessary_column}` column/key is missing. Please make sure you name your columns/keys appropriately, then retry"
 65
 66    return Remediation(
 67        name="necessary_column",
 68        immediate_msg=immediate_msg,
 69        necessary_msg=necessary_msg,
 70        necessary_fn=necessary_fn,
 71        error_msg=error_msg,
 72    )
 73
 74
 75def additional_column_validator(df: pd.DataFrame, fields: list[str] = ["prompt", "completion"]) -> Remediation:
 76    """
 77    This validator will remove additional columns from the dataframe.
 78    """
 79    additional_columns = []
 80    necessary_msg = None
 81    immediate_msg = None
 82    necessary_fn = None  # type: ignore
 83
 84    if len(df.columns) > 2:
 85        additional_columns = [c for c in df.columns if c not in fields]
 86        warn_message = ""
 87        for ac in additional_columns:
 88            dups = [c for c in additional_columns if ac in c]
 89            if len(dups) > 0:
 90                warn_message += f"\n  WARNING: Some of the additional columns/keys contain `{ac}` in their name. These will be ignored, and the column/key `{ac}` will be used instead. This could also result from a duplicate column/key in the provided file."
 91        immediate_msg = f"\n- The input file should contain exactly two columns/keys per row. Additional columns/keys present are: {additional_columns}{warn_message}"
 92        necessary_msg = f"Remove additional columns/keys: {additional_columns}"
 93
 94        def necessary_fn(x: Any) -> Any:
 95            return x[fields]
 96
 97    return Remediation(
 98        name="additional_column",
 99        immediate_msg=immediate_msg,
100        necessary_msg=necessary_msg,
101        necessary_fn=necessary_fn,
102    )
103
104
105def non_empty_field_validator(df: pd.DataFrame, field: str = "completion") -> Remediation:
106    """
107    This validator will ensure that no completion is empty.
108    """
109    necessary_msg = None
110    necessary_fn = None  # type: ignore
111    immediate_msg = None
112
113    if df[field].apply(lambda x: x == "").any() or df[field].isnull().any():
114        empty_rows = (df[field] == "") | (df[field].isnull())
115        empty_indexes = df.reset_index().index[empty_rows].tolist()
116        immediate_msg = f"\n- `{field}` column/key should not contain empty strings. These are rows: {empty_indexes}"
117
118        def necessary_fn(x: Any) -> Any:
119            return x[x[field] != ""].dropna(subset=[field])
120
121        necessary_msg = f"Remove {len(empty_indexes)} rows with empty {field}s"
122
123    return Remediation(
124        name=f"empty_{field}",
125        immediate_msg=immediate_msg,
126        necessary_msg=necessary_msg,
127        necessary_fn=necessary_fn,
128    )
129
130
131def duplicated_rows_validator(df: pd.DataFrame, fields: list[str] = ["prompt", "completion"]) -> Remediation:
132    """
133    This validator will suggest to the user to remove duplicate rows if they exist.
134    """
135    duplicated_rows = df.duplicated(subset=fields)
136    duplicated_indexes = df.reset_index().index[duplicated_rows].tolist()
137    immediate_msg = None
138    optional_msg = None
139    optional_fn = None  # type: ignore
140
141    if len(duplicated_indexes) > 0:
142        immediate_msg = f"\n- There are {len(duplicated_indexes)} duplicated {'-'.join(fields)} sets. These are rows: {duplicated_indexes}"
143        optional_msg = f"Remove {len(duplicated_indexes)} duplicate rows"
144
145        def optional_fn(x: Any) -> Any:
146            return x.drop_duplicates(subset=fields)
147
148    return Remediation(
149        name="duplicated_rows",
150        immediate_msg=immediate_msg,
151        optional_msg=optional_msg,
152        optional_fn=optional_fn,
153    )
154
155
156def long_examples_validator(df: pd.DataFrame) -> Remediation:
157    """
158    This validator will suggest to the user to remove examples that are too long.
159    """
160    immediate_msg = None
161    optional_msg = None
162    optional_fn = None  # type: ignore
163
164    ft_type = infer_task_type(df)
165    if ft_type != "open-ended generation":
166
167        def get_long_indexes(d: pd.DataFrame) -> Any:
168            long_examples = d.apply(lambda x: len(x.prompt) + len(x.completion) > 10000, axis=1)
169            return d.reset_index().index[long_examples].tolist()
170
171        long_indexes = get_long_indexes(df)
172
173        if len(long_indexes) > 0:
174            immediate_msg = f"\n- There are {len(long_indexes)} examples that are very long. These are rows: {long_indexes}\nFor conditional generation, and for classification the examples shouldn't be longer than 2048 tokens."
175            optional_msg = f"Remove {len(long_indexes)} long examples"
176
177            def optional_fn(x: Any) -> Any:
178                long_indexes_to_drop = get_long_indexes(x)
179                if long_indexes != long_indexes_to_drop:
180                    sys.stdout.write(
181                        f"The indices of the long examples has changed as a result of a previously applied recommendation.\nThe {len(long_indexes_to_drop)} long examples to be dropped are now at the following indices: {long_indexes_to_drop}\n"
182                    )
183                return x.drop(long_indexes_to_drop)
184
185    return Remediation(
186        name="long_examples",
187        immediate_msg=immediate_msg,
188        optional_msg=optional_msg,
189        optional_fn=optional_fn,
190    )
191
192
193def common_prompt_suffix_validator(df: pd.DataFrame) -> Remediation:
194    """
195    This validator will suggest to add a common suffix to the prompt if one doesn't already exist in case of classification or conditional generation.
196    """
197    error_msg = None
198    immediate_msg = None
199    optional_msg = None
200    optional_fn = None  # type: ignore
201
202    # Find a suffix which is not contained within the prompt otherwise
203    suggested_suffix = "\n\n### =>\n\n"
204    suffix_options = [
205        " ->",
206        "\n\n###\n\n",
207        "\n\n===\n\n",
208        "\n\n---\n\n",
209        "\n\n===>\n\n",
210        "\n\n--->\n\n",
211    ]
212    for suffix_option in suffix_options:
213        if suffix_option == " ->":
214            if df.prompt.str.contains("\n").any():
215                continue
216        if df.prompt.str.contains(suffix_option, regex=False).any():
217            continue
218        suggested_suffix = suffix_option
219        break
220    display_suggested_suffix = suggested_suffix.replace("\n", "\\n")
221
222    ft_type = infer_task_type(df)
223    if ft_type == "open-ended generation":
224        return Remediation(name="common_suffix")
225
226    def add_suffix(x: Any, suffix: Any) -> Any:
227        x["prompt"] += suffix
228        return x
229
230    common_suffix = get_common_xfix(df.prompt, xfix="suffix")
231    if (df.prompt == common_suffix).all():
232        error_msg = f"All prompts are identical: `{common_suffix}`\nConsider leaving the prompts blank if you want to do open-ended generation, otherwise ensure prompts are different"
233        return Remediation(name="common_suffix", error_msg=error_msg)
234
235    if common_suffix != "":
236        common_suffix_new_line_handled = common_suffix.replace("\n", "\\n")
237        immediate_msg = f"\n- All prompts end with suffix `{common_suffix_new_line_handled}`"
238        if len(common_suffix) > 10:
239            immediate_msg += f". This suffix seems very long. Consider replacing with a shorter suffix, such as `{display_suggested_suffix}`"
240        if df.prompt.str[: -len(common_suffix)].str.contains(common_suffix, regex=False).any():
241            immediate_msg += f"\n  WARNING: Some of your prompts contain the suffix `{common_suffix}` more than once. We strongly suggest that you review your prompts and add a unique suffix"
242
243    else:
244        immediate_msg = "\n- Your data does not contain a common separator at the end of your prompts. Having a separator string appended to the end of the prompt makes it clearer to the fine-tuned model where the completion should begin. See https://platform.openai.com/docs/guides/fine-tuning/preparing-your-dataset for more detail and examples. If you intend to do open-ended generation, then you should leave the prompts empty"
245
246    if common_suffix == "":
247        optional_msg = f"Add a suffix separator `{display_suggested_suffix}` to all prompts"
248
249        def optional_fn(x: Any) -> Any:
250            return add_suffix(x, suggested_suffix)
251
252    return Remediation(
253        name="common_completion_suffix",
254        immediate_msg=immediate_msg,
255        optional_msg=optional_msg,
256        optional_fn=optional_fn,
257        error_msg=error_msg,
258    )
259
260
261def common_prompt_prefix_validator(df: pd.DataFrame) -> Remediation:
262    """
263    This validator will suggest to remove a common prefix from the prompt if a long one exist.
264    """
265    MAX_PREFIX_LEN = 12
266
267    immediate_msg = None
268    optional_msg = None
269    optional_fn = None  # type: ignore
270
271    common_prefix = get_common_xfix(df.prompt, xfix="prefix")
272    if common_prefix == "":
273        return Remediation(name="common_prefix")
274
275    def remove_common_prefix(x: Any, prefix: Any) -> Any:
276        x["prompt"] = x["prompt"].str[len(prefix) :]
277        return x
278
279    if (df.prompt == common_prefix).all():
280        # already handled by common_suffix_validator
281        return Remediation(name="common_prefix")
282
283    if common_prefix != "":
284        immediate_msg = f"\n- All prompts start with prefix `{common_prefix}`"
285        if MAX_PREFIX_LEN < len(common_prefix):
286            immediate_msg += ". Fine-tuning doesn't require the instruction specifying the task, or a few-shot example scenario. Most of the time you should only add the input data into the prompt, and the desired output into the completion"
287            optional_msg = f"Remove prefix `{common_prefix}` from all prompts"
288
289            def optional_fn(x: Any) -> Any:
290                return remove_common_prefix(x, common_prefix)
291
292    return Remediation(
293        name="common_prompt_prefix",
294        immediate_msg=immediate_msg,
295        optional_msg=optional_msg,
296        optional_fn=optional_fn,
297    )
298
299
300def common_completion_prefix_validator(df: pd.DataFrame) -> Remediation:
301    """
302    This validator will suggest to remove a common prefix from the completion if a long one exist.
303    """
304    MAX_PREFIX_LEN = 5
305
306    common_prefix = get_common_xfix(df.completion, xfix="prefix")
307    ws_prefix = len(common_prefix) > 0 and common_prefix[0] == " "
308    if len(common_prefix) < MAX_PREFIX_LEN:
309        return Remediation(name="common_prefix")
310
311    def remove_common_prefix(x: Any, prefix: Any, ws_prefix: Any) -> Any:
312        x["completion"] = x["completion"].str[len(prefix) :]
313        if ws_prefix:
314            # keep the single whitespace as prefix
315            x["completion"] = f" {x['completion']}"
316        return x
317
318    if (df.completion == common_prefix).all():
319        # already handled by common_suffix_validator
320        return Remediation(name="common_prefix")
321
322    immediate_msg = f"\n- All completions start with prefix `{common_prefix}`. Most of the time you should only add the output data into the completion, without any prefix"
323    optional_msg = f"Remove prefix `{common_prefix}` from all completions"
324
325    def optional_fn(x: Any) -> Any:
326        return remove_common_prefix(x, common_prefix, ws_prefix)
327
328    return Remediation(
329        name="common_completion_prefix",
330        immediate_msg=immediate_msg,
331        optional_msg=optional_msg,
332        optional_fn=optional_fn,
333    )
334
335
336def common_completion_suffix_validator(df: pd.DataFrame) -> Remediation:
337    """
338    This validator will suggest to add a common suffix to the completion if one doesn't already exist in case of classification or conditional generation.
339    """
340    error_msg = None
341    immediate_msg = None
342    optional_msg = None
343    optional_fn = None  # type: ignore
344
345    ft_type = infer_task_type(df)
346    if ft_type == "open-ended generation" or ft_type == "classification":
347        return Remediation(name="common_suffix")
348
349    common_suffix = get_common_xfix(df.completion, xfix="suffix")
350    if (df.completion == common_suffix).all():
351        error_msg = f"All completions are identical: `{common_suffix}`\nEnsure completions are different, otherwise the model will just repeat `{common_suffix}`"
352        return Remediation(name="common_suffix", error_msg=error_msg)
353
354    # Find a suffix which is not contained within the completion otherwise
355    suggested_suffix = " [END]"
356    suffix_options = [
357        "\n",
358        ".",
359        " END",
360        "***",
361        "+++",
362        "&&&",
363        "$$$",
364        "@@@",
365        "%%%",
366    ]
367    for suffix_option in suffix_options:
368        if df.completion.str.contains(suffix_option, regex=False).any():
369            continue
370        suggested_suffix = suffix_option
371        break
372    display_suggested_suffix = suggested_suffix.replace("\n", "\\n")
373
374    def add_suffix(x: Any, suffix: Any) -> Any:
375        x["completion"] += suffix
376        return x
377
378    if common_suffix != "":
379        common_suffix_new_line_handled = common_suffix.replace("\n", "\\n")
380        immediate_msg = f"\n- All completions end with suffix `{common_suffix_new_line_handled}`"
381        if len(common_suffix) > 10:
382            immediate_msg += f". This suffix seems very long. Consider replacing with a shorter suffix, such as `{display_suggested_suffix}`"
383        if df.completion.str[: -len(common_suffix)].str.contains(common_suffix, regex=False).any():
384            immediate_msg += f"\n  WARNING: Some of your completions contain the suffix `{common_suffix}` more than once. We suggest that you review your completions and add a unique ending"
385
386    else:
387        immediate_msg = "\n- Your data does not contain a common ending at the end of your completions. Having a common ending string appended to the end of the completion makes it clearer to the fine-tuned model where the completion should end. See https://platform.openai.com/docs/guides/fine-tuning/preparing-your-dataset for more detail and examples."
388
389    if common_suffix == "":
390        optional_msg = f"Add a suffix ending `{display_suggested_suffix}` to all completions"
391
392        def optional_fn(x: Any) -> Any:
393            return add_suffix(x, suggested_suffix)
394
395    return Remediation(
396        name="common_completion_suffix",
397        immediate_msg=immediate_msg,
398        optional_msg=optional_msg,
399        optional_fn=optional_fn,
400        error_msg=error_msg,
401    )
402
403
404def completions_space_start_validator(df: pd.DataFrame) -> Remediation:
405    """
406    This validator will suggest to add a space at the start of the completion if it doesn't already exist. This helps with tokenization.
407    """
408
409    def add_space_start(x: Any) -> Any:
410        x["completion"] = x["completion"].apply(lambda s: ("" if s.startswith(" ") else " ") + s)
411        return x
412
413    optional_msg = None
414    optional_fn = None
415    immediate_msg = None
416
417    if df.completion.str[:1].nunique() != 1 or df.completion.values[0][0] != " ":
418        immediate_msg = "\n- The completion should start with a whitespace character (` `). This tends to produce better results due to the tokenization we use. See https://platform.openai.com/docs/guides/fine-tuning/preparing-your-dataset for more details"
419        optional_msg = "Add a whitespace character to the beginning of the completion"
420        optional_fn = add_space_start
421    return Remediation(
422        name="completion_space_start",
423        immediate_msg=immediate_msg,
424        optional_msg=optional_msg,
425        optional_fn=optional_fn,
426    )
427
428
429def lower_case_validator(df: pd.DataFrame, column: Any) -> Remediation | None:
430    """
431    This validator will suggest to lowercase the column values, if more than a third of letters are uppercase.
432    """
433
434    def lower_case(x: Any) -> Any:
435        x[column] = x[column].str.lower()
436        return x
437
438    count_upper = df[column].apply(lambda x: sum(1 for c in x if c.isalpha() and c.isupper())).sum()
439    count_lower = df[column].apply(lambda x: sum(1 for c in x if c.isalpha() and c.islower())).sum()
440
441    if count_upper * 2 > count_lower:
442        return Remediation(
443            name="lower_case",
444            immediate_msg=f"\n- More than a third of your `{column}` column/key is uppercase. Uppercase {column}s tends to perform worse than a mixture of case encountered in normal language. We recommend to lower case the data if that makes sense in your domain. See https://platform.openai.com/docs/guides/fine-tuning/preparing-your-dataset for more details",
445            optional_msg=f"Lowercase all your data in column/key `{column}`",
446            optional_fn=lower_case,
447        )
448    return None
449
450
451def read_any_format(
452    fname: str, fields: list[str] = ["prompt", "completion"]
453) -> tuple[pd.DataFrame | None, Remediation]:
454    """
455    This function will read a file saved in .csv, .json, .txt, .xlsx or .tsv format using pandas.
456     - for .xlsx it will read the first sheet
457     - for .txt it will assume completions and split on newline
458    """
459    remediation = None
460    necessary_msg = None
461    immediate_msg = None
462    error_msg = None
463    df = None
464
465    if os.path.isfile(fname):
466        try:
467            if fname.lower().endswith(".csv") or fname.lower().endswith(".tsv"):
468                file_extension_str, separator = ("CSV", ",") if fname.lower().endswith(".csv") else ("TSV", "\t")
469                immediate_msg = (
470                    f"\n- Based on your file extension, your file is formatted as a {file_extension_str} file"
471                )
472                necessary_msg = f"Your format `{file_extension_str}` will be converted to `JSONL`"
473                df = pd.read_csv(fname, sep=separator, dtype=str).fillna("")
474            elif fname.lower().endswith(".xlsx"):
475                immediate_msg = "\n- Based on your file extension, your file is formatted as an Excel file"
476                necessary_msg = "Your format `XLSX` will be converted to `JSONL`"
477                xls = pd.ExcelFile(fname)
478                sheets = xls.sheet_names
479                if len(sheets) > 1:
480                    immediate_msg += "\n- Your Excel file contains more than one sheet. Please either save as csv or ensure all data is present in the first sheet. WARNING: Reading only the first sheet..."
481                df = pd.read_excel(fname, dtype=str).fillna("")
482            elif fname.lower().endswith(".txt"):
483                immediate_msg = "\n- Based on your file extension, you provided a text file"
484                necessary_msg = "Your format `TXT` will be converted to `JSONL`"
485                with open(fname, "r") as f:
486                    content = f.read()
487                    df = pd.DataFrame(
488                        [["", line] for line in content.split("\n")],
489                        columns=fields,
490                        dtype=str,
491                    ).fillna("")
492            elif fname.lower().endswith(".jsonl"):
493                df = pd.read_json(fname, lines=True, dtype=str).fillna("")  # type: ignore
494                if len(df) == 1:  # type: ignore
495                    # this is NOT what we expect for a .jsonl file
496                    immediate_msg = "\n- Your JSONL file appears to be in a JSON format. Your file will be converted to JSONL format"
497                    necessary_msg = "Your format `JSON` will be converted to `JSONL`"
498                    df = pd.read_json(fname, dtype=str).fillna("")  # type: ignore
499                else:
500                    pass  # this is what we expect for a .jsonl file
501            elif fname.lower().endswith(".json"):
502                try:
503                    # to handle case where .json file is actually a .jsonl file
504                    df = pd.read_json(fname, lines=True, dtype=str).fillna("")  # type: ignore
505                    if len(df) == 1:  # type: ignore
506                        # this code path corresponds to a .json file that has one line
507                        df = pd.read_json(fname, dtype=str).fillna("")  # type: ignore
508                    else:
509                        # this is NOT what we expect for a .json file
510                        immediate_msg = "\n- Your JSON file appears to be in a JSONL format. Your file will be converted to JSONL format"
511                        necessary_msg = "Your format `JSON` will be converted to `JSONL`"
512                except ValueError:
513                    # this code path corresponds to a .json file that has multiple lines (i.e. it is indented)
514                    df = pd.read_json(fname, dtype=str).fillna("")  # type: ignore
515            else:
516                error_msg = (
517                    "Your file must have one of the following extensions: .CSV, .TSV, .XLSX, .TXT, .JSON or .JSONL"
518                )
519                if "." in fname:
520                    error_msg += f" Your file `{fname}` ends with the extension `.{fname.split('.')[-1]}` which is not supported."
521                else:
522                    error_msg += f" Your file `{fname}` is missing a file extension."
523
524        except (ValueError, TypeError):
525            file_extension_str = fname.split(".")[-1].upper()
526            error_msg = f"Your file `{fname}` does not appear to be in valid {file_extension_str} format. Please ensure your file is formatted as a valid {file_extension_str} file."
527
528    else:
529        error_msg = f"File {fname} does not exist."
530
531    remediation = Remediation(
532        name="read_any_format",
533        necessary_msg=necessary_msg,
534        immediate_msg=immediate_msg,
535        error_msg=error_msg,
536    )
537    return df, remediation
538
539
540def format_inferrer_validator(df: pd.DataFrame) -> Remediation:
541    """
542    This validator will infer the likely fine-tuning format of the data, and display it to the user if it is classification.
543    It will also suggest to use ada and explain train/validation split benefits.
544    """
545    ft_type = infer_task_type(df)
546    immediate_msg = None
547    if ft_type == "classification":
548        immediate_msg = f"\n- Based on your data it seems like you're trying to fine-tune a model for {ft_type}\n- For classification, we recommend you try one of the faster and cheaper models, such as `ada`\n- For classification, you can estimate the expected model performance by keeping a held out dataset, which is not used for training"
549    return Remediation(name="num_examples", immediate_msg=immediate_msg)
550
551
552def apply_necessary_remediation(df: OptionalDataFrameT, remediation: Remediation) -> OptionalDataFrameT:
553    """
554    This function will apply a necessary remediation to a dataframe, or print an error message if one exists.
555    """
556    if remediation.error_msg is not None:
557        sys.stderr.write(f"\n\nERROR in {remediation.name} validator: {remediation.error_msg}\n\nAborting...")
558        sys.exit(1)
559    if remediation.immediate_msg is not None:
560        sys.stdout.write(remediation.immediate_msg)
561    if remediation.necessary_fn is not None:
562        df = remediation.necessary_fn(df)
563    return df
564
565
566def accept_suggestion(input_text: str, auto_accept: bool) -> bool:
567    sys.stdout.write(input_text)
568    if auto_accept:
569        sys.stdout.write("Y\n")
570        return True
571    return input().lower() != "n"
572
573
574def apply_optional_remediation(
575    df: pd.DataFrame, remediation: Remediation, auto_accept: bool
576) -> tuple[pd.DataFrame, bool]:
577    """
578    This function will apply an optional remediation to a dataframe, based on the user input.
579    """
580    optional_applied = False
581    input_text = f"- [Recommended] {remediation.optional_msg} [Y/n]: "
582    if remediation.optional_msg is not None:
583        if accept_suggestion(input_text, auto_accept):
584            assert remediation.optional_fn is not None
585            df = remediation.optional_fn(df)
586            optional_applied = True
587    if remediation.necessary_msg is not None:
588        sys.stdout.write(f"- [Necessary] {remediation.necessary_msg}\n")
589    return df, optional_applied
590
591
592def estimate_fine_tuning_time(df: pd.DataFrame) -> None:
593    """
594    Estimate the time it'll take to fine-tune the dataset
595    """
596    ft_format = infer_task_type(df)
597    expected_time = 1.0
598    if ft_format == "classification":
599        num_examples = len(df)
600        expected_time = num_examples * 1.44
601    else:
602        size = df.memory_usage(index=True).sum()
603        expected_time = size * 0.0515
604
605    def format_time(time: float) -> str:
606        if time < 60:
607            return f"{round(time, 2)} seconds"
608        elif time < 3600:
609            return f"{round(time / 60, 2)} minutes"
610        elif time < 86400:
611            return f"{round(time / 3600, 2)} hours"
612        else:
613            return f"{round(time / 86400, 2)} days"
614
615    time_string = format_time(expected_time + 140)
616    sys.stdout.write(
617        f"Once your model starts training, it'll approximately take {time_string} to train a `curie` model, and less for `ada` and `babbage`. Queue will approximately take half an hour per job ahead of you.\n"
618    )
619
620
621def get_outfnames(fname: str, split: bool) -> list[str]:
622    suffixes = ["_train", "_valid"] if split else [""]
623    i = 0
624    while True:
625        index_suffix = f" ({i})" if i > 0 else ""
626        candidate_fnames = [f"{os.path.splitext(fname)[0]}_prepared{suffix}{index_suffix}.jsonl" for suffix in suffixes]
627        if not any(os.path.isfile(f) for f in candidate_fnames):
628            return candidate_fnames
629        i += 1
630
631
632def get_classification_hyperparams(df: pd.DataFrame) -> tuple[int, object]:
633    n_classes = df.completion.nunique()
634    pos_class = None
635    if n_classes == 2:
636        pos_class = df.completion.value_counts().index[0]
637    return n_classes, pos_class
638
639
640def write_out_file(df: pd.DataFrame, fname: str, any_remediations: bool, auto_accept: bool) -> None:
641    """
642    This function will write out a dataframe to a file, if the user would like to proceed, and also offer a fine-tuning command with the newly created file.
643    For classification it will optionally ask the user if they would like to split the data into train/valid files, and modify the suggested command to include the valid set.
644    """
645    ft_format = infer_task_type(df)
646    common_prompt_suffix = get_common_xfix(df.prompt, xfix="suffix")
647    common_completion_suffix = get_common_xfix(df.completion, xfix="suffix")
648
649    split = False
650    input_text = "- [Recommended] Would you like to split into training and validation set? [Y/n]: "
651    if ft_format == "classification":
652        if accept_suggestion(input_text, auto_accept):
653            split = True
654
655    additional_params = ""
656    common_prompt_suffix_new_line_handled = common_prompt_suffix.replace("\n", "\\n")
657    common_completion_suffix_new_line_handled = common_completion_suffix.replace("\n", "\\n")
658    optional_ending_string = (
659        f' Make sure to include `stop=["{common_completion_suffix_new_line_handled}"]` so that the generated texts ends at the expected place.'
660        if len(common_completion_suffix_new_line_handled) > 0
661        else ""
662    )
663
664    input_text = "\n\nYour data will be written to a new JSONL file. Proceed [Y/n]: "
665
666    if not any_remediations and not split:
667        sys.stdout.write(
668            f'\nYou can use your file for fine-tuning:\n> openai api fine_tunes.create -t "{fname}"{additional_params}\n\nAfter you’ve fine-tuned a model, remember that your prompt has to end with the indicator string `{common_prompt_suffix_new_line_handled}` for the model to start generating completions, rather than continuing with the prompt.{optional_ending_string}\n'
669        )
670        estimate_fine_tuning_time(df)
671
672    elif accept_suggestion(input_text, auto_accept):
673        fnames = get_outfnames(fname, split)
674        if split:
675            assert len(fnames) == 2 and "train" in fnames[0] and "valid" in fnames[1]
676            MAX_VALID_EXAMPLES = 1000
677            n_train = max(len(df) - MAX_VALID_EXAMPLES, int(len(df) * 0.8))
678            df_train = df.sample(n=n_train, random_state=42)
679            df_valid = df.drop(df_train.index)
680            df_train[["prompt", "completion"]].to_json(  # type: ignore
681                fnames[0], lines=True, orient="records", force_ascii=False, indent=None
682            )
683            df_valid[["prompt", "completion"]].to_json(
684                fnames[1], lines=True, orient="records", force_ascii=False, indent=None
685            )
686
687            n_classes, pos_class = get_classification_hyperparams(df)
688            additional_params += " --compute_classification_metrics"
689            if n_classes == 2:
690                additional_params += f' --classification_positive_class "{pos_class}"'
691            else:
692                additional_params += f" --classification_n_classes {n_classes}"
693        else:
694            assert len(fnames) == 1
695            df[["prompt", "completion"]].to_json(
696                fnames[0], lines=True, orient="records", force_ascii=False, indent=None
697            )
698
699        # Add -v VALID_FILE if we split the file into train / valid
700        files_string = ("s" if split else "") + " to `" + ("` and `".join(fnames))
701        valid_string = f' -v "{fnames[1]}"' if split else ""
702        separator_reminder = (
703            ""
704            if len(common_prompt_suffix_new_line_handled) == 0
705            else f"After you’ve fine-tuned a model, remember that your prompt has to end with the indicator string `{common_prompt_suffix_new_line_handled}` for the model to start generating completions, rather than continuing with the prompt."
706        )
707        sys.stdout.write(
708            f'\nWrote modified file{files_string}`\nFeel free to take a look!\n\nNow use that file when fine-tuning:\n> openai api fine_tunes.create -t "{fnames[0]}"{valid_string}{additional_params}\n\n{separator_reminder}{optional_ending_string}\n'
709        )
710        estimate_fine_tuning_time(df)
711    else:
712        sys.stdout.write("Aborting... did not write the file\n")
713
714
715def infer_task_type(df: pd.DataFrame) -> str:
716    """
717    Infer the likely fine-tuning task type from the data
718    """
719    CLASSIFICATION_THRESHOLD = 3  # min_average instances of each class
720    if sum(df.prompt.str.len()) == 0:
721        return "open-ended generation"
722
723    if len(df.completion.unique()) < len(df) / CLASSIFICATION_THRESHOLD:
724        return "classification"
725
726    return "conditional generation"
727
728
729def get_common_xfix(series: Any, xfix: str = "suffix") -> str:
730    """
731    Finds the longest common suffix or prefix of all the values in a series
732    """
733    common_xfix = ""
734    while True:
735        common_xfixes = (
736            series.str[-(len(common_xfix) + 1) :] if xfix == "suffix" else series.str[: len(common_xfix) + 1]
737        )  # first few or last few characters
738        if common_xfixes.nunique() != 1:  # we found the character at which we don't have a unique xfix anymore
739            break
740        elif common_xfix == common_xfixes.values[0]:  # the entire first row is a prefix of every other row
741            break
742        else:  # the first or last few characters are still common across all rows - let's try to add one more
743            common_xfix = common_xfixes.values[0]
744    return common_xfix
745
746
747Validator: TypeAlias = "Callable[[pd.DataFrame], Remediation | None]"
748
749
750def get_validators() -> list[Validator]:
751    return [
752        num_examples_validator,
753        lambda x: necessary_column_validator(x, "prompt"),
754        lambda x: necessary_column_validator(x, "completion"),
755        additional_column_validator,
756        non_empty_field_validator,
757        format_inferrer_validator,
758        duplicated_rows_validator,
759        long_examples_validator,
760        lambda x: lower_case_validator(x, "prompt"),
761        lambda x: lower_case_validator(x, "completion"),
762        common_prompt_suffix_validator,
763        common_prompt_prefix_validator,
764        common_completion_prefix_validator,
765        common_completion_suffix_validator,
766        completions_space_start_validator,
767    ]
768
769
770def apply_validators(
771    df: pd.DataFrame,
772    fname: str,
773    remediation: Remediation | None,
774    validators: list[Validator],
775    auto_accept: bool,
776    write_out_file_func: Callable[..., Any],
777) -> None:
778    optional_remediations: list[Remediation] = []
779    if remediation is not None:
780        optional_remediations.append(remediation)
781    for validator in validators:
782        remediation = validator(df)
783        if remediation is not None:
784            optional_remediations.append(remediation)
785            df = apply_necessary_remediation(df, remediation)
786
787    any_optional_or_necessary_remediations = any(
788        [
789            remediation
790            for remediation in optional_remediations
791            if remediation.optional_msg is not None or remediation.necessary_msg is not None
792        ]
793    )
794    any_necessary_applied = any(
795        [remediation for remediation in optional_remediations if remediation.necessary_msg is not None]
796    )
797    any_optional_applied = False
798
799    if any_optional_or_necessary_remediations:
800        sys.stdout.write("\n\nBased on the analysis we will perform the following actions:\n")
801        for remediation in optional_remediations:
802            df, optional_applied = apply_optional_remediation(df, remediation, auto_accept)
803            any_optional_applied = any_optional_applied or optional_applied
804    else:
805        sys.stdout.write("\n\nNo remediations found.\n")
806
807    any_optional_or_necessary_applied = any_optional_applied or any_necessary_applied
808
809    write_out_file_func(df, fname, any_optional_or_necessary_applied, auto_accept)