From c09f8e4ddc3be791bd0e8c8385ed1871bdd5d681 Mon Sep 17 00:00:00 2001 From: Justin Flick Date: Mon, 29 May 2023 06:57:41 -0700 Subject: [PATCH] Add pagination for Vertex AI embeddings (#5325) Fixes #5316 --------- Co-authored-by: Justin Flick Co-authored-by: Harrison Chase --- langchain/embeddings/vertexai.py | 16 ++++++++++++---- .../embeddings/test_vertexai.py | 19 +++++++++++++++++++ 2 files changed, 31 insertions(+), 4 deletions(-) diff --git a/langchain/embeddings/vertexai.py b/langchain/embeddings/vertexai.py index 0730a5fc..6ff2a95f 100644 --- a/langchain/embeddings/vertexai.py +++ b/langchain/embeddings/vertexai.py @@ -22,17 +22,25 @@ class VertexAIEmbeddings(_VertexAICommon, Embeddings): values["client"] = TextEmbeddingModel.from_pretrained(values["model_name"]) return values - def embed_documents(self, texts: List[str]) -> List[List[float]]: - """Embed a list of strings. + def embed_documents( + self, texts: List[str], batch_size: int = 5 + ) -> List[List[float]]: + """Embed a list of strings. Vertex AI currently + sets a max batch size of 5 strings. Args: texts: List[str] The list of strings to embed. + batch_size: [int] The batch size of embeddings to send to the model Returns: List of embeddings, one for each text. """ - embeddings = self.client.get_embeddings(texts) - return [el.values for el in embeddings] + embeddings = [] + for batch in range(0, len(texts), batch_size): + text_batch = texts[batch : batch + batch_size] + embeddings_batch = self.client.get_embeddings(text_batch) + embeddings.extend([el.values for el in embeddings_batch]) + return embeddings def embed_query(self, text: str) -> List[float]: """Embed a text. diff --git a/tests/integration_tests/embeddings/test_vertexai.py b/tests/integration_tests/embeddings/test_vertexai.py index ce389f89..5d711275 100644 --- a/tests/integration_tests/embeddings/test_vertexai.py +++ b/tests/integration_tests/embeddings/test_vertexai.py @@ -23,3 +23,22 @@ def test_embedding_query() -> None: model = VertexAIEmbeddings() output = model.embed_query(document) assert len(output) == 768 + + +def test_paginated_texts() -> None: + documents = [ + "foo bar", + "foo baz", + "bar foo", + "baz foo", + "bar bar", + "foo foo", + "baz baz", + "baz bar", + ] + model = VertexAIEmbeddings() + output = model.embed_documents(documents) + assert len(output) == 8 + assert len(output[0]) == 768 + assert model._llm_type == "vertexai" + assert model.model_name == model.client._model_id