From aacc6198b9e4703ca6400dca90b8ade0c2da33bd Mon Sep 17 00:00:00 2001 From: Anders Swanson <91502735+anders-swanson@users.noreply.github.com> Date: Mon, 17 Jun 2024 15:06:45 -0700 Subject: [PATCH] community: OCI GenAI embedding batch size (#22986) Thank you for contributing to LangChain! - [x] **PR title**: "community: OCI GenAI embedding batch size" - [x] **PR message**: - **Issue:** #22985 - [ ] **Add tests and docs**: N/A - [x] **Lint and test**: Run `make format`, `make lint` and `make test` from the root of the package(s) you've modified. See contribution guidelines for more: https://python.langchain.com/docs/contributing/ Additional guidelines: - Make sure optional dependencies are imported within a function. - Please do not add dependencies to pyproject.toml files (even optional ones) unless they are required for unit tests. - Most PRs should not touch more than one package. - Changes should be backwards compatible. - If you are adding something to community, do not re-import it in langchain. If no one reviews your PR within a few days, please @-mention one of baskaryan, efriis, eyurtsev, ccurme, vbarda, hwchase17. --------- Signed-off-by: Anders Swanson Co-authored-by: Chester Curme --- .../embeddings/oci_generative_ai.py | 29 +++++++++++++------ 1 file changed, 20 insertions(+), 9 deletions(-) diff --git a/libs/community/langchain_community/embeddings/oci_generative_ai.py b/libs/community/langchain_community/embeddings/oci_generative_ai.py index b3e428fe2f..dcfa38f648 100644 --- a/libs/community/langchain_community/embeddings/oci_generative_ai.py +++ b/libs/community/langchain_community/embeddings/oci_generative_ai.py @@ -1,5 +1,5 @@ from enum import Enum -from typing import Any, Dict, List, Mapping, Optional +from typing import Any, Dict, Iterator, List, Mapping, Optional from langchain_core.embeddings import Embeddings from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator @@ -80,6 +80,10 @@ class OCIGenAIEmbeddings(BaseModel, Embeddings): truncate: Optional[str] = "END" """Truncate embeddings that are too long from start or end ("NONE"|"START"|"END")""" + batch_size: int = 96 + """Batch size of OCI GenAI embedding requests. OCI GenAI may handle up to 96 texts + per request""" + class Config: """Configuration for this pydantic object.""" @@ -182,16 +186,23 @@ class OCIGenAIEmbeddings(BaseModel, Embeddings): else: serving_mode = models.OnDemandServingMode(model_id=self.model_id) - invocation_obj = models.EmbedTextDetails( - serving_mode=serving_mode, - compartment_id=self.compartment_id, - truncate=self.truncate, - inputs=texts, - ) + embeddings = [] + + def split_texts() -> Iterator[List[str]]: + for i in range(0, len(texts), self.batch_size): + yield texts[i : i + self.batch_size] - response = self.client.embed_text(invocation_obj) + for chunk in split_texts(): + invocation_obj = models.EmbedTextDetails( + serving_mode=serving_mode, + compartment_id=self.compartment_id, + truncate=self.truncate, + inputs=chunk, + ) + response = self.client.embed_text(invocation_obj) + embeddings.extend(response.data.embeddings) - return response.data.embeddings + return embeddings def embed_query(self, text: str) -> List[float]: """Call out to OCIGenAI's embedding endpoint.