diff --git a/libs/langchain/langchain/chat_models/baichuan.py b/libs/langchain/langchain/chat_models/baichuan.py index 91be5c38cf..b0cb78a119 100644 --- a/libs/langchain/langchain/chat_models/baichuan.py +++ b/libs/langchain/langchain/chat_models/baichuan.py @@ -28,6 +28,8 @@ from langchain.utils import get_from_dict_or_env, get_pydantic_field_names logger = logging.getLogger(__name__) +DEFAULT_API_BASE = "https://api.baichuan-ai.com/v1" + def _convert_message_to_dict(message: BaseMessage) -> dict: message_dict: Dict[str, Any] @@ -101,7 +103,7 @@ class ChatBaichuan(BaseChatModel): def lc_serializable(self) -> bool: return True - baichuan_api_base: str = "https://api.baichuan-ai.com" + baichuan_api_base: str = Field(default=DEFAULT_API_BASE) """Baichuan custom endpoints""" baichuan_api_key: Optional[str] = None """Baichuan API Key""" @@ -162,6 +164,7 @@ class ChatBaichuan(BaseChatModel): values, "baichuan_api_base", "BAICHUAN_API_BASE", + DEFAULT_API_BASE, ) values["baichuan_api_key"] = get_from_dict_or_env( values, @@ -252,7 +255,7 @@ class ChatBaichuan(BaseChatModel): timestamp = int(time.time()) - url = f"{self.baichuan_api_base}/v1" + url = self.baichuan_api_base if self.streaming: url = f"{url}/stream" url = f"{url}/chat" diff --git a/libs/langchain/langchain/chat_models/hunyuan.py b/libs/langchain/langchain/chat_models/hunyuan.py index 3f0b7261b8..b87f2748ed 100644 --- a/libs/langchain/langchain/chat_models/hunyuan.py +++ b/libs/langchain/langchain/chat_models/hunyuan.py @@ -31,8 +31,8 @@ from langchain.utils import get_from_dict_or_env, get_pydantic_field_names logger = logging.getLogger(__name__) -DEFAULT_HUNYUAN_API_BASE = "https://hunyuan.cloud.tencent.com" -DEFAULT_HUNYUAN_PATH = "/hyllm/v1/chat/completions" +DEFAULT_API_BASE = "https://hunyuan.cloud.tencent.com" +DEFAULT_PATH = "/hyllm/v1/chat/completions" def _convert_message_to_dict(message: BaseMessage) -> dict: @@ -141,7 +141,7 @@ class ChatHunyuan(BaseChatModel): def lc_serializable(self) -> bool: return True - hunyuan_api_base: str = "https://hunyuan.cloud.tencent.com" + hunyuan_api_base: str = Field(default=DEFAULT_API_BASE) """Hunyuan custom endpoints""" hunyuan_app_id: Optional[str] = None """Hunyuan App ID""" @@ -201,6 +201,7 @@ class ChatHunyuan(BaseChatModel): values, "hunyuan_api_base", "HUNYUAN_API_BASE", + DEFAULT_API_BASE, ) values["hunyuan_app_id"] = get_from_dict_or_env( values, @@ -303,7 +304,7 @@ class ChatHunyuan(BaseChatModel): if self.streaming: payload["stream"] = 1 - url = self.hunyuan_api_base + DEFAULT_HUNYUAN_PATH + url = self.hunyuan_api_base + DEFAULT_PATH res = requests.post( url=url,