diff --git a/libs/langchain/langchain/embeddings/ernie.py b/libs/langchain/langchain/embeddings/ernie.py index b8213651ad..37723b53ab 100644 --- a/libs/langchain/langchain/embeddings/ernie.py +++ b/libs/langchain/langchain/embeddings/ernie.py @@ -1,5 +1,7 @@ +import asyncio import logging import threading +from functools import partial from typing import Dict, List, Optional import requests @@ -14,6 +16,7 @@ logger = logging.getLogger(__name__) class ErnieEmbeddings(BaseModel, Embeddings): """`Ernie Embeddings V1` embedding models.""" + ernie_api_base: Optional[str] = None ernie_client_id: Optional[str] = None ernie_client_secret: Optional[str] = None access_token: Optional[str] = None @@ -26,6 +29,9 @@ class ErnieEmbeddings(BaseModel, Embeddings): @root_validator() def validate_environment(cls, values: Dict) -> Dict: + values["ernie_api_base"] = get_from_dict_or_env( + values, "ernie_api_base", "ERNIE_API_BASE", "https://aip.baidubce.com" + ) values["ernie_client_id"] = get_from_dict_or_env( values, "ernie_client_id", @@ -40,7 +46,7 @@ class ErnieEmbeddings(BaseModel, Embeddings): def _embedding(self, json: object) -> dict: base_url = ( - "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings" + f"{self.ernie_api_base}/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings" ) resp = requests.post( f"{base_url}/embedding-v1", @@ -71,6 +77,15 @@ class ErnieEmbeddings(BaseModel, Embeddings): self.access_token = str(resp.json().get("access_token")) def embed_documents(self, texts: List[str]) -> List[List[float]]: + """Embed search docs. + + Args: + texts: The list of texts to embed + + Returns: + List[List[float]]: List of embeddings, one for each text. + """ + if not self.access_token: self._refresh_access_token_with_lock() text_in_chunks = [ @@ -90,6 +105,15 @@ class ErnieEmbeddings(BaseModel, Embeddings): return lst def embed_query(self, text: str) -> List[float]: + """Embed query text. + + Args: + text: The text to embed. + + Returns: + List[float]: Embeddings for the text. + """ + if not self.access_token: self._refresh_access_token_with_lock() resp = self._embedding({"input": [text]}) @@ -100,3 +124,31 @@ class ErnieEmbeddings(BaseModel, Embeddings): else: raise ValueError(f"Error from Ernie: {resp}") return resp["data"][0]["embedding"] + + async def aembed_query(self, text: str) -> List[float]: + """Asynchronous Embed query text. + + Args: + text: The text to embed. + + Returns: + List[float]: Embeddings for the text. + """ + + return await asyncio.get_running_loop().run_in_executor( + None, partial(self.embed_query, text) + ) + + async def aembed_documents(self, texts: List[str]) -> List[List[float]]: + """Asynchronous Embed search docs. + + Args: + texts: The list of texts to embed + + Returns: + List[List[float]]: List of embeddings, one for each text. + """ + + result = await asyncio.gather(*[self.aembed_query(text) for text in texts]) + + return list(result)