From b50d724114ba230515913b7b3872c2de5d84fb55 Mon Sep 17 00:00:00 2001 From: John Mai Date: Tue, 12 Sep 2023 06:50:07 +0800 Subject: [PATCH] Supported custom ernie_api_base for Ernie (#10416) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Description: Supported custom ernie_api_base for Ernie - ernie_api_baseļ¼šSupport Ernie custom endpoints - Rectifying omitted code modifications. #10398 Issue: None Dependencies: None Tag maintainer: @baskaryan Twitter handle: @JohnMai95 --- libs/langchain/langchain/chat_models/ernie.py | 10 ++++++++-- libs/langchain/langchain/embeddings/ernie.py | 2 +- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/libs/langchain/langchain/chat_models/ernie.py b/libs/langchain/langchain/chat_models/ernie.py index 367341c11f..dd7c37ed96 100644 --- a/libs/langchain/langchain/chat_models/ernie.py +++ b/libs/langchain/langchain/chat_models/ernie.py @@ -56,6 +56,9 @@ class ErnieBotChat(BaseChatModel): """ + ernie_api_base: Optional[str] = None + """Baidu application custom endpoints""" + ernie_client_id: Optional[str] = None """Baidu application client id""" @@ -84,6 +87,9 @@ class ErnieBotChat(BaseChatModel): @root_validator() def validate_environment(cls, values: Dict) -> Dict: + values["ernie_api_base"] = get_from_dict_or_env( + values, "ernie_api_base", "ERNIE_API_BASE", "https://aip.baidubce.com" + ) values["ernie_client_id"] = get_from_dict_or_env( values, "ernie_client_id", @@ -97,7 +103,7 @@ class ErnieBotChat(BaseChatModel): return values def _chat(self, payload: object) -> dict: - base_url = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat" + base_url = f"{self.ernie_api_base}/rpc/2.0/ai_custom/v1/wenxinworkshop/chat" model_paths = { "ERNIE-Bot-turbo": "eb-instant", "ERNIE-Bot": "completions", @@ -125,7 +131,7 @@ class ErnieBotChat(BaseChatModel): def _refresh_access_token_with_lock(self) -> None: with self._lock: logger.debug("Refreshing access token") - base_url: str = "https://aip.baidubce.com/oauth/2.0/token" + base_url: str = f"{self.ernie_api_base}/oauth/2.0/token" resp = requests.post( base_url, timeout=10, diff --git a/libs/langchain/langchain/embeddings/ernie.py b/libs/langchain/langchain/embeddings/ernie.py index 37723b53ab..77ed2f7641 100644 --- a/libs/langchain/langchain/embeddings/ernie.py +++ b/libs/langchain/langchain/embeddings/ernie.py @@ -61,7 +61,7 @@ class ErnieEmbeddings(BaseModel, Embeddings): def _refresh_access_token_with_lock(self) -> None: with self._lock: logger.debug("Refreshing access token") - base_url: str = "https://aip.baidubce.com/oauth/2.0/token" + base_url: str = f"{self.ernie_api_base}/oauth/2.0/token" resp = requests.post( base_url, headers={