@ -1,6 +1,6 @@
from __future__ import annotations
from typing import Any , Dict , Iterator , List , Mapping , Optional
from typing import Any , Dict , Iterator , List , Mapping , Optional , cast
from langchain_core . messages import (
AIMessage ,
@ -33,9 +33,7 @@ def _convert_message_to_dict(message: BaseMessage) -> dict:
def convert_dict_to_message ( _dict : Mapping [ str , Any ] ) - > AIMessage :
content = _dict . get ( " choice " , { } ) . get ( " message " , { } ) . get ( " content " , " " )
return AIMessage (
content = content ,
)
return AIMessage ( content = content )
class VolcEngineMaasChat ( BaseChatModel , VolcEngineMaasBase ) :
@ -118,7 +116,7 @@ class VolcEngineMaasChat(BaseChatModel, VolcEngineMaasBase):
msg = convert_dict_to_message ( res )
yield ChatGenerationChunk ( message = AIMessageChunk ( content = msg . content ) )
if run_manager :
run_manager . on_llm_new_token ( msg. content )
run_manager . on_llm_new_token ( cast( str , msg. content ) )
def _generate (
self ,
@ -135,7 +133,7 @@ class VolcEngineMaasChat(BaseChatModel, VolcEngineMaasBase):
params = self . _convert_prompt_msg_params ( messages , * * kwargs )
res = self . client . chat ( params )
msg = convert_dict_to_message ( res )
completion = msg. content
completion = cast( str , msg. content )
message = AIMessage ( content = completion )
return ChatResult ( generations = [ ChatGeneration ( message = message ) ] )