|
|
@ -28,6 +28,8 @@ from langchain.utils import get_from_dict_or_env, get_pydantic_field_names
|
|
|
|
|
|
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
DEFAULT_API_BASE = "https://api.baichuan-ai.com/v1"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _convert_message_to_dict(message: BaseMessage) -> dict:
|
|
|
|
def _convert_message_to_dict(message: BaseMessage) -> dict:
|
|
|
|
message_dict: Dict[str, Any]
|
|
|
|
message_dict: Dict[str, Any]
|
|
|
@ -101,7 +103,7 @@ class ChatBaichuan(BaseChatModel):
|
|
|
|
def lc_serializable(self) -> bool:
|
|
|
|
def lc_serializable(self) -> bool:
|
|
|
|
return True
|
|
|
|
return True
|
|
|
|
|
|
|
|
|
|
|
|
baichuan_api_base: str = "https://api.baichuan-ai.com"
|
|
|
|
baichuan_api_base: str = Field(default=DEFAULT_API_BASE)
|
|
|
|
"""Baichuan custom endpoints"""
|
|
|
|
"""Baichuan custom endpoints"""
|
|
|
|
baichuan_api_key: Optional[str] = None
|
|
|
|
baichuan_api_key: Optional[str] = None
|
|
|
|
"""Baichuan API Key"""
|
|
|
|
"""Baichuan API Key"""
|
|
|
@ -162,6 +164,7 @@ class ChatBaichuan(BaseChatModel):
|
|
|
|
values,
|
|
|
|
values,
|
|
|
|
"baichuan_api_base",
|
|
|
|
"baichuan_api_base",
|
|
|
|
"BAICHUAN_API_BASE",
|
|
|
|
"BAICHUAN_API_BASE",
|
|
|
|
|
|
|
|
DEFAULT_API_BASE,
|
|
|
|
)
|
|
|
|
)
|
|
|
|
values["baichuan_api_key"] = get_from_dict_or_env(
|
|
|
|
values["baichuan_api_key"] = get_from_dict_or_env(
|
|
|
|
values,
|
|
|
|
values,
|
|
|
@ -252,7 +255,7 @@ class ChatBaichuan(BaseChatModel):
|
|
|
|
|
|
|
|
|
|
|
|
timestamp = int(time.time())
|
|
|
|
timestamp = int(time.time())
|
|
|
|
|
|
|
|
|
|
|
|
url = f"{self.baichuan_api_base}/v1"
|
|
|
|
url = self.baichuan_api_base
|
|
|
|
if self.streaming:
|
|
|
|
if self.streaming:
|
|
|
|
url = f"{url}/stream"
|
|
|
|
url = f"{url}/stream"
|
|
|
|
url = f"{url}/chat"
|
|
|
|
url = f"{url}/chat"
|
|
|
|