mirror of
https://github.com/hwchase17/langchain
synced 2024-11-06 03:20:49 +00:00
Harrison/batch embeds (#972)
Co-authored-by: John Dagdelen <jdagdelen@users.noreply.github.com> Co-authored-by: Harrison Chase <harrisonchase@Harrisons-MBP.attlocal.net>
This commit is contained in:
parent
ba54d36787
commit
91c6cea227
@ -75,20 +75,27 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
|
||||
text = text.replace("\n", " ")
|
||||
return self.client.create(input=[text], engine=engine)["data"][0]["embedding"]
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
def embed_documents(
|
||||
self, texts: List[str], chunk_size: int = 1000
|
||||
) -> List[List[float]]:
|
||||
"""Call out to OpenAI's embedding endpoint for embedding search docs.
|
||||
|
||||
Args:
|
||||
texts: The list of texts to embed.
|
||||
chunk_size: The maximum number of texts to send to OpenAI at once
|
||||
(max 1000).
|
||||
|
||||
Returns:
|
||||
List of embeddings, one for each text.
|
||||
"""
|
||||
responses = [
|
||||
self._embedding_func(text, engine=self.document_model_name)
|
||||
for text in texts
|
||||
]
|
||||
return responses
|
||||
# handle large batches of texts
|
||||
results = []
|
||||
for i in range(0, len(texts), chunk_size):
|
||||
response = self.client.create(
|
||||
input=texts[i : i + chunk_size], engine=self.document_model_name
|
||||
)
|
||||
results += [r["embedding"] for r in response["data"]]
|
||||
return results
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
"""Call out to OpenAI's embedding endpoint for embedding query text.
|
||||
|
@ -8,7 +8,18 @@ def test_openai_embedding_documents() -> None:
|
||||
embedding = OpenAIEmbeddings()
|
||||
output = embedding.embed_documents(documents)
|
||||
assert len(output) == 1
|
||||
assert len(output[0]) == 2048
|
||||
assert len(output[0]) == 1536
|
||||
|
||||
|
||||
def test_openai_embedding_documents_multiple() -> None:
|
||||
"""Test openai embeddings."""
|
||||
documents = ["foo bar", "bar foo", "foo"]
|
||||
embedding = OpenAIEmbeddings()
|
||||
output = embedding.embed_documents(documents, chunk_size=2)
|
||||
assert len(output) == 3
|
||||
assert len(output[0]) == 1536
|
||||
assert len(output[1]) == 1536
|
||||
assert len(output[2]) == 1536
|
||||
|
||||
|
||||
def test_openai_embedding_query() -> None:
|
||||
@ -16,4 +27,4 @@ def test_openai_embedding_query() -> None:
|
||||
document = "foo bar"
|
||||
embedding = OpenAIEmbeddings()
|
||||
output = embedding.embed_query(document)
|
||||
assert len(output) == 2048
|
||||
assert len(output) == 1536
|
||||
|
Loading…
Reference in New Issue
Block a user