Supported custom ernie_api_base for Ernie (#10416)

Description: Supported custom ernie_api_base for Ernie
 - ernie_api_base:Support Ernie custom endpoints
 - Rectifying omitted code modifications. #10398

Issue: None
Dependencies: None
Tag maintainer: @baskaryan 
Twitter handle: @JohnMai95
This commit is contained in:
John Mai 2023-09-12 06:50:07 +08:00 committed by GitHub
parent 70b6897dc1
commit b50d724114
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 9 additions and 3 deletions

View File

@ -56,6 +56,9 @@ class ErnieBotChat(BaseChatModel):
""" """
ernie_api_base: Optional[str] = None
"""Baidu application custom endpoints"""
ernie_client_id: Optional[str] = None ernie_client_id: Optional[str] = None
"""Baidu application client id""" """Baidu application client id"""
@ -84,6 +87,9 @@ class ErnieBotChat(BaseChatModel):
@root_validator() @root_validator()
def validate_environment(cls, values: Dict) -> Dict: def validate_environment(cls, values: Dict) -> Dict:
values["ernie_api_base"] = get_from_dict_or_env(
values, "ernie_api_base", "ERNIE_API_BASE", "https://aip.baidubce.com"
)
values["ernie_client_id"] = get_from_dict_or_env( values["ernie_client_id"] = get_from_dict_or_env(
values, values,
"ernie_client_id", "ernie_client_id",
@ -97,7 +103,7 @@ class ErnieBotChat(BaseChatModel):
return values return values
def _chat(self, payload: object) -> dict: def _chat(self, payload: object) -> dict:
base_url = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat" base_url = f"{self.ernie_api_base}/rpc/2.0/ai_custom/v1/wenxinworkshop/chat"
model_paths = { model_paths = {
"ERNIE-Bot-turbo": "eb-instant", "ERNIE-Bot-turbo": "eb-instant",
"ERNIE-Bot": "completions", "ERNIE-Bot": "completions",
@ -125,7 +131,7 @@ class ErnieBotChat(BaseChatModel):
def _refresh_access_token_with_lock(self) -> None: def _refresh_access_token_with_lock(self) -> None:
with self._lock: with self._lock:
logger.debug("Refreshing access token") logger.debug("Refreshing access token")
base_url: str = "https://aip.baidubce.com/oauth/2.0/token" base_url: str = f"{self.ernie_api_base}/oauth/2.0/token"
resp = requests.post( resp = requests.post(
base_url, base_url,
timeout=10, timeout=10,

View File

@ -61,7 +61,7 @@ class ErnieEmbeddings(BaseModel, Embeddings):
def _refresh_access_token_with_lock(self) -> None: def _refresh_access_token_with_lock(self) -> None:
with self._lock: with self._lock:
logger.debug("Refreshing access token") logger.debug("Refreshing access token")
base_url: str = "https://aip.baidubce.com/oauth/2.0/token" base_url: str = f"{self.ernie_api_base}/oauth/2.0/token"
resp = requests.post( resp = requests.post(
base_url, base_url,
headers={ headers={