langchain/libs/community/tests/unit_tests/embeddings/test_vertexai.py
Vlad Kolesnikov 11fda490ca
community[minor]: New model parameters and dynamic batching for VertexAIEmbeddings (#13999)
- **Description:** VertexAIEmbeddings performance improvements
  - **Twitter handle:** @vladkol

## Improvements

- Dynamic batch size, starting from 250, lowering down to 5. Batch size
varies across regions.
Some regions support larger batches, and it significantly improves
performance.
When running large batches of texts in `us-central1`, performance gain
can be up to 3.5x.
The dynamic batching also makes sure every batch is below 20K token
limit.
- New model parameter `embeddings_type` that translates to `task_type`
parameter of the API. Newer model versions support [different embeddings
task
types](https://cloud.google.com/vertex-ai/docs/generative-ai/embeddings/get-text-embeddings#api_changes_to_models_released_on_or_after_august_2023).
2023-12-17 22:24:22 -05:00

64 lines
1.5 KiB
Python

"""Test Vertex AI embeddings API wrapper.
"""
from langchain_community.embeddings import VertexAIEmbeddings
def test_split_by_punctuation() -> None:
parts = VertexAIEmbeddings._split_by_punctuation(
"Hello, my friend!\nHow are you?\nI have 2 news:\n\n\t- Good,\n\t- Bad."
)
assert parts == [
"Hello",
",",
" ",
"my",
" ",
"friend",
"!",
"\n",
"How",
" ",
"are",
" ",
"you",
"?",
"\n",
"I",
" ",
"have",
" ",
"2",
" ",
"news",
":",
"\n",
"\n",
"\t",
"-",
" ",
"Good",
",",
"\n",
"\t",
"-",
" ",
"Bad",
".",
]
def test_batching() -> None:
long_text = "foo " * 500 # 1000 words, 2000 tokens
long_texts = [long_text for _ in range(0, 250)]
documents251 = ["foo bar" for _ in range(0, 251)]
five_elem = VertexAIEmbeddings._prepare_batches(long_texts, 5)
default250_elem = VertexAIEmbeddings._prepare_batches(long_texts, 250)
batches251 = VertexAIEmbeddings._prepare_batches(documents251, 250)
assert len(five_elem) == 50 # 250/5 items
assert len(five_elem[0]) == 5 # 5 items per batch
assert len(default250_elem[0]) == 10 # Should not be more than 20K tokens
assert len(default250_elem) == 25
assert len(batches251[0]) == 250
assert len(batches251[1]) == 1