@ -166,6 +166,16 @@ class ChatCohere(BaseChatModel, BaseCohere):
if run_manager :
await run_manager . on_llm_new_token ( delta )
def _get_generation_info ( self , response : Any ) - > Dict [ str , Any ] :
""" Get the generation info from cohere API response. """
return {
" documents " : response . documents ,
" citations " : response . citations ,
" search_results " : response . search_results ,
" search_queries " : response . search_queries ,
" token_count " : response . token_count ,
}
def _generate (
self ,
messages : List [ BaseMessage ] ,
@ -185,7 +195,7 @@ class ChatCohere(BaseChatModel, BaseCohere):
message = AIMessage ( content = response . text )
generation_info = None
if hasattr ( response , " documents " ) :
generation_info = { " documents " : response . documents }
generation_info = self . _get_generation_info ( response )
return ChatResult (
generations = [
ChatGeneration ( message = message , generation_info = generation_info )
@ -211,7 +221,7 @@ class ChatCohere(BaseChatModel, BaseCohere):
message = AIMessage ( content = response . text )
generation_info = None
if hasattr ( response , " documents " ) :
generation_info = { " documents " : response . documents }
generation_info = self . _get_generation_info ( response )
return ChatResult (
generations = [
ChatGeneration ( message = message , generation_info = generation_info )