2023-12-11 21:53:30 +00:00
|
|
|
import asyncio
|
|
|
|
import logging
|
|
|
|
import threading
|
|
|
|
from typing import Dict, List, Optional
|
|
|
|
|
|
|
|
import requests
|
2024-01-15 19:14:44 +00:00
|
|
|
from langchain_core._api.deprecation import deprecated
|
2023-12-11 21:53:30 +00:00
|
|
|
from langchain_core.embeddings import Embeddings
|
|
|
|
from langchain_core.pydantic_v1 import BaseModel, root_validator
|
2023-12-29 20:34:03 +00:00
|
|
|
from langchain_core.runnables.config import run_in_executor
|
2023-12-11 21:53:30 +00:00
|
|
|
from langchain_core.utils import get_from_dict_or_env
|
|
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
2024-01-15 19:14:44 +00:00
|
|
|
@deprecated(
|
|
|
|
since="0.0.13",
|
|
|
|
alternative="langchain_community.embeddings.QianfanEmbeddingsEndpoint",
|
|
|
|
)
|
2023-12-11 21:53:30 +00:00
|
|
|
class ErnieEmbeddings(BaseModel, Embeddings):
|
|
|
|
"""`Ernie Embeddings V1` embedding models."""
|
|
|
|
|
|
|
|
ernie_api_base: Optional[str] = None
|
|
|
|
ernie_client_id: Optional[str] = None
|
|
|
|
ernie_client_secret: Optional[str] = None
|
|
|
|
access_token: Optional[str] = None
|
|
|
|
|
|
|
|
chunk_size: int = 16
|
|
|
|
|
|
|
|
model_name = "ErnieBot-Embedding-V1"
|
|
|
|
|
|
|
|
_lock = threading.Lock()
|
|
|
|
|
|
|
|
@root_validator()
|
|
|
|
def validate_environment(cls, values: Dict) -> Dict:
|
|
|
|
values["ernie_api_base"] = get_from_dict_or_env(
|
|
|
|
values, "ernie_api_base", "ERNIE_API_BASE", "https://aip.baidubce.com"
|
|
|
|
)
|
|
|
|
values["ernie_client_id"] = get_from_dict_or_env(
|
|
|
|
values,
|
|
|
|
"ernie_client_id",
|
|
|
|
"ERNIE_CLIENT_ID",
|
|
|
|
)
|
|
|
|
values["ernie_client_secret"] = get_from_dict_or_env(
|
|
|
|
values,
|
|
|
|
"ernie_client_secret",
|
|
|
|
"ERNIE_CLIENT_SECRET",
|
|
|
|
)
|
|
|
|
return values
|
|
|
|
|
|
|
|
def _embedding(self, json: object) -> dict:
|
|
|
|
base_url = (
|
|
|
|
f"{self.ernie_api_base}/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings"
|
|
|
|
)
|
|
|
|
resp = requests.post(
|
|
|
|
f"{base_url}/embedding-v1",
|
|
|
|
headers={
|
|
|
|
"Content-Type": "application/json",
|
|
|
|
},
|
|
|
|
params={"access_token": self.access_token},
|
|
|
|
json=json,
|
|
|
|
)
|
|
|
|
return resp.json()
|
|
|
|
|
|
|
|
def _refresh_access_token_with_lock(self) -> None:
|
|
|
|
with self._lock:
|
|
|
|
logger.debug("Refreshing access token")
|
|
|
|
base_url: str = f"{self.ernie_api_base}/oauth/2.0/token"
|
|
|
|
resp = requests.post(
|
|
|
|
base_url,
|
|
|
|
headers={
|
|
|
|
"Content-Type": "application/json",
|
|
|
|
"Accept": "application/json",
|
|
|
|
},
|
|
|
|
params={
|
|
|
|
"grant_type": "client_credentials",
|
|
|
|
"client_id": self.ernie_client_id,
|
|
|
|
"client_secret": self.ernie_client_secret,
|
|
|
|
},
|
|
|
|
)
|
|
|
|
self.access_token = str(resp.json().get("access_token"))
|
|
|
|
|
|
|
|
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
|
|
|
"""Embed search docs.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
texts: The list of texts to embed
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
List[List[float]]: List of embeddings, one for each text.
|
|
|
|
"""
|
|
|
|
|
|
|
|
if not self.access_token:
|
|
|
|
self._refresh_access_token_with_lock()
|
|
|
|
text_in_chunks = [
|
|
|
|
texts[i : i + self.chunk_size]
|
|
|
|
for i in range(0, len(texts), self.chunk_size)
|
|
|
|
]
|
|
|
|
lst = []
|
|
|
|
for chunk in text_in_chunks:
|
|
|
|
resp = self._embedding({"input": [text for text in chunk]})
|
|
|
|
if resp.get("error_code"):
|
|
|
|
if resp.get("error_code") == 111:
|
|
|
|
self._refresh_access_token_with_lock()
|
|
|
|
resp = self._embedding({"input": [text for text in chunk]})
|
|
|
|
else:
|
|
|
|
raise ValueError(f"Error from Ernie: {resp}")
|
|
|
|
lst.extend([i["embedding"] for i in resp["data"]])
|
|
|
|
return lst
|
|
|
|
|
|
|
|
def embed_query(self, text: str) -> List[float]:
|
|
|
|
"""Embed query text.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
text: The text to embed.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
List[float]: Embeddings for the text.
|
|
|
|
"""
|
|
|
|
|
|
|
|
if not self.access_token:
|
|
|
|
self._refresh_access_token_with_lock()
|
|
|
|
resp = self._embedding({"input": [text]})
|
|
|
|
if resp.get("error_code"):
|
|
|
|
if resp.get("error_code") == 111:
|
|
|
|
self._refresh_access_token_with_lock()
|
|
|
|
resp = self._embedding({"input": [text]})
|
|
|
|
else:
|
|
|
|
raise ValueError(f"Error from Ernie: {resp}")
|
|
|
|
return resp["data"][0]["embedding"]
|
|
|
|
|
|
|
|
async def aembed_query(self, text: str) -> List[float]:
|
|
|
|
"""Asynchronous Embed query text.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
text: The text to embed.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
List[float]: Embeddings for the text.
|
|
|
|
"""
|
|
|
|
|
2023-12-29 20:34:03 +00:00
|
|
|
return await run_in_executor(None, self.embed_query, text)
|
2023-12-11 21:53:30 +00:00
|
|
|
|
|
|
|
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
|
|
|
|
"""Asynchronous Embed search docs.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
texts: The list of texts to embed
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
List[List[float]]: List of embeddings, one for each text.
|
|
|
|
"""
|
|
|
|
|
|
|
|
result = await asyncio.gather(*[self.aembed_query(text) for text in texts])
|
|
|
|
|
|
|
|
return list(result)
|