diff --git a/langchain/embeddings/openai.py b/langchain/embeddings/openai.py index 06c972b828..eb48c6229d 100644 --- a/langchain/embeddings/openai.py +++ b/langchain/embeddings/openai.py @@ -75,20 +75,27 @@ class OpenAIEmbeddings(BaseModel, Embeddings): text = text.replace("\n", " ") return self.client.create(input=[text], engine=engine)["data"][0]["embedding"] - def embed_documents(self, texts: List[str]) -> List[List[float]]: + def embed_documents( + self, texts: List[str], chunk_size: int = 1000 + ) -> List[List[float]]: """Call out to OpenAI's embedding endpoint for embedding search docs. Args: texts: The list of texts to embed. + chunk_size: The maximum number of texts to send to OpenAI at once + (max 1000). Returns: List of embeddings, one for each text. """ - responses = [ - self._embedding_func(text, engine=self.document_model_name) - for text in texts - ] - return responses + # handle large batches of texts + results = [] + for i in range(0, len(texts), chunk_size): + response = self.client.create( + input=texts[i : i + chunk_size], engine=self.document_model_name + ) + results += [r["embedding"] for r in response["data"]] + return results def embed_query(self, text: str) -> List[float]: """Call out to OpenAI's embedding endpoint for embedding query text. diff --git a/tests/integration_tests/embeddings/test_openai.py b/tests/integration_tests/embeddings/test_openai.py index a721f78ca7..13beee61f0 100644 --- a/tests/integration_tests/embeddings/test_openai.py +++ b/tests/integration_tests/embeddings/test_openai.py @@ -8,7 +8,18 @@ def test_openai_embedding_documents() -> None: embedding = OpenAIEmbeddings() output = embedding.embed_documents(documents) assert len(output) == 1 - assert len(output[0]) == 2048 + assert len(output[0]) == 1536 + + +def test_openai_embedding_documents_multiple() -> None: + """Test openai embeddings.""" + documents = ["foo bar", "bar foo", "foo"] + embedding = OpenAIEmbeddings() + output = embedding.embed_documents(documents, chunk_size=2) + assert len(output) == 3 + assert len(output[0]) == 1536 + assert len(output[1]) == 1536 + assert len(output[2]) == 1536 def test_openai_embedding_query() -> None: @@ -16,4 +27,4 @@ def test_openai_embedding_query() -> None: document = "foo bar" embedding = OpenAIEmbeddings() output = embedding.embed_query(document) - assert len(output) == 2048 + assert len(output) == 1536