community[patch]: callback before yield for _stream/_astream (#17907)

- Description: callback on_llm_new_token before yield chunk for
_stream/_astream for some chat models, make all chat models in a
consistent behaviour.
- Issue: N/A
- Dependencies: N/A
This commit is contained in:
mackong 2024-02-23 08:15:21 +08:00 committed by GitHub
parent 15e42f1799
commit 9678797625
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
22 changed files with 66 additions and 55 deletions

View File

@ -143,9 +143,9 @@ class ChatAnthropic(BaseChatModel, _AnthropicCommon):
for data in stream_resp: for data in stream_resp:
delta = data.completion delta = data.completion
chunk = ChatGenerationChunk(message=AIMessageChunk(content=delta)) chunk = ChatGenerationChunk(message=AIMessageChunk(content=delta))
yield chunk
if run_manager: if run_manager:
run_manager.on_llm_new_token(delta, chunk=chunk) run_manager.on_llm_new_token(delta, chunk=chunk)
yield chunk
async def _astream( async def _astream(
self, self,
@ -163,9 +163,9 @@ class ChatAnthropic(BaseChatModel, _AnthropicCommon):
async for data in stream_resp: async for data in stream_resp:
delta = data.completion delta = data.completion
chunk = ChatGenerationChunk(message=AIMessageChunk(content=delta)) chunk = ChatGenerationChunk(message=AIMessageChunk(content=delta))
yield chunk
if run_manager: if run_manager:
await run_manager.on_llm_new_token(delta, chunk=chunk) await run_manager.on_llm_new_token(delta, chunk=chunk)
yield chunk
def _generate( def _generate(
self, self,

View File

@ -219,9 +219,9 @@ class ChatBaichuan(BaseChatModel):
) )
default_chunk_class = chunk.__class__ default_chunk_class = chunk.__class__
cg_chunk = ChatGenerationChunk(message=chunk) cg_chunk = ChatGenerationChunk(message=chunk)
yield cg_chunk
if run_manager: if run_manager:
run_manager.on_llm_new_token(chunk.content, chunk=cg_chunk) run_manager.on_llm_new_token(chunk.content, chunk=cg_chunk)
yield cg_chunk
def _chat(self, messages: List[BaseMessage], **kwargs: Any) -> requests.Response: def _chat(self, messages: List[BaseMessage], **kwargs: Any) -> requests.Response:
parameters = {**self._default_params, **kwargs} parameters = {**self._default_params, **kwargs}

View File

@ -346,9 +346,9 @@ class QianfanChatEndpoint(BaseChatModel):
), ),
generation_info=msg.additional_kwargs, generation_info=msg.additional_kwargs,
) )
yield chunk
if run_manager: if run_manager:
run_manager.on_llm_new_token(chunk.text, chunk=chunk) run_manager.on_llm_new_token(chunk.text, chunk=chunk)
yield chunk
async def _astream( async def _astream(
self, self,
@ -372,6 +372,6 @@ class QianfanChatEndpoint(BaseChatModel):
), ),
generation_info=msg.additional_kwargs, generation_info=msg.additional_kwargs,
) )
yield chunk
if run_manager: if run_manager:
await run_manager.on_llm_new_token(chunk.text, chunk=chunk) await run_manager.on_llm_new_token(chunk.text, chunk=chunk)
yield chunk

View File

@ -148,9 +148,9 @@ class ChatCohere(BaseChatModel, BaseCohere):
if data.event_type == "text-generation": if data.event_type == "text-generation":
delta = data.text delta = data.text
chunk = ChatGenerationChunk(message=AIMessageChunk(content=delta)) chunk = ChatGenerationChunk(message=AIMessageChunk(content=delta))
yield chunk
if run_manager: if run_manager:
run_manager.on_llm_new_token(delta, chunk=chunk) run_manager.on_llm_new_token(delta, chunk=chunk)
yield chunk
async def _astream( async def _astream(
self, self,
@ -166,9 +166,9 @@ class ChatCohere(BaseChatModel, BaseCohere):
if data.event_type == "text-generation": if data.event_type == "text-generation":
delta = data.text delta = data.text
chunk = ChatGenerationChunk(message=AIMessageChunk(content=delta)) chunk = ChatGenerationChunk(message=AIMessageChunk(content=delta))
yield chunk
if run_manager: if run_manager:
await run_manager.on_llm_new_token(delta, chunk=chunk) await run_manager.on_llm_new_token(delta, chunk=chunk)
yield chunk
def _get_generation_info(self, response: Any) -> Dict[str, Any]: def _get_generation_info(self, response: Any) -> Dict[str, Any]:
"""Get the generation info from cohere API response.""" """Get the generation info from cohere API response."""

View File

@ -329,9 +329,9 @@ class ChatDeepInfra(BaseChatModel):
chunk = _handle_sse_line(line) chunk = _handle_sse_line(line)
if chunk: if chunk:
cg_chunk = ChatGenerationChunk(message=chunk, generation_info=None) cg_chunk = ChatGenerationChunk(message=chunk, generation_info=None)
yield cg_chunk
if run_manager: if run_manager:
run_manager.on_llm_new_token(str(chunk.content), chunk=cg_chunk) run_manager.on_llm_new_token(str(chunk.content), chunk=cg_chunk)
yield cg_chunk
async def _astream( async def _astream(
self, self,
@ -352,11 +352,11 @@ class ChatDeepInfra(BaseChatModel):
chunk = _handle_sse_line(line) chunk = _handle_sse_line(line)
if chunk: if chunk:
cg_chunk = ChatGenerationChunk(message=chunk, generation_info=None) cg_chunk = ChatGenerationChunk(message=chunk, generation_info=None)
yield cg_chunk
if run_manager: if run_manager:
await run_manager.on_llm_new_token( await run_manager.on_llm_new_token(
str(chunk.content), chunk=cg_chunk str(chunk.content), chunk=cg_chunk
) )
yield cg_chunk
async def _agenerate( async def _agenerate(
self, self,

View File

@ -206,12 +206,10 @@ class ChatEdenAI(BaseChatModel):
for chunk_response in response.iter_lines(): for chunk_response in response.iter_lines():
chunk = json.loads(chunk_response.decode()) chunk = json.loads(chunk_response.decode())
token = chunk["text"] token = chunk["text"]
chat_generatio_chunk = ChatGenerationChunk( cg_chunk = ChatGenerationChunk(message=AIMessageChunk(content=token))
message=AIMessageChunk(content=token)
)
yield chat_generatio_chunk
if run_manager: if run_manager:
run_manager.on_llm_new_token(token, chunk=chat_generatio_chunk) run_manager.on_llm_new_token(token, chunk=cg_chunk)
yield cg_chunk
async def _astream( async def _astream(
self, self,
@ -246,14 +244,14 @@ class ChatEdenAI(BaseChatModel):
async for chunk_response in response.content: async for chunk_response in response.content:
chunk = json.loads(chunk_response.decode()) chunk = json.loads(chunk_response.decode())
token = chunk["text"] token = chunk["text"]
chat_generation_chunk = ChatGenerationChunk( cg_chunk = ChatGenerationChunk(
message=AIMessageChunk(content=token) message=AIMessageChunk(content=token)
) )
yield chat_generation_chunk
if run_manager: if run_manager:
await run_manager.on_llm_new_token( await run_manager.on_llm_new_token(
token=chunk["text"], chunk=chat_generation_chunk token=chunk["text"], chunk=cg_chunk
) )
yield cg_chunk
def _generate( def _generate(
self, self,

View File

@ -219,10 +219,12 @@ class ChatFireworks(BaseChatModel):
dict(finish_reason=finish_reason) if finish_reason is not None else None dict(finish_reason=finish_reason) if finish_reason is not None else None
) )
default_chunk_class = chunk.__class__ default_chunk_class = chunk.__class__
chunk = ChatGenerationChunk(message=chunk, generation_info=generation_info) cg_chunk = ChatGenerationChunk(
yield chunk message=chunk, generation_info=generation_info
)
if run_manager: if run_manager:
run_manager.on_llm_new_token(chunk.text, chunk=chunk) run_manager.on_llm_new_token(chunk.text, chunk=cg_chunk)
yield cg_chunk
async def _astream( async def _astream(
self, self,
@ -250,10 +252,12 @@ class ChatFireworks(BaseChatModel):
dict(finish_reason=finish_reason) if finish_reason is not None else None dict(finish_reason=finish_reason) if finish_reason is not None else None
) )
default_chunk_class = chunk.__class__ default_chunk_class = chunk.__class__
chunk = ChatGenerationChunk(message=chunk, generation_info=generation_info) cg_chunk = ChatGenerationChunk(
yield chunk message=chunk, generation_info=generation_info
)
if run_manager: if run_manager:
await run_manager.on_llm_new_token(token=chunk.text, chunk=chunk) await run_manager.on_llm_new_token(token=chunk.text, chunk=cg_chunk)
yield cg_chunk
def conditional_decorator( def conditional_decorator(

View File

@ -155,9 +155,9 @@ class GigaChat(_BaseGigaChat, BaseChatModel):
if chunk.choices: if chunk.choices:
content = chunk.choices[0].delta.content content = chunk.choices[0].delta.content
cg_chunk = ChatGenerationChunk(message=AIMessageChunk(content=content)) cg_chunk = ChatGenerationChunk(message=AIMessageChunk(content=content))
yield cg_chunk
if run_manager: if run_manager:
run_manager.on_llm_new_token(content, chunk=cg_chunk) run_manager.on_llm_new_token(content, chunk=cg_chunk)
yield cg_chunk
async def _astream( async def _astream(
self, self,
@ -172,9 +172,9 @@ class GigaChat(_BaseGigaChat, BaseChatModel):
if chunk.choices: if chunk.choices:
content = chunk.choices[0].delta.content content = chunk.choices[0].delta.content
cg_chunk = ChatGenerationChunk(message=AIMessageChunk(content=content)) cg_chunk = ChatGenerationChunk(message=AIMessageChunk(content=content))
yield cg_chunk
if run_manager: if run_manager:
await run_manager.on_llm_new_token(content, chunk=cg_chunk) await run_manager.on_llm_new_token(content, chunk=cg_chunk)
yield cg_chunk
def get_num_tokens(self, text: str) -> int: def get_num_tokens(self, text: str) -> int:
"""Count approximate number of tokens""" """Count approximate number of tokens"""

View File

@ -325,13 +325,13 @@ class GPTRouter(BaseChatModel):
chunk.data, default_chunk_class chunk.data, default_chunk_class
) )
yield chunk
if run_manager: if run_manager:
run_manager.on_llm_new_token( run_manager.on_llm_new_token(
token=chunk.message.content, chunk=chunk.message token=chunk.message.content, chunk=chunk.message
) )
yield chunk
async def _astream( async def _astream(
self, self,
messages: List[BaseMessage], messages: List[BaseMessage],
@ -358,13 +358,13 @@ class GPTRouter(BaseChatModel):
chunk.data, default_chunk_class chunk.data, default_chunk_class
) )
yield chunk
if run_manager: if run_manager:
await run_manager.on_llm_new_token( await run_manager.on_llm_new_token(
token=chunk.message.content, chunk=chunk.message token=chunk.message.content, chunk=chunk.message
) )
yield chunk
def _create_message_dicts( def _create_message_dicts(
self, messages: List[BaseMessage], stop: Optional[List[str]] self, messages: List[BaseMessage], stop: Optional[List[str]]
) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]: ) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]:

View File

@ -276,9 +276,9 @@ class ChatHunyuan(BaseChatModel):
) )
default_chunk_class = chunk.__class__ default_chunk_class = chunk.__class__
cg_chunk = ChatGenerationChunk(message=chunk) cg_chunk = ChatGenerationChunk(message=chunk)
yield cg_chunk
if run_manager: if run_manager:
run_manager.on_llm_new_token(chunk.content, chunk=cg_chunk) run_manager.on_llm_new_token(chunk.content, chunk=cg_chunk)
yield cg_chunk
def _chat(self, messages: List[BaseMessage], **kwargs: Any) -> requests.Response: def _chat(self, messages: List[BaseMessage], **kwargs: Any) -> requests.Response:
if self.hunyuan_secret_key is None: if self.hunyuan_secret_key is None:

View File

@ -313,9 +313,9 @@ class JinaChat(BaseChatModel):
chunk = _convert_delta_to_message_chunk(delta, default_chunk_class) chunk = _convert_delta_to_message_chunk(delta, default_chunk_class)
default_chunk_class = chunk.__class__ default_chunk_class = chunk.__class__
cg_chunk = ChatGenerationChunk(message=chunk) cg_chunk = ChatGenerationChunk(message=chunk)
yield cg_chunk
if run_manager: if run_manager:
run_manager.on_llm_new_token(chunk.content, chunk=cg_chunk) run_manager.on_llm_new_token(chunk.content, chunk=cg_chunk)
yield cg_chunk
def _generate( def _generate(
self, self,
@ -373,9 +373,9 @@ class JinaChat(BaseChatModel):
chunk = _convert_delta_to_message_chunk(delta, default_chunk_class) chunk = _convert_delta_to_message_chunk(delta, default_chunk_class)
default_chunk_class = chunk.__class__ default_chunk_class = chunk.__class__
cg_chunk = ChatGenerationChunk(message=chunk) cg_chunk = ChatGenerationChunk(message=chunk)
yield cg_chunk
if run_manager: if run_manager:
await run_manager.on_llm_new_token(chunk.content, chunk=cg_chunk) await run_manager.on_llm_new_token(chunk.content, chunk=cg_chunk)
yield cg_chunk
async def _agenerate( async def _agenerate(
self, self,

View File

@ -217,10 +217,12 @@ class ChatKonko(ChatOpenAI):
dict(finish_reason=finish_reason) if finish_reason is not None else None dict(finish_reason=finish_reason) if finish_reason is not None else None
) )
default_chunk_class = chunk.__class__ default_chunk_class = chunk.__class__
chunk = ChatGenerationChunk(message=chunk, generation_info=generation_info) cg_chunk = ChatGenerationChunk(
yield chunk message=chunk, generation_info=generation_info
)
if run_manager: if run_manager:
run_manager.on_llm_new_token(chunk.text, chunk=chunk) run_manager.on_llm_new_token(chunk.text, chunk=cg_chunk)
yield cg_chunk
def _generate( def _generate(
self, self,

View File

@ -356,9 +356,9 @@ class ChatLiteLLM(BaseChatModel):
chunk = _convert_delta_to_message_chunk(delta, default_chunk_class) chunk = _convert_delta_to_message_chunk(delta, default_chunk_class)
default_chunk_class = chunk.__class__ default_chunk_class = chunk.__class__
cg_chunk = ChatGenerationChunk(message=chunk) cg_chunk = ChatGenerationChunk(message=chunk)
yield cg_chunk
if run_manager: if run_manager:
run_manager.on_llm_new_token(chunk.content, chunk=cg_chunk) run_manager.on_llm_new_token(chunk.content, chunk=cg_chunk)
yield cg_chunk
async def _astream( async def _astream(
self, self,
@ -380,9 +380,9 @@ class ChatLiteLLM(BaseChatModel):
chunk = _convert_delta_to_message_chunk(delta, default_chunk_class) chunk = _convert_delta_to_message_chunk(delta, default_chunk_class)
default_chunk_class = chunk.__class__ default_chunk_class = chunk.__class__
cg_chunk = ChatGenerationChunk(message=chunk) cg_chunk = ChatGenerationChunk(message=chunk)
yield cg_chunk
if run_manager: if run_manager:
await run_manager.on_llm_new_token(chunk.content, chunk=cg_chunk) await run_manager.on_llm_new_token(chunk.content, chunk=cg_chunk)
yield cg_chunk
async def _agenerate( async def _agenerate(
self, self,

View File

@ -124,9 +124,9 @@ class ChatLiteLLMRouter(ChatLiteLLM):
chunk = _convert_delta_to_message_chunk(delta, default_chunk_class) chunk = _convert_delta_to_message_chunk(delta, default_chunk_class)
default_chunk_class = chunk.__class__ default_chunk_class = chunk.__class__
cg_chunk = ChatGenerationChunk(message=chunk) cg_chunk = ChatGenerationChunk(message=chunk)
yield cg_chunk
if run_manager: if run_manager:
run_manager.on_llm_new_token(chunk.content, chunk=cg_chunk, **params) run_manager.on_llm_new_token(chunk.content, chunk=cg_chunk, **params)
yield cg_chunk
async def _astream( async def _astream(
self, self,
@ -150,11 +150,11 @@ class ChatLiteLLMRouter(ChatLiteLLM):
chunk = _convert_delta_to_message_chunk(delta, default_chunk_class) chunk = _convert_delta_to_message_chunk(delta, default_chunk_class)
default_chunk_class = chunk.__class__ default_chunk_class = chunk.__class__
cg_chunk = ChatGenerationChunk(message=chunk) cg_chunk = ChatGenerationChunk(message=chunk)
yield cg_chunk
if run_manager: if run_manager:
await run_manager.on_llm_new_token( await run_manager.on_llm_new_token(
chunk.content, chunk=cg_chunk, **params chunk.content, chunk=cg_chunk, **params
) )
yield cg_chunk
async def _agenerate( async def _agenerate(
self, self,

View File

@ -188,12 +188,12 @@ class LlamaEdgeChatService(BaseChatModel):
else None else None
) )
default_chunk_class = chunk.__class__ default_chunk_class = chunk.__class__
chunk = ChatGenerationChunk( cg_chunk = ChatGenerationChunk(
message=chunk, generation_info=generation_info message=chunk, generation_info=generation_info
) )
yield chunk
if run_manager: if run_manager:
run_manager.on_llm_new_token(chunk.text, chunk=chunk) run_manager.on_llm_new_token(chunk.text, chunk=cg_chunk)
yield cg_chunk
def _chat(self, messages: List[BaseMessage], **kwargs: Any) -> requests.Response: def _chat(self, messages: List[BaseMessage], **kwargs: Any) -> requests.Response:
if self.service_url is None: if self.service_url is None:

View File

@ -318,6 +318,7 @@ class ChatOllama(BaseChatModel, _OllamaCommon):
if run_manager: if run_manager:
run_manager.on_llm_new_token( run_manager.on_llm_new_token(
chunk.text, chunk.text,
chunk=chunk,
verbose=self.verbose, verbose=self.verbose,
) )
yield chunk yield chunk
@ -337,6 +338,7 @@ class ChatOllama(BaseChatModel, _OllamaCommon):
if run_manager: if run_manager:
await run_manager.on_llm_new_token( await run_manager.on_llm_new_token(
chunk.text, chunk.text,
chunk=chunk,
verbose=self.verbose, verbose=self.verbose,
) )
yield chunk yield chunk
@ -356,6 +358,7 @@ class ChatOllama(BaseChatModel, _OllamaCommon):
if run_manager: if run_manager:
run_manager.on_llm_new_token( run_manager.on_llm_new_token(
chunk.text, chunk.text,
chunk=chunk,
verbose=self.verbose, verbose=self.verbose,
) )
yield chunk yield chunk

View File

@ -411,10 +411,12 @@ class ChatOpenAI(BaseChatModel):
dict(finish_reason=finish_reason) if finish_reason is not None else None dict(finish_reason=finish_reason) if finish_reason is not None else None
) )
default_chunk_class = chunk.__class__ default_chunk_class = chunk.__class__
chunk = ChatGenerationChunk(message=chunk, generation_info=generation_info) cg_chunk = ChatGenerationChunk(
yield chunk message=chunk, generation_info=generation_info
)
if run_manager: if run_manager:
run_manager.on_llm_new_token(chunk.text, chunk=chunk) run_manager.on_llm_new_token(chunk.text, chunk=cg_chunk)
yield cg_chunk
def _generate( def _generate(
self, self,
@ -501,10 +503,12 @@ class ChatOpenAI(BaseChatModel):
dict(finish_reason=finish_reason) if finish_reason is not None else None dict(finish_reason=finish_reason) if finish_reason is not None else None
) )
default_chunk_class = chunk.__class__ default_chunk_class = chunk.__class__
chunk = ChatGenerationChunk(message=chunk, generation_info=generation_info) cg_chunk = ChatGenerationChunk(
yield chunk message=chunk, generation_info=generation_info
)
if run_manager: if run_manager:
await run_manager.on_llm_new_token(token=chunk.text, chunk=chunk) await run_manager.on_llm_new_token(token=chunk.text, chunk=cg_chunk)
yield cg_chunk
async def _agenerate( async def _agenerate(
self, self,

View File

@ -237,9 +237,9 @@ class ChatSparkLLM(BaseChatModel):
delta = content["data"] delta = content["data"]
chunk = _convert_delta_to_message_chunk(delta, default_chunk_class) chunk = _convert_delta_to_message_chunk(delta, default_chunk_class)
cg_chunk = ChatGenerationChunk(message=chunk) cg_chunk = ChatGenerationChunk(message=chunk)
yield cg_chunk
if run_manager: if run_manager:
run_manager.on_llm_new_token(str(chunk.content), chunk=cg_chunk) run_manager.on_llm_new_token(str(chunk.content), chunk=cg_chunk)
yield cg_chunk
def _generate( def _generate(
self, self,

View File

@ -342,9 +342,9 @@ class ChatTongyi(BaseChatModel):
chunk = ChatGenerationChunk( chunk = ChatGenerationChunk(
**self._chat_generation_from_qwen_resp(stream_resp, is_chunk=True) **self._chat_generation_from_qwen_resp(stream_resp, is_chunk=True)
) )
yield chunk
if run_manager: if run_manager:
run_manager.on_llm_new_token(chunk.text, chunk=chunk) run_manager.on_llm_new_token(chunk.text, chunk=chunk)
yield chunk
async def _astream( async def _astream(
self, self,
@ -360,9 +360,9 @@ class ChatTongyi(BaseChatModel):
chunk = ChatGenerationChunk( chunk = ChatGenerationChunk(
**self._chat_generation_from_qwen_resp(stream_resp, is_chunk=True) **self._chat_generation_from_qwen_resp(stream_resp, is_chunk=True)
) )
yield chunk
if run_manager: if run_manager:
await run_manager.on_llm_new_token(chunk.text, chunk=chunk) await run_manager.on_llm_new_token(chunk.text, chunk=chunk)
yield chunk
def _invocation_params( def _invocation_params(
self, messages: List[BaseMessage], stop: Any, **kwargs: Any self, messages: List[BaseMessage], stop: Any, **kwargs: Any

View File

@ -117,9 +117,9 @@ class VolcEngineMaasChat(BaseChatModel, VolcEngineMaasBase):
if res: if res:
msg = convert_dict_to_message(res) msg = convert_dict_to_message(res)
chunk = ChatGenerationChunk(message=AIMessageChunk(content=msg.content)) chunk = ChatGenerationChunk(message=AIMessageChunk(content=msg.content))
yield chunk
if run_manager: if run_manager:
run_manager.on_llm_new_token(cast(str, msg.content), chunk=chunk) run_manager.on_llm_new_token(cast(str, msg.content), chunk=chunk)
yield chunk
def _generate( def _generate(
self, self,

View File

@ -273,9 +273,9 @@ class ChatYuan2(BaseChatModel):
message=chunk, message=chunk,
generation_info=generation_info, generation_info=generation_info,
) )
yield cg_chunk
if run_manager: if run_manager:
run_manager.on_llm_new_token(chunk.content, chunk=cg_chunk) run_manager.on_llm_new_token(chunk.content, chunk=cg_chunk)
yield cg_chunk
def _generate( def _generate(
self, self,
@ -356,9 +356,9 @@ class ChatYuan2(BaseChatModel):
message=chunk, message=chunk,
generation_info=generation_info, generation_info=generation_info,
) )
yield cg_chunk
if run_manager: if run_manager:
await run_manager.on_llm_new_token(chunk.content, chunk=cg_chunk) await run_manager.on_llm_new_token(chunk.content, chunk=cg_chunk)
yield cg_chunk
async def _agenerate( async def _agenerate(
self, self,

View File

@ -328,9 +328,9 @@ class ChatZhipuAI(BaseChatModel):
if r.event == "add": if r.event == "add":
delta = r.data delta = r.data
chunk = ChatGenerationChunk(message=AIMessageChunk(content=delta)) chunk = ChatGenerationChunk(message=AIMessageChunk(content=delta))
yield chunk
if run_manager: if run_manager:
run_manager.on_llm_new_token(delta, chunk=chunk) run_manager.on_llm_new_token(delta, chunk=chunk)
yield chunk
elif r.event == "error": elif r.event == "error":
raise ValueError(f"Error from ZhipuAI API response: {r.data}") raise ValueError(f"Error from ZhipuAI API response: {r.data}")