mirror of
https://github.com/hwchase17/langchain
synced 2024-11-06 03:20:49 +00:00
Supported custom ernie_api_base for Ernie (#10416)
Description: Supported custom ernie_api_base for Ernie - ernie_api_base:Support Ernie custom endpoints - Rectifying omitted code modifications. #10398 Issue: None Dependencies: None Tag maintainer: @baskaryan Twitter handle: @JohnMai95
This commit is contained in:
parent
70b6897dc1
commit
b50d724114
@ -56,6 +56,9 @@ class ErnieBotChat(BaseChatModel):
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
ernie_api_base: Optional[str] = None
|
||||||
|
"""Baidu application custom endpoints"""
|
||||||
|
|
||||||
ernie_client_id: Optional[str] = None
|
ernie_client_id: Optional[str] = None
|
||||||
"""Baidu application client id"""
|
"""Baidu application client id"""
|
||||||
|
|
||||||
@ -84,6 +87,9 @@ class ErnieBotChat(BaseChatModel):
|
|||||||
|
|
||||||
@root_validator()
|
@root_validator()
|
||||||
def validate_environment(cls, values: Dict) -> Dict:
|
def validate_environment(cls, values: Dict) -> Dict:
|
||||||
|
values["ernie_api_base"] = get_from_dict_or_env(
|
||||||
|
values, "ernie_api_base", "ERNIE_API_BASE", "https://aip.baidubce.com"
|
||||||
|
)
|
||||||
values["ernie_client_id"] = get_from_dict_or_env(
|
values["ernie_client_id"] = get_from_dict_or_env(
|
||||||
values,
|
values,
|
||||||
"ernie_client_id",
|
"ernie_client_id",
|
||||||
@ -97,7 +103,7 @@ class ErnieBotChat(BaseChatModel):
|
|||||||
return values
|
return values
|
||||||
|
|
||||||
def _chat(self, payload: object) -> dict:
|
def _chat(self, payload: object) -> dict:
|
||||||
base_url = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat"
|
base_url = f"{self.ernie_api_base}/rpc/2.0/ai_custom/v1/wenxinworkshop/chat"
|
||||||
model_paths = {
|
model_paths = {
|
||||||
"ERNIE-Bot-turbo": "eb-instant",
|
"ERNIE-Bot-turbo": "eb-instant",
|
||||||
"ERNIE-Bot": "completions",
|
"ERNIE-Bot": "completions",
|
||||||
@ -125,7 +131,7 @@ class ErnieBotChat(BaseChatModel):
|
|||||||
def _refresh_access_token_with_lock(self) -> None:
|
def _refresh_access_token_with_lock(self) -> None:
|
||||||
with self._lock:
|
with self._lock:
|
||||||
logger.debug("Refreshing access token")
|
logger.debug("Refreshing access token")
|
||||||
base_url: str = "https://aip.baidubce.com/oauth/2.0/token"
|
base_url: str = f"{self.ernie_api_base}/oauth/2.0/token"
|
||||||
resp = requests.post(
|
resp = requests.post(
|
||||||
base_url,
|
base_url,
|
||||||
timeout=10,
|
timeout=10,
|
||||||
|
@ -61,7 +61,7 @@ class ErnieEmbeddings(BaseModel, Embeddings):
|
|||||||
def _refresh_access_token_with_lock(self) -> None:
|
def _refresh_access_token_with_lock(self) -> None:
|
||||||
with self._lock:
|
with self._lock:
|
||||||
logger.debug("Refreshing access token")
|
logger.debug("Refreshing access token")
|
||||||
base_url: str = "https://aip.baidubce.com/oauth/2.0/token"
|
base_url: str = f"{self.ernie_api_base}/oauth/2.0/token"
|
||||||
resp = requests.post(
|
resp = requests.post(
|
||||||
base_url,
|
base_url,
|
||||||
headers={
|
headers={
|
||||||
|
Loading…
Reference in New Issue
Block a user