Supported custom ernie_api_base & Implemented asynchronous for ErnieEmbeddings (#10398)

Description: Supported custom ernie_api_base & Implemented asynchronous
for ErnieEmbeddings
 - ernie_api_base:Support Ernie Service custom endpoints
 - Support asynchronous 

Issue: None
Dependencies: None
Tag maintainer:
Twitter handle: @JohnMai95
This commit is contained in:
John Mai 2023-09-10 07:57:16 +08:00 committed by GitHub
parent e0d45e6a09
commit ee3f950a67
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1,5 +1,7 @@
import asyncio
import logging import logging
import threading import threading
from functools import partial
from typing import Dict, List, Optional from typing import Dict, List, Optional
import requests import requests
@ -14,6 +16,7 @@ logger = logging.getLogger(__name__)
class ErnieEmbeddings(BaseModel, Embeddings): class ErnieEmbeddings(BaseModel, Embeddings):
"""`Ernie Embeddings V1` embedding models.""" """`Ernie Embeddings V1` embedding models."""
ernie_api_base: Optional[str] = None
ernie_client_id: Optional[str] = None ernie_client_id: Optional[str] = None
ernie_client_secret: Optional[str] = None ernie_client_secret: Optional[str] = None
access_token: Optional[str] = None access_token: Optional[str] = None
@ -26,6 +29,9 @@ class ErnieEmbeddings(BaseModel, Embeddings):
@root_validator() @root_validator()
def validate_environment(cls, values: Dict) -> Dict: def validate_environment(cls, values: Dict) -> Dict:
values["ernie_api_base"] = get_from_dict_or_env(
values, "ernie_api_base", "ERNIE_API_BASE", "https://aip.baidubce.com"
)
values["ernie_client_id"] = get_from_dict_or_env( values["ernie_client_id"] = get_from_dict_or_env(
values, values,
"ernie_client_id", "ernie_client_id",
@ -40,7 +46,7 @@ class ErnieEmbeddings(BaseModel, Embeddings):
def _embedding(self, json: object) -> dict: def _embedding(self, json: object) -> dict:
base_url = ( base_url = (
"https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings" f"{self.ernie_api_base}/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings"
) )
resp = requests.post( resp = requests.post(
f"{base_url}/embedding-v1", f"{base_url}/embedding-v1",
@ -71,6 +77,15 @@ class ErnieEmbeddings(BaseModel, Embeddings):
self.access_token = str(resp.json().get("access_token")) self.access_token = str(resp.json().get("access_token"))
def embed_documents(self, texts: List[str]) -> List[List[float]]: def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""Embed search docs.
Args:
texts: The list of texts to embed
Returns:
List[List[float]]: List of embeddings, one for each text.
"""
if not self.access_token: if not self.access_token:
self._refresh_access_token_with_lock() self._refresh_access_token_with_lock()
text_in_chunks = [ text_in_chunks = [
@ -90,6 +105,15 @@ class ErnieEmbeddings(BaseModel, Embeddings):
return lst return lst
def embed_query(self, text: str) -> List[float]: def embed_query(self, text: str) -> List[float]:
"""Embed query text.
Args:
text: The text to embed.
Returns:
List[float]: Embeddings for the text.
"""
if not self.access_token: if not self.access_token:
self._refresh_access_token_with_lock() self._refresh_access_token_with_lock()
resp = self._embedding({"input": [text]}) resp = self._embedding({"input": [text]})
@ -100,3 +124,31 @@ class ErnieEmbeddings(BaseModel, Embeddings):
else: else:
raise ValueError(f"Error from Ernie: {resp}") raise ValueError(f"Error from Ernie: {resp}")
return resp["data"][0]["embedding"] return resp["data"][0]["embedding"]
async def aembed_query(self, text: str) -> List[float]:
"""Asynchronous Embed query text.
Args:
text: The text to embed.
Returns:
List[float]: Embeddings for the text.
"""
return await asyncio.get_running_loop().run_in_executor(
None, partial(self.embed_query, text)
)
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
"""Asynchronous Embed search docs.
Args:
texts: The list of texts to embed
Returns:
List[List[float]]: List of embeddings, one for each text.
"""
result = await asyncio.gather(*[self.aembed_query(text) for text in texts])
return list(result)