mirror of
https://github.com/hwchase17/langchain
synced 2024-11-06 03:20:49 +00:00
Supported custom ernie_api_base & Implemented asynchronous for ErnieEmbeddings (#10398)
Description: Supported custom ernie_api_base & Implemented asynchronous for ErnieEmbeddings - ernie_api_base:Support Ernie Service custom endpoints - Support asynchronous Issue: None Dependencies: None Tag maintainer: Twitter handle: @JohnMai95
This commit is contained in:
parent
e0d45e6a09
commit
ee3f950a67
@ -1,5 +1,7 @@
|
|||||||
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
import threading
|
import threading
|
||||||
|
from functools import partial
|
||||||
from typing import Dict, List, Optional
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
@ -14,6 +16,7 @@ logger = logging.getLogger(__name__)
|
|||||||
class ErnieEmbeddings(BaseModel, Embeddings):
|
class ErnieEmbeddings(BaseModel, Embeddings):
|
||||||
"""`Ernie Embeddings V1` embedding models."""
|
"""`Ernie Embeddings V1` embedding models."""
|
||||||
|
|
||||||
|
ernie_api_base: Optional[str] = None
|
||||||
ernie_client_id: Optional[str] = None
|
ernie_client_id: Optional[str] = None
|
||||||
ernie_client_secret: Optional[str] = None
|
ernie_client_secret: Optional[str] = None
|
||||||
access_token: Optional[str] = None
|
access_token: Optional[str] = None
|
||||||
@ -26,6 +29,9 @@ class ErnieEmbeddings(BaseModel, Embeddings):
|
|||||||
|
|
||||||
@root_validator()
|
@root_validator()
|
||||||
def validate_environment(cls, values: Dict) -> Dict:
|
def validate_environment(cls, values: Dict) -> Dict:
|
||||||
|
values["ernie_api_base"] = get_from_dict_or_env(
|
||||||
|
values, "ernie_api_base", "ERNIE_API_BASE", "https://aip.baidubce.com"
|
||||||
|
)
|
||||||
values["ernie_client_id"] = get_from_dict_or_env(
|
values["ernie_client_id"] = get_from_dict_or_env(
|
||||||
values,
|
values,
|
||||||
"ernie_client_id",
|
"ernie_client_id",
|
||||||
@ -40,7 +46,7 @@ class ErnieEmbeddings(BaseModel, Embeddings):
|
|||||||
|
|
||||||
def _embedding(self, json: object) -> dict:
|
def _embedding(self, json: object) -> dict:
|
||||||
base_url = (
|
base_url = (
|
||||||
"https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings"
|
f"{self.ernie_api_base}/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings"
|
||||||
)
|
)
|
||||||
resp = requests.post(
|
resp = requests.post(
|
||||||
f"{base_url}/embedding-v1",
|
f"{base_url}/embedding-v1",
|
||||||
@ -71,6 +77,15 @@ class ErnieEmbeddings(BaseModel, Embeddings):
|
|||||||
self.access_token = str(resp.json().get("access_token"))
|
self.access_token = str(resp.json().get("access_token"))
|
||||||
|
|
||||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||||
|
"""Embed search docs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
texts: The list of texts to embed
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[List[float]]: List of embeddings, one for each text.
|
||||||
|
"""
|
||||||
|
|
||||||
if not self.access_token:
|
if not self.access_token:
|
||||||
self._refresh_access_token_with_lock()
|
self._refresh_access_token_with_lock()
|
||||||
text_in_chunks = [
|
text_in_chunks = [
|
||||||
@ -90,6 +105,15 @@ class ErnieEmbeddings(BaseModel, Embeddings):
|
|||||||
return lst
|
return lst
|
||||||
|
|
||||||
def embed_query(self, text: str) -> List[float]:
|
def embed_query(self, text: str) -> List[float]:
|
||||||
|
"""Embed query text.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: The text to embed.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[float]: Embeddings for the text.
|
||||||
|
"""
|
||||||
|
|
||||||
if not self.access_token:
|
if not self.access_token:
|
||||||
self._refresh_access_token_with_lock()
|
self._refresh_access_token_with_lock()
|
||||||
resp = self._embedding({"input": [text]})
|
resp = self._embedding({"input": [text]})
|
||||||
@ -100,3 +124,31 @@ class ErnieEmbeddings(BaseModel, Embeddings):
|
|||||||
else:
|
else:
|
||||||
raise ValueError(f"Error from Ernie: {resp}")
|
raise ValueError(f"Error from Ernie: {resp}")
|
||||||
return resp["data"][0]["embedding"]
|
return resp["data"][0]["embedding"]
|
||||||
|
|
||||||
|
async def aembed_query(self, text: str) -> List[float]:
|
||||||
|
"""Asynchronous Embed query text.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: The text to embed.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[float]: Embeddings for the text.
|
||||||
|
"""
|
||||||
|
|
||||||
|
return await asyncio.get_running_loop().run_in_executor(
|
||||||
|
None, partial(self.embed_query, text)
|
||||||
|
)
|
||||||
|
|
||||||
|
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||||
|
"""Asynchronous Embed search docs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
texts: The list of texts to embed
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[List[float]]: List of embeddings, one for each text.
|
||||||
|
"""
|
||||||
|
|
||||||
|
result = await asyncio.gather(*[self.aembed_query(text) for text in texts])
|
||||||
|
|
||||||
|
return list(result)
|
||||||
|
Loading…
Reference in New Issue
Block a user