Baichuan & Hunyuan set default api_base (#12059)

### Description
Baichuan & Hunyuan set default api_base env
pull/12223/head
John Mai 10 months ago committed by GitHub
parent 283a3ecc9c
commit ebf749c40c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -28,6 +28,8 @@ from langchain.utils import get_from_dict_or_env, get_pydantic_field_names
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
DEFAULT_API_BASE = "https://api.baichuan-ai.com/v1"
def _convert_message_to_dict(message: BaseMessage) -> dict: def _convert_message_to_dict(message: BaseMessage) -> dict:
message_dict: Dict[str, Any] message_dict: Dict[str, Any]
@ -101,7 +103,7 @@ class ChatBaichuan(BaseChatModel):
def lc_serializable(self) -> bool: def lc_serializable(self) -> bool:
return True return True
baichuan_api_base: str = "https://api.baichuan-ai.com" baichuan_api_base: str = Field(default=DEFAULT_API_BASE)
"""Baichuan custom endpoints""" """Baichuan custom endpoints"""
baichuan_api_key: Optional[str] = None baichuan_api_key: Optional[str] = None
"""Baichuan API Key""" """Baichuan API Key"""
@ -162,6 +164,7 @@ class ChatBaichuan(BaseChatModel):
values, values,
"baichuan_api_base", "baichuan_api_base",
"BAICHUAN_API_BASE", "BAICHUAN_API_BASE",
DEFAULT_API_BASE,
) )
values["baichuan_api_key"] = get_from_dict_or_env( values["baichuan_api_key"] = get_from_dict_or_env(
values, values,
@ -252,7 +255,7 @@ class ChatBaichuan(BaseChatModel):
timestamp = int(time.time()) timestamp = int(time.time())
url = f"{self.baichuan_api_base}/v1" url = self.baichuan_api_base
if self.streaming: if self.streaming:
url = f"{url}/stream" url = f"{url}/stream"
url = f"{url}/chat" url = f"{url}/chat"

@ -31,8 +31,8 @@ from langchain.utils import get_from_dict_or_env, get_pydantic_field_names
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
DEFAULT_HUNYUAN_API_BASE = "https://hunyuan.cloud.tencent.com" DEFAULT_API_BASE = "https://hunyuan.cloud.tencent.com"
DEFAULT_HUNYUAN_PATH = "/hyllm/v1/chat/completions" DEFAULT_PATH = "/hyllm/v1/chat/completions"
def _convert_message_to_dict(message: BaseMessage) -> dict: def _convert_message_to_dict(message: BaseMessage) -> dict:
@ -141,7 +141,7 @@ class ChatHunyuan(BaseChatModel):
def lc_serializable(self) -> bool: def lc_serializable(self) -> bool:
return True return True
hunyuan_api_base: str = "https://hunyuan.cloud.tencent.com" hunyuan_api_base: str = Field(default=DEFAULT_API_BASE)
"""Hunyuan custom endpoints""" """Hunyuan custom endpoints"""
hunyuan_app_id: Optional[str] = None hunyuan_app_id: Optional[str] = None
"""Hunyuan App ID""" """Hunyuan App ID"""
@ -201,6 +201,7 @@ class ChatHunyuan(BaseChatModel):
values, values,
"hunyuan_api_base", "hunyuan_api_base",
"HUNYUAN_API_BASE", "HUNYUAN_API_BASE",
DEFAULT_API_BASE,
) )
values["hunyuan_app_id"] = get_from_dict_or_env( values["hunyuan_app_id"] = get_from_dict_or_env(
values, values,
@ -303,7 +304,7 @@ class ChatHunyuan(BaseChatModel):
if self.streaming: if self.streaming:
payload["stream"] = 1 payload["stream"] = 1
url = self.hunyuan_api_base + DEFAULT_HUNYUAN_PATH url = self.hunyuan_api_base + DEFAULT_PATH
res = requests.post( res = requests.post(
url=url, url=url,

Loading…
Cancel
Save