Add pagination for Vertex AI embeddings (#5325)

Fixes #5316

---------

Co-authored-by: Justin Flick <jflick@homesite.com>
Co-authored-by: Harrison Chase <hw.chase.17@gmail.com>
searx_updates
Justin Flick 12 months ago committed by GitHub
parent 3e16468423
commit c09f8e4ddc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -22,17 +22,25 @@ class VertexAIEmbeddings(_VertexAICommon, Embeddings):
values["client"] = TextEmbeddingModel.from_pretrained(values["model_name"])
return values
def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""Embed a list of strings.
def embed_documents(
self, texts: List[str], batch_size: int = 5
) -> List[List[float]]:
"""Embed a list of strings. Vertex AI currently
sets a max batch size of 5 strings.
Args:
texts: List[str] The list of strings to embed.
batch_size: [int] The batch size of embeddings to send to the model
Returns:
List of embeddings, one for each text.
"""
embeddings = self.client.get_embeddings(texts)
return [el.values for el in embeddings]
embeddings = []
for batch in range(0, len(texts), batch_size):
text_batch = texts[batch : batch + batch_size]
embeddings_batch = self.client.get_embeddings(text_batch)
embeddings.extend([el.values for el in embeddings_batch])
return embeddings
def embed_query(self, text: str) -> List[float]:
"""Embed a text.

@ -23,3 +23,22 @@ def test_embedding_query() -> None:
model = VertexAIEmbeddings()
output = model.embed_query(document)
assert len(output) == 768
def test_paginated_texts() -> None:
documents = [
"foo bar",
"foo baz",
"bar foo",
"baz foo",
"bar bar",
"foo foo",
"baz baz",
"baz bar",
]
model = VertexAIEmbeddings()
output = model.embed_documents(documents)
assert len(output) == 8
assert len(output[0]) == 768
assert model._llm_type == "vertexai"
assert model.model_name == model.client._model_id

Loading…
Cancel
Save