@ -1,7 +1,5 @@
import hashlib
import json
import logging
import time
from typing import Any , Dict , Iterator , List , Mapping , Optional , Type
import requests
@ -30,7 +28,7 @@ from langchain_core.utils import (
logger = logging . getLogger ( __name__ )
DEFAULT_API_BASE = " https://api.baichuan-ai.com/v1 "
DEFAULT_API_BASE = " https://api.baichuan-ai.com/v1 /chat/completions "
def _convert_message_to_dict ( message : BaseMessage ) - > dict :
@ -73,14 +71,6 @@ def _convert_delta_to_message_chunk(
return default_class ( content = content )
# signature generation
def _signature ( secret_key : SecretStr , payload : Dict [ str , Any ] , timestamp : int ) - > str :
input_str = secret_key . get_secret_value ( ) + json . dumps ( payload ) + str ( timestamp )
md5 = hashlib . md5 ( )
md5 . update ( input_str . encode ( " utf-8 " ) )
return md5 . hexdigest ( )
class ChatBaichuan ( BaseChatModel ) :
""" Baichuan chat models API by Baichuan Intelligent Technology.
@ -91,7 +81,6 @@ class ChatBaichuan(BaseChatModel):
def lc_secrets ( self ) - > Dict [ str , str ] :
return {
" baichuan_api_key " : " BAICHUAN_API_KEY " ,
" baichuan_secret_key " : " BAICHUAN_SECRET_KEY " ,
}
@property
@ -103,14 +92,14 @@ class ChatBaichuan(BaseChatModel):
baichuan_api_key : Optional [ SecretStr ] = None
""" Baichuan API Key """
baichuan_secret_key : Optional [ SecretStr ] = None
""" Baichuan Secret Key"""
""" [DEPRECATED, keeping it for for backward compatibility] Baichuan Secret Key"""
streaming : bool = False
""" Whether to stream the results or not. """
request_timeout : int = 60
""" request timeout for chat http requests """
model = " Baichuan2-53B "
""" model name of Baichuan, default is `Baichuan2-53B`. """
model = " Baichuan2-Turbo-192K "
""" model name of Baichuan, default is `Baichuan2-Turbo-192K`,
other options include ` Baichuan2 - Turbo ` """
temperature : float = 0.3
""" What sampling temperature to use. """
top_k : int = 5
@ -168,13 +157,6 @@ class ChatBaichuan(BaseChatModel):
" BAICHUAN_API_KEY " ,
)
)
values [ " baichuan_secret_key " ] = convert_to_secret_str (
get_from_dict_or_env (
values ,
" baichuan_secret_key " ,
" BAICHUAN_SECRET_KEY " ,
)
)
return values
@ -187,6 +169,7 @@ class ChatBaichuan(BaseChatModel):
" top_p " : self . top_p ,
" top_k " : self . top_k ,
" with_search_enhance " : self . with_search_enhance ,
" stream " : self . streaming ,
}
return { * * normal_params , * * self . model_kwargs }
@ -205,12 +188,9 @@ class ChatBaichuan(BaseChatModel):
return generate_from_stream ( stream_iter )
res = self . _chat ( messages , * * kwargs )
if res . status_code != 200 :
raise ValueError ( f " Error from Baichuan api response: { res } " )
response = res . json ( )
if response . get ( " code " ) != 0 :
raise ValueError ( f " Error from Baichuan api response: { response } " )
return self . _create_chat_result ( response )
def _stream (
@ -221,43 +201,49 @@ class ChatBaichuan(BaseChatModel):
* * kwargs : Any ,
) - > Iterator [ ChatGenerationChunk ] :
res = self . _chat ( messages , * * kwargs )
if res . status_code != 200 :
raise ValueError ( f " Error from Baichuan api response: { res } " )
default_chunk_class = AIMessageChunk
for chunk in res . iter_lines ( ) :
chunk = chunk . decode ( " utf-8 " ) . strip ( " \r \n " )
parts = chunk . split ( " data: " , 1 )
chunk = parts [ 1 ] if len ( parts ) > 1 else None
if chunk is None :
continue
if chunk == " [DONE] " :
break
response = json . loads ( chunk )
if response . get ( " code " ) != 0 :
raise ValueError ( f " Error from Baichuan api response: { response } " )
data = response . get ( " data " )
for m in data . get ( " messages " ) :
chunk = _convert_delta_to_message_chunk ( m , default_chunk_class )
for m in response . get ( " choices " ) :
chunk = _convert_delta_to_message_chunk (
m . get ( " delta " ) , default_chunk_class
)
default_chunk_class = chunk . __class__
yield ChatGenerationChunk ( message = chunk )
if run_manager :
run_manager . on_llm_new_token ( chunk . content )
def _chat ( self , messages : List [ BaseMessage ] , * * kwargs : Any ) - > requests . Response :
if self . baichuan_secret_key is None :
raise ValueError ( " Baichuan secret key is not set. " )
parameters = { * * self . _default_params , * * kwargs }
model = parameters . pop ( " model " )
headers = parameters . pop ( " headers " , { } )
temperature = parameters . pop ( " temperature " , 0.3 )
top_k = parameters . pop ( " top_k " , 5 )
top_p = parameters . pop ( " top_p " , 0.85 )
with_search_enhance = parameters . pop ( " with_search_enhance " , False )
stream = parameters . pop ( " stream " , False )
payload = {
" model " : model ,
" messages " : [ _convert_message_to_dict ( m ) for m in messages ] ,
" parameters " : parameters ,
" top_k " : top_k ,
" top_p " : top_p ,
" temperature " : temperature ,
" with_search_enhance " : with_search_enhance ,
" stream " : stream ,
}
timestamp = int ( time . time ( ) )
url = self . baichuan_api_base
if self . streaming :
url = f " { url } /stream "
url = f " { url } /chat "
api_key = " "
if self . baichuan_api_key :
api_key = self . baichuan_api_key . get_secret_value ( )
@ -268,13 +254,6 @@ class ChatBaichuan(BaseChatModel):
headers = {
" Content-Type " : " application/json " ,
" Authorization " : f " Bearer { api_key } " ,
" X-BC-Timestamp " : str ( timestamp ) ,
" X-BC-Signature " : _signature (
secret_key = self . baichuan_secret_key ,
payload = payload ,
timestamp = timestamp ,
) ,
" X-BC-Sign-Algo " : " MD5 " ,
* * headers ,
} ,
json = payload ,
@ -284,8 +263,8 @@ class ChatBaichuan(BaseChatModel):
def _create_chat_result ( self , response : Mapping [ str , Any ] ) - > ChatResult :
generations = [ ]
for m in response [ " data " ] [ " messag es" ] :
message = _convert_dict_to_message ( m)
for c in response [ " choic es" ] :
message = _convert_dict_to_message ( c[ " message" ] )
gen = ChatGeneration ( message = message )
generations . append ( gen )