diff --git a/libs/community/langchain_community/embeddings/oci_generative_ai.py b/libs/community/langchain_community/embeddings/oci_generative_ai.py index b3e428fe2f..dcfa38f648 100644 --- a/libs/community/langchain_community/embeddings/oci_generative_ai.py +++ b/libs/community/langchain_community/embeddings/oci_generative_ai.py @@ -1,5 +1,5 @@ from enum import Enum -from typing import Any, Dict, List, Mapping, Optional +from typing import Any, Dict, Iterator, List, Mapping, Optional from langchain_core.embeddings import Embeddings from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator @@ -80,6 +80,10 @@ class OCIGenAIEmbeddings(BaseModel, Embeddings): truncate: Optional[str] = "END" """Truncate embeddings that are too long from start or end ("NONE"|"START"|"END")""" + batch_size: int = 96 + """Batch size of OCI GenAI embedding requests. OCI GenAI may handle up to 96 texts + per request""" + class Config: """Configuration for this pydantic object.""" @@ -182,16 +186,23 @@ class OCIGenAIEmbeddings(BaseModel, Embeddings): else: serving_mode = models.OnDemandServingMode(model_id=self.model_id) - invocation_obj = models.EmbedTextDetails( - serving_mode=serving_mode, - compartment_id=self.compartment_id, - truncate=self.truncate, - inputs=texts, - ) + embeddings = [] + + def split_texts() -> Iterator[List[str]]: + for i in range(0, len(texts), self.batch_size): + yield texts[i : i + self.batch_size] - response = self.client.embed_text(invocation_obj) + for chunk in split_texts(): + invocation_obj = models.EmbedTextDetails( + serving_mode=serving_mode, + compartment_id=self.compartment_id, + truncate=self.truncate, + inputs=chunk, + ) + response = self.client.embed_text(invocation_obj) + embeddings.extend(response.data.embeddings) - return response.data.embeddings + return embeddings def embed_query(self, text: str) -> List[float]: """Call out to OCIGenAI's embedding endpoint.