community: OCI GenAI embedding batch size (#22986)

Thank you for contributing to LangChain!

- [x] **PR title**: "community: OCI GenAI embedding batch size"



- [x] **PR message**:
    - **Issue:** #22985 


- [ ] **Add tests and docs**: N/A


- [x] **Lint and test**: Run `make format`, `make lint` and `make test`
from the root of the package(s) you've modified. See contribution
guidelines for more: https://python.langchain.com/docs/contributing/

Additional guidelines:
- Make sure optional dependencies are imported within a function.
- Please do not add dependencies to pyproject.toml files (even optional
ones) unless they are required for unit tests.
- Most PRs should not touch more than one package.
- Changes should be backwards compatible.
- If you are adding something to community, do not re-import it in
langchain.

If no one reviews your PR within a few days, please @-mention one of
baskaryan, efriis, eyurtsev, ccurme, vbarda, hwchase17.

---------

Signed-off-by: Anders Swanson <anders.swanson@oracle.com>
Co-authored-by: Chester Curme <chester.curme@gmail.com>
pull/23058/head
Anders Swanson 3 weeks ago committed by GitHub
parent 8235bae48e
commit aacc6198b9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -1,5 +1,5 @@
from enum import Enum
from typing import Any, Dict, List, Mapping, Optional
from typing import Any, Dict, Iterator, List, Mapping, Optional
from langchain_core.embeddings import Embeddings
from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator
@ -80,6 +80,10 @@ class OCIGenAIEmbeddings(BaseModel, Embeddings):
truncate: Optional[str] = "END"
"""Truncate embeddings that are too long from start or end ("NONE"|"START"|"END")"""
batch_size: int = 96
"""Batch size of OCI GenAI embedding requests. OCI GenAI may handle up to 96 texts
per request"""
class Config:
"""Configuration for this pydantic object."""
@ -182,16 +186,23 @@ class OCIGenAIEmbeddings(BaseModel, Embeddings):
else:
serving_mode = models.OnDemandServingMode(model_id=self.model_id)
invocation_obj = models.EmbedTextDetails(
serving_mode=serving_mode,
compartment_id=self.compartment_id,
truncate=self.truncate,
inputs=texts,
)
embeddings = []
def split_texts() -> Iterator[List[str]]:
for i in range(0, len(texts), self.batch_size):
yield texts[i : i + self.batch_size]
response = self.client.embed_text(invocation_obj)
for chunk in split_texts():
invocation_obj = models.EmbedTextDetails(
serving_mode=serving_mode,
compartment_id=self.compartment_id,
truncate=self.truncate,
inputs=chunk,
)
response = self.client.embed_text(invocation_obj)
embeddings.extend(response.data.embeddings)
return response.data.embeddings
return embeddings
def embed_query(self, text: str) -> List[float]:
"""Call out to OCIGenAI's embedding endpoint.

Loading…
Cancel
Save