From 9678797625c73b270a4dd6449c9b19f2f031238d Mon Sep 17 00:00:00 2001 From: mackong Date: Fri, 23 Feb 2024 08:15:21 +0800 Subject: [PATCH] community[patch]: callback before yield for _stream/_astream (#17907) - Description: callback on_llm_new_token before yield chunk for _stream/_astream for some chat models, make all chat models in a consistent behaviour. - Issue: N/A - Dependencies: N/A --- .../langchain_community/chat_models/anthropic.py | 4 ++-- .../langchain_community/chat_models/baichuan.py | 2 +- .../chat_models/baidu_qianfan_endpoint.py | 4 ++-- .../langchain_community/chat_models/cohere.py | 4 ++-- .../langchain_community/chat_models/deepinfra.py | 4 ++-- .../langchain_community/chat_models/edenai.py | 14 ++++++-------- .../langchain_community/chat_models/fireworks.py | 16 ++++++++++------ .../langchain_community/chat_models/gigachat.py | 4 ++-- .../chat_models/gpt_router.py | 8 ++++---- .../langchain_community/chat_models/hunyuan.py | 2 +- .../langchain_community/chat_models/jinachat.py | 4 ++-- .../langchain_community/chat_models/konko.py | 8 +++++--- .../langchain_community/chat_models/litellm.py | 4 ++-- .../chat_models/litellm_router.py | 4 ++-- .../chat_models/llama_edge.py | 6 +++--- .../langchain_community/chat_models/ollama.py | 3 +++ .../langchain_community/chat_models/openai.py | 16 ++++++++++------ .../langchain_community/chat_models/sparkllm.py | 2 +- .../langchain_community/chat_models/tongyi.py | 4 ++-- .../chat_models/volcengine_maas.py | 2 +- .../langchain_community/chat_models/yuan2.py | 4 ++-- .../langchain_community/chat_models/zhipuai.py | 2 +- 22 files changed, 66 insertions(+), 55 deletions(-) diff --git a/libs/community/langchain_community/chat_models/anthropic.py b/libs/community/langchain_community/chat_models/anthropic.py index 682f361679..8f900d2bd0 100644 --- a/libs/community/langchain_community/chat_models/anthropic.py +++ b/libs/community/langchain_community/chat_models/anthropic.py @@ -143,9 +143,9 @@ class ChatAnthropic(BaseChatModel, _AnthropicCommon): for data in stream_resp: delta = data.completion chunk = ChatGenerationChunk(message=AIMessageChunk(content=delta)) - yield chunk if run_manager: run_manager.on_llm_new_token(delta, chunk=chunk) + yield chunk async def _astream( self, @@ -163,9 +163,9 @@ class ChatAnthropic(BaseChatModel, _AnthropicCommon): async for data in stream_resp: delta = data.completion chunk = ChatGenerationChunk(message=AIMessageChunk(content=delta)) - yield chunk if run_manager: await run_manager.on_llm_new_token(delta, chunk=chunk) + yield chunk def _generate( self, diff --git a/libs/community/langchain_community/chat_models/baichuan.py b/libs/community/langchain_community/chat_models/baichuan.py index 978f8b1ddc..1b3d8ac093 100644 --- a/libs/community/langchain_community/chat_models/baichuan.py +++ b/libs/community/langchain_community/chat_models/baichuan.py @@ -219,9 +219,9 @@ class ChatBaichuan(BaseChatModel): ) default_chunk_class = chunk.__class__ cg_chunk = ChatGenerationChunk(message=chunk) - yield cg_chunk if run_manager: run_manager.on_llm_new_token(chunk.content, chunk=cg_chunk) + yield cg_chunk def _chat(self, messages: List[BaseMessage], **kwargs: Any) -> requests.Response: parameters = {**self._default_params, **kwargs} diff --git a/libs/community/langchain_community/chat_models/baidu_qianfan_endpoint.py b/libs/community/langchain_community/chat_models/baidu_qianfan_endpoint.py index cf656817a9..309d0621fe 100644 --- a/libs/community/langchain_community/chat_models/baidu_qianfan_endpoint.py +++ b/libs/community/langchain_community/chat_models/baidu_qianfan_endpoint.py @@ -346,9 +346,9 @@ class QianfanChatEndpoint(BaseChatModel): ), generation_info=msg.additional_kwargs, ) - yield chunk if run_manager: run_manager.on_llm_new_token(chunk.text, chunk=chunk) + yield chunk async def _astream( self, @@ -372,6 +372,6 @@ class QianfanChatEndpoint(BaseChatModel): ), generation_info=msg.additional_kwargs, ) - yield chunk if run_manager: await run_manager.on_llm_new_token(chunk.text, chunk=chunk) + yield chunk diff --git a/libs/community/langchain_community/chat_models/cohere.py b/libs/community/langchain_community/chat_models/cohere.py index af7bc307ad..657bca68a3 100644 --- a/libs/community/langchain_community/chat_models/cohere.py +++ b/libs/community/langchain_community/chat_models/cohere.py @@ -148,9 +148,9 @@ class ChatCohere(BaseChatModel, BaseCohere): if data.event_type == "text-generation": delta = data.text chunk = ChatGenerationChunk(message=AIMessageChunk(content=delta)) - yield chunk if run_manager: run_manager.on_llm_new_token(delta, chunk=chunk) + yield chunk async def _astream( self, @@ -166,9 +166,9 @@ class ChatCohere(BaseChatModel, BaseCohere): if data.event_type == "text-generation": delta = data.text chunk = ChatGenerationChunk(message=AIMessageChunk(content=delta)) - yield chunk if run_manager: await run_manager.on_llm_new_token(delta, chunk=chunk) + yield chunk def _get_generation_info(self, response: Any) -> Dict[str, Any]: """Get the generation info from cohere API response.""" diff --git a/libs/community/langchain_community/chat_models/deepinfra.py b/libs/community/langchain_community/chat_models/deepinfra.py index 156865c487..0886e52086 100644 --- a/libs/community/langchain_community/chat_models/deepinfra.py +++ b/libs/community/langchain_community/chat_models/deepinfra.py @@ -329,9 +329,9 @@ class ChatDeepInfra(BaseChatModel): chunk = _handle_sse_line(line) if chunk: cg_chunk = ChatGenerationChunk(message=chunk, generation_info=None) - yield cg_chunk if run_manager: run_manager.on_llm_new_token(str(chunk.content), chunk=cg_chunk) + yield cg_chunk async def _astream( self, @@ -352,11 +352,11 @@ class ChatDeepInfra(BaseChatModel): chunk = _handle_sse_line(line) if chunk: cg_chunk = ChatGenerationChunk(message=chunk, generation_info=None) - yield cg_chunk if run_manager: await run_manager.on_llm_new_token( str(chunk.content), chunk=cg_chunk ) + yield cg_chunk async def _agenerate( self, diff --git a/libs/community/langchain_community/chat_models/edenai.py b/libs/community/langchain_community/chat_models/edenai.py index 3e329d91d7..a4252b804c 100644 --- a/libs/community/langchain_community/chat_models/edenai.py +++ b/libs/community/langchain_community/chat_models/edenai.py @@ -206,12 +206,10 @@ class ChatEdenAI(BaseChatModel): for chunk_response in response.iter_lines(): chunk = json.loads(chunk_response.decode()) token = chunk["text"] - chat_generatio_chunk = ChatGenerationChunk( - message=AIMessageChunk(content=token) - ) - yield chat_generatio_chunk + cg_chunk = ChatGenerationChunk(message=AIMessageChunk(content=token)) if run_manager: - run_manager.on_llm_new_token(token, chunk=chat_generatio_chunk) + run_manager.on_llm_new_token(token, chunk=cg_chunk) + yield cg_chunk async def _astream( self, @@ -246,14 +244,14 @@ class ChatEdenAI(BaseChatModel): async for chunk_response in response.content: chunk = json.loads(chunk_response.decode()) token = chunk["text"] - chat_generation_chunk = ChatGenerationChunk( + cg_chunk = ChatGenerationChunk( message=AIMessageChunk(content=token) ) - yield chat_generation_chunk if run_manager: await run_manager.on_llm_new_token( - token=chunk["text"], chunk=chat_generation_chunk + token=chunk["text"], chunk=cg_chunk ) + yield cg_chunk def _generate( self, diff --git a/libs/community/langchain_community/chat_models/fireworks.py b/libs/community/langchain_community/chat_models/fireworks.py index 0ad8144d03..7c7c127046 100644 --- a/libs/community/langchain_community/chat_models/fireworks.py +++ b/libs/community/langchain_community/chat_models/fireworks.py @@ -219,10 +219,12 @@ class ChatFireworks(BaseChatModel): dict(finish_reason=finish_reason) if finish_reason is not None else None ) default_chunk_class = chunk.__class__ - chunk = ChatGenerationChunk(message=chunk, generation_info=generation_info) - yield chunk + cg_chunk = ChatGenerationChunk( + message=chunk, generation_info=generation_info + ) if run_manager: - run_manager.on_llm_new_token(chunk.text, chunk=chunk) + run_manager.on_llm_new_token(chunk.text, chunk=cg_chunk) + yield cg_chunk async def _astream( self, @@ -250,10 +252,12 @@ class ChatFireworks(BaseChatModel): dict(finish_reason=finish_reason) if finish_reason is not None else None ) default_chunk_class = chunk.__class__ - chunk = ChatGenerationChunk(message=chunk, generation_info=generation_info) - yield chunk + cg_chunk = ChatGenerationChunk( + message=chunk, generation_info=generation_info + ) if run_manager: - await run_manager.on_llm_new_token(token=chunk.text, chunk=chunk) + await run_manager.on_llm_new_token(token=chunk.text, chunk=cg_chunk) + yield cg_chunk def conditional_decorator( diff --git a/libs/community/langchain_community/chat_models/gigachat.py b/libs/community/langchain_community/chat_models/gigachat.py index fc009e569d..ca02400ea0 100644 --- a/libs/community/langchain_community/chat_models/gigachat.py +++ b/libs/community/langchain_community/chat_models/gigachat.py @@ -155,9 +155,9 @@ class GigaChat(_BaseGigaChat, BaseChatModel): if chunk.choices: content = chunk.choices[0].delta.content cg_chunk = ChatGenerationChunk(message=AIMessageChunk(content=content)) - yield cg_chunk if run_manager: run_manager.on_llm_new_token(content, chunk=cg_chunk) + yield cg_chunk async def _astream( self, @@ -172,9 +172,9 @@ class GigaChat(_BaseGigaChat, BaseChatModel): if chunk.choices: content = chunk.choices[0].delta.content cg_chunk = ChatGenerationChunk(message=AIMessageChunk(content=content)) - yield cg_chunk if run_manager: await run_manager.on_llm_new_token(content, chunk=cg_chunk) + yield cg_chunk def get_num_tokens(self, text: str) -> int: """Count approximate number of tokens""" diff --git a/libs/community/langchain_community/chat_models/gpt_router.py b/libs/community/langchain_community/chat_models/gpt_router.py index fe919f4969..626b9a7655 100644 --- a/libs/community/langchain_community/chat_models/gpt_router.py +++ b/libs/community/langchain_community/chat_models/gpt_router.py @@ -325,13 +325,13 @@ class GPTRouter(BaseChatModel): chunk.data, default_chunk_class ) - yield chunk - if run_manager: run_manager.on_llm_new_token( token=chunk.message.content, chunk=chunk.message ) + yield chunk + async def _astream( self, messages: List[BaseMessage], @@ -358,13 +358,13 @@ class GPTRouter(BaseChatModel): chunk.data, default_chunk_class ) - yield chunk - if run_manager: await run_manager.on_llm_new_token( token=chunk.message.content, chunk=chunk.message ) + yield chunk + def _create_message_dicts( self, messages: List[BaseMessage], stop: Optional[List[str]] ) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]: diff --git a/libs/community/langchain_community/chat_models/hunyuan.py b/libs/community/langchain_community/chat_models/hunyuan.py index b20b1b921e..7eef887a5f 100644 --- a/libs/community/langchain_community/chat_models/hunyuan.py +++ b/libs/community/langchain_community/chat_models/hunyuan.py @@ -276,9 +276,9 @@ class ChatHunyuan(BaseChatModel): ) default_chunk_class = chunk.__class__ cg_chunk = ChatGenerationChunk(message=chunk) - yield cg_chunk if run_manager: run_manager.on_llm_new_token(chunk.content, chunk=cg_chunk) + yield cg_chunk def _chat(self, messages: List[BaseMessage], **kwargs: Any) -> requests.Response: if self.hunyuan_secret_key is None: diff --git a/libs/community/langchain_community/chat_models/jinachat.py b/libs/community/langchain_community/chat_models/jinachat.py index 2fb0978139..e6744c8ed6 100644 --- a/libs/community/langchain_community/chat_models/jinachat.py +++ b/libs/community/langchain_community/chat_models/jinachat.py @@ -313,9 +313,9 @@ class JinaChat(BaseChatModel): chunk = _convert_delta_to_message_chunk(delta, default_chunk_class) default_chunk_class = chunk.__class__ cg_chunk = ChatGenerationChunk(message=chunk) - yield cg_chunk if run_manager: run_manager.on_llm_new_token(chunk.content, chunk=cg_chunk) + yield cg_chunk def _generate( self, @@ -373,9 +373,9 @@ class JinaChat(BaseChatModel): chunk = _convert_delta_to_message_chunk(delta, default_chunk_class) default_chunk_class = chunk.__class__ cg_chunk = ChatGenerationChunk(message=chunk) - yield cg_chunk if run_manager: await run_manager.on_llm_new_token(chunk.content, chunk=cg_chunk) + yield cg_chunk async def _agenerate( self, diff --git a/libs/community/langchain_community/chat_models/konko.py b/libs/community/langchain_community/chat_models/konko.py index 4f385148c6..afeab01312 100644 --- a/libs/community/langchain_community/chat_models/konko.py +++ b/libs/community/langchain_community/chat_models/konko.py @@ -217,10 +217,12 @@ class ChatKonko(ChatOpenAI): dict(finish_reason=finish_reason) if finish_reason is not None else None ) default_chunk_class = chunk.__class__ - chunk = ChatGenerationChunk(message=chunk, generation_info=generation_info) - yield chunk + cg_chunk = ChatGenerationChunk( + message=chunk, generation_info=generation_info + ) if run_manager: - run_manager.on_llm_new_token(chunk.text, chunk=chunk) + run_manager.on_llm_new_token(chunk.text, chunk=cg_chunk) + yield cg_chunk def _generate( self, diff --git a/libs/community/langchain_community/chat_models/litellm.py b/libs/community/langchain_community/chat_models/litellm.py index 631c16d85a..b80fa2a6be 100644 --- a/libs/community/langchain_community/chat_models/litellm.py +++ b/libs/community/langchain_community/chat_models/litellm.py @@ -356,9 +356,9 @@ class ChatLiteLLM(BaseChatModel): chunk = _convert_delta_to_message_chunk(delta, default_chunk_class) default_chunk_class = chunk.__class__ cg_chunk = ChatGenerationChunk(message=chunk) - yield cg_chunk if run_manager: run_manager.on_llm_new_token(chunk.content, chunk=cg_chunk) + yield cg_chunk async def _astream( self, @@ -380,9 +380,9 @@ class ChatLiteLLM(BaseChatModel): chunk = _convert_delta_to_message_chunk(delta, default_chunk_class) default_chunk_class = chunk.__class__ cg_chunk = ChatGenerationChunk(message=chunk) - yield cg_chunk if run_manager: await run_manager.on_llm_new_token(chunk.content, chunk=cg_chunk) + yield cg_chunk async def _agenerate( self, diff --git a/libs/community/langchain_community/chat_models/litellm_router.py b/libs/community/langchain_community/chat_models/litellm_router.py index 6b098bb480..bd008a18be 100644 --- a/libs/community/langchain_community/chat_models/litellm_router.py +++ b/libs/community/langchain_community/chat_models/litellm_router.py @@ -124,9 +124,9 @@ class ChatLiteLLMRouter(ChatLiteLLM): chunk = _convert_delta_to_message_chunk(delta, default_chunk_class) default_chunk_class = chunk.__class__ cg_chunk = ChatGenerationChunk(message=chunk) - yield cg_chunk if run_manager: run_manager.on_llm_new_token(chunk.content, chunk=cg_chunk, **params) + yield cg_chunk async def _astream( self, @@ -150,11 +150,11 @@ class ChatLiteLLMRouter(ChatLiteLLM): chunk = _convert_delta_to_message_chunk(delta, default_chunk_class) default_chunk_class = chunk.__class__ cg_chunk = ChatGenerationChunk(message=chunk) - yield cg_chunk if run_manager: await run_manager.on_llm_new_token( chunk.content, chunk=cg_chunk, **params ) + yield cg_chunk async def _agenerate( self, diff --git a/libs/community/langchain_community/chat_models/llama_edge.py b/libs/community/langchain_community/chat_models/llama_edge.py index 5cd8d72edc..603546a9f7 100644 --- a/libs/community/langchain_community/chat_models/llama_edge.py +++ b/libs/community/langchain_community/chat_models/llama_edge.py @@ -188,12 +188,12 @@ class LlamaEdgeChatService(BaseChatModel): else None ) default_chunk_class = chunk.__class__ - chunk = ChatGenerationChunk( + cg_chunk = ChatGenerationChunk( message=chunk, generation_info=generation_info ) - yield chunk if run_manager: - run_manager.on_llm_new_token(chunk.text, chunk=chunk) + run_manager.on_llm_new_token(chunk.text, chunk=cg_chunk) + yield cg_chunk def _chat(self, messages: List[BaseMessage], **kwargs: Any) -> requests.Response: if self.service_url is None: diff --git a/libs/community/langchain_community/chat_models/ollama.py b/libs/community/langchain_community/chat_models/ollama.py index 03a98182dc..a35adcd966 100644 --- a/libs/community/langchain_community/chat_models/ollama.py +++ b/libs/community/langchain_community/chat_models/ollama.py @@ -318,6 +318,7 @@ class ChatOllama(BaseChatModel, _OllamaCommon): if run_manager: run_manager.on_llm_new_token( chunk.text, + chunk=chunk, verbose=self.verbose, ) yield chunk @@ -337,6 +338,7 @@ class ChatOllama(BaseChatModel, _OllamaCommon): if run_manager: await run_manager.on_llm_new_token( chunk.text, + chunk=chunk, verbose=self.verbose, ) yield chunk @@ -356,6 +358,7 @@ class ChatOllama(BaseChatModel, _OllamaCommon): if run_manager: run_manager.on_llm_new_token( chunk.text, + chunk=chunk, verbose=self.verbose, ) yield chunk diff --git a/libs/community/langchain_community/chat_models/openai.py b/libs/community/langchain_community/chat_models/openai.py index 0279d03ed9..9573b9bd2b 100644 --- a/libs/community/langchain_community/chat_models/openai.py +++ b/libs/community/langchain_community/chat_models/openai.py @@ -411,10 +411,12 @@ class ChatOpenAI(BaseChatModel): dict(finish_reason=finish_reason) if finish_reason is not None else None ) default_chunk_class = chunk.__class__ - chunk = ChatGenerationChunk(message=chunk, generation_info=generation_info) - yield chunk + cg_chunk = ChatGenerationChunk( + message=chunk, generation_info=generation_info + ) if run_manager: - run_manager.on_llm_new_token(chunk.text, chunk=chunk) + run_manager.on_llm_new_token(chunk.text, chunk=cg_chunk) + yield cg_chunk def _generate( self, @@ -501,10 +503,12 @@ class ChatOpenAI(BaseChatModel): dict(finish_reason=finish_reason) if finish_reason is not None else None ) default_chunk_class = chunk.__class__ - chunk = ChatGenerationChunk(message=chunk, generation_info=generation_info) - yield chunk + cg_chunk = ChatGenerationChunk( + message=chunk, generation_info=generation_info + ) if run_manager: - await run_manager.on_llm_new_token(token=chunk.text, chunk=chunk) + await run_manager.on_llm_new_token(token=chunk.text, chunk=cg_chunk) + yield cg_chunk async def _agenerate( self, diff --git a/libs/community/langchain_community/chat_models/sparkllm.py b/libs/community/langchain_community/chat_models/sparkllm.py index 3cc504ee6a..02c39d885b 100644 --- a/libs/community/langchain_community/chat_models/sparkllm.py +++ b/libs/community/langchain_community/chat_models/sparkllm.py @@ -237,9 +237,9 @@ class ChatSparkLLM(BaseChatModel): delta = content["data"] chunk = _convert_delta_to_message_chunk(delta, default_chunk_class) cg_chunk = ChatGenerationChunk(message=chunk) - yield cg_chunk if run_manager: run_manager.on_llm_new_token(str(chunk.content), chunk=cg_chunk) + yield cg_chunk def _generate( self, diff --git a/libs/community/langchain_community/chat_models/tongyi.py b/libs/community/langchain_community/chat_models/tongyi.py index 9d76e8b120..1eb881d0fb 100644 --- a/libs/community/langchain_community/chat_models/tongyi.py +++ b/libs/community/langchain_community/chat_models/tongyi.py @@ -342,9 +342,9 @@ class ChatTongyi(BaseChatModel): chunk = ChatGenerationChunk( **self._chat_generation_from_qwen_resp(stream_resp, is_chunk=True) ) - yield chunk if run_manager: run_manager.on_llm_new_token(chunk.text, chunk=chunk) + yield chunk async def _astream( self, @@ -360,9 +360,9 @@ class ChatTongyi(BaseChatModel): chunk = ChatGenerationChunk( **self._chat_generation_from_qwen_resp(stream_resp, is_chunk=True) ) - yield chunk if run_manager: await run_manager.on_llm_new_token(chunk.text, chunk=chunk) + yield chunk def _invocation_params( self, messages: List[BaseMessage], stop: Any, **kwargs: Any diff --git a/libs/community/langchain_community/chat_models/volcengine_maas.py b/libs/community/langchain_community/chat_models/volcengine_maas.py index 8178e4bee1..df348971de 100644 --- a/libs/community/langchain_community/chat_models/volcengine_maas.py +++ b/libs/community/langchain_community/chat_models/volcengine_maas.py @@ -117,9 +117,9 @@ class VolcEngineMaasChat(BaseChatModel, VolcEngineMaasBase): if res: msg = convert_dict_to_message(res) chunk = ChatGenerationChunk(message=AIMessageChunk(content=msg.content)) - yield chunk if run_manager: run_manager.on_llm_new_token(cast(str, msg.content), chunk=chunk) + yield chunk def _generate( self, diff --git a/libs/community/langchain_community/chat_models/yuan2.py b/libs/community/langchain_community/chat_models/yuan2.py index d16e629be2..9e7ad33229 100644 --- a/libs/community/langchain_community/chat_models/yuan2.py +++ b/libs/community/langchain_community/chat_models/yuan2.py @@ -273,9 +273,9 @@ class ChatYuan2(BaseChatModel): message=chunk, generation_info=generation_info, ) - yield cg_chunk if run_manager: run_manager.on_llm_new_token(chunk.content, chunk=cg_chunk) + yield cg_chunk def _generate( self, @@ -356,9 +356,9 @@ class ChatYuan2(BaseChatModel): message=chunk, generation_info=generation_info, ) - yield cg_chunk if run_manager: await run_manager.on_llm_new_token(chunk.content, chunk=cg_chunk) + yield cg_chunk async def _agenerate( self, diff --git a/libs/community/langchain_community/chat_models/zhipuai.py b/libs/community/langchain_community/chat_models/zhipuai.py index 35114ce52c..9306e13022 100644 --- a/libs/community/langchain_community/chat_models/zhipuai.py +++ b/libs/community/langchain_community/chat_models/zhipuai.py @@ -328,9 +328,9 @@ class ChatZhipuAI(BaseChatModel): if r.event == "add": delta = r.data chunk = ChatGenerationChunk(message=AIMessageChunk(content=delta)) - yield chunk if run_manager: run_manager.on_llm_new_token(delta, chunk=chunk) + yield chunk elif r.event == "error": raise ValueError(f"Error from ZhipuAI API response: {r.data}")