Commit db3f3526
Changed files (1)
openai
openai/embeddings_utils.py
@@ -15,51 +15,51 @@ from openai.datalib.pandas_helper import pandas as pd
@retry(wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6))
-def get_embedding(text: str, engine="text-similarity-davinci-001") -> List[float]:
+def get_embedding(text: str, engine="text-similarity-davinci-001", **kwargs) -> List[float]:
# replace newlines, which can negatively affect performance.
text = text.replace("\n", " ")
- return openai.Embedding.create(input=[text], engine=engine)["data"][0]["embedding"]
+ return openai.Embedding.create(input=[text], engine=engine, **kwargs)["data"][0]["embedding"]
@retry(wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6))
async def aget_embedding(
- text: str, engine="text-similarity-davinci-001"
+ text: str, engine="text-similarity-davinci-001", **kwargs
) -> List[float]:
# replace newlines, which can negatively affect performance.
text = text.replace("\n", " ")
- return (await openai.Embedding.acreate(input=[text], engine=engine))["data"][0][
+ return (await openai.Embedding.acreate(input=[text], engine=engine, **kwargs))["data"][0][
"embedding"
]
@retry(wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6))
def get_embeddings(
- list_of_text: List[str], engine="text-similarity-babbage-001"
+ list_of_text: List[str], engine="text-similarity-babbage-001", **kwargs
) -> List[List[float]]:
assert len(list_of_text) <= 2048, "The batch size should not be larger than 2048."
# replace newlines, which can negatively affect performance.
list_of_text = [text.replace("\n", " ") for text in list_of_text]
- data = openai.Embedding.create(input=list_of_text, engine=engine).data
+ data = openai.Embedding.create(input=list_of_text, engine=engine, **kwargs).data
data = sorted(data, key=lambda x: x["index"]) # maintain the same order as input.
return [d["embedding"] for d in data]
@retry(wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6))
async def aget_embeddings(
- list_of_text: List[str], engine="text-similarity-babbage-001"
+ list_of_text: List[str], engine="text-similarity-babbage-001", **kwargs
) -> List[List[float]]:
assert len(list_of_text) <= 2048, "The batch size should not be larger than 2048."
# replace newlines, which can negatively affect performance.
list_of_text = [text.replace("\n", " ") for text in list_of_text]
- data = (await openai.Embedding.acreate(input=list_of_text, engine=engine)).data
+ data = (await openai.Embedding.acreate(input=list_of_text, engine=engine, **kwargs)).data
data = sorted(data, key=lambda x: x["index"]) # maintain the same order as input.
return [d["embedding"] for d in data]