Commit db3f3526

Liren Tu <tuliren@gmail.com>
2023-04-20 05:04:20
Pass in keyword arguments in embedding utility functions (#405)
* Pass in kwargs in embedding util functions * Remove debug code
1 parent 794def3
Changed files (1)
openai/embeddings_utils.py
@@ -15,51 +15,51 @@ from openai.datalib.pandas_helper import pandas as pd
 
 
 @retry(wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6))
-def get_embedding(text: str, engine="text-similarity-davinci-001") -> List[float]:
+def get_embedding(text: str, engine="text-similarity-davinci-001", **kwargs) -> List[float]:
 
     # replace newlines, which can negatively affect performance.
     text = text.replace("\n", " ")
 
-    return openai.Embedding.create(input=[text], engine=engine)["data"][0]["embedding"]
+    return openai.Embedding.create(input=[text], engine=engine, **kwargs)["data"][0]["embedding"]
 
 
 @retry(wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6))
 async def aget_embedding(
-    text: str, engine="text-similarity-davinci-001"
+    text: str, engine="text-similarity-davinci-001", **kwargs
 ) -> List[float]:
 
     # replace newlines, which can negatively affect performance.
     text = text.replace("\n", " ")
 
-    return (await openai.Embedding.acreate(input=[text], engine=engine))["data"][0][
+    return (await openai.Embedding.acreate(input=[text], engine=engine, **kwargs))["data"][0][
         "embedding"
     ]
 
 
 @retry(wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6))
 def get_embeddings(
-    list_of_text: List[str], engine="text-similarity-babbage-001"
+    list_of_text: List[str], engine="text-similarity-babbage-001", **kwargs
 ) -> List[List[float]]:
     assert len(list_of_text) <= 2048, "The batch size should not be larger than 2048."
 
     # replace newlines, which can negatively affect performance.
     list_of_text = [text.replace("\n", " ") for text in list_of_text]
 
-    data = openai.Embedding.create(input=list_of_text, engine=engine).data
+    data = openai.Embedding.create(input=list_of_text, engine=engine, **kwargs).data
     data = sorted(data, key=lambda x: x["index"])  # maintain the same order as input.
     return [d["embedding"] for d in data]
 
 
 @retry(wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6))
 async def aget_embeddings(
-    list_of_text: List[str], engine="text-similarity-babbage-001"
+    list_of_text: List[str], engine="text-similarity-babbage-001", **kwargs
 ) -> List[List[float]]:
     assert len(list_of_text) <= 2048, "The batch size should not be larger than 2048."
 
     # replace newlines, which can negatively affect performance.
     list_of_text = [text.replace("\n", " ") for text in list_of_text]
 
-    data = (await openai.Embedding.acreate(input=list_of_text, engine=engine)).data
+    data = (await openai.Embedding.acreate(input=list_of_text, engine=engine, **kwargs)).data
     data = sorted(data, key=lambda x: x["index"])  # maintain the same order as input.
     return [d["embedding"] for d in data]