Harrison/batch embeds (#972)

Co-authored-by: John Dagdelen <jdagdelen@users.noreply.github.com>
Co-authored-by: Harrison Chase <harrisonchase@Harrisons-MBP.attlocal.net>
This commit is contained in:
Harrison Chase 2023-02-10 06:59:50 -08:00 committed by GitHub
parent ba54d36787
commit 91c6cea227
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 26 additions and 8 deletions

View File

@ -75,20 +75,27 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
text = text.replace("\n", " ")
return self.client.create(input=[text], engine=engine)["data"][0]["embedding"]
def embed_documents(self, texts: List[str]) -> List[List[float]]:
def embed_documents(
self, texts: List[str], chunk_size: int = 1000
) -> List[List[float]]:
"""Call out to OpenAI's embedding endpoint for embedding search docs.
Args:
texts: The list of texts to embed.
chunk_size: The maximum number of texts to send to OpenAI at once
(max 1000).
Returns:
List of embeddings, one for each text.
"""
responses = [
self._embedding_func(text, engine=self.document_model_name)
for text in texts
]
return responses
# handle large batches of texts
results = []
for i in range(0, len(texts), chunk_size):
response = self.client.create(
input=texts[i : i + chunk_size], engine=self.document_model_name
)
results += [r["embedding"] for r in response["data"]]
return results
def embed_query(self, text: str) -> List[float]:
"""Call out to OpenAI's embedding endpoint for embedding query text.

View File

@ -8,7 +8,18 @@ def test_openai_embedding_documents() -> None:
embedding = OpenAIEmbeddings()
output = embedding.embed_documents(documents)
assert len(output) == 1
assert len(output[0]) == 2048
assert len(output[0]) == 1536
def test_openai_embedding_documents_multiple() -> None:
"""Test openai embeddings."""
documents = ["foo bar", "bar foo", "foo"]
embedding = OpenAIEmbeddings()
output = embedding.embed_documents(documents, chunk_size=2)
assert len(output) == 3
assert len(output[0]) == 1536
assert len(output[1]) == 1536
assert len(output[2]) == 1536
def test_openai_embedding_query() -> None:
@ -16,4 +27,4 @@ def test_openai_embedding_query() -> None:
document = "foo bar"
embedding = OpenAIEmbeddings()
output = embedding.embed_query(document)
assert len(output) == 2048
assert len(output) == 1536