community[patch]: callback before yield for _stream/_astream (#17907)

- Description: callback on_llm_new_token before yield chunk for
_stream/_astream for some chat models, make all chat models in a
consistent behaviour.
- Issue: N/A
- Dependencies: N/A
pull/16856/head
mackong 7 months ago committed by GitHub
parent 15e42f1799
commit 9678797625
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -143,9 +143,9 @@ class ChatAnthropic(BaseChatModel, _AnthropicCommon):
for data in stream_resp:
delta = data.completion
chunk = ChatGenerationChunk(message=AIMessageChunk(content=delta))
yield chunk
if run_manager:
run_manager.on_llm_new_token(delta, chunk=chunk)
yield chunk
async def _astream(
self,
@ -163,9 +163,9 @@ class ChatAnthropic(BaseChatModel, _AnthropicCommon):
async for data in stream_resp:
delta = data.completion
chunk = ChatGenerationChunk(message=AIMessageChunk(content=delta))
yield chunk
if run_manager:
await run_manager.on_llm_new_token(delta, chunk=chunk)
yield chunk
def _generate(
self,

@ -219,9 +219,9 @@ class ChatBaichuan(BaseChatModel):
)
default_chunk_class = chunk.__class__
cg_chunk = ChatGenerationChunk(message=chunk)
yield cg_chunk
if run_manager:
run_manager.on_llm_new_token(chunk.content, chunk=cg_chunk)
yield cg_chunk
def _chat(self, messages: List[BaseMessage], **kwargs: Any) -> requests.Response:
parameters = {**self._default_params, **kwargs}

@ -346,9 +346,9 @@ class QianfanChatEndpoint(BaseChatModel):
),
generation_info=msg.additional_kwargs,
)
yield chunk
if run_manager:
run_manager.on_llm_new_token(chunk.text, chunk=chunk)
yield chunk
async def _astream(
self,
@ -372,6 +372,6 @@ class QianfanChatEndpoint(BaseChatModel):
),
generation_info=msg.additional_kwargs,
)
yield chunk
if run_manager:
await run_manager.on_llm_new_token(chunk.text, chunk=chunk)
yield chunk

@ -148,9 +148,9 @@ class ChatCohere(BaseChatModel, BaseCohere):
if data.event_type == "text-generation":
delta = data.text
chunk = ChatGenerationChunk(message=AIMessageChunk(content=delta))
yield chunk
if run_manager:
run_manager.on_llm_new_token(delta, chunk=chunk)
yield chunk
async def _astream(
self,
@ -166,9 +166,9 @@ class ChatCohere(BaseChatModel, BaseCohere):
if data.event_type == "text-generation":
delta = data.text
chunk = ChatGenerationChunk(message=AIMessageChunk(content=delta))
yield chunk
if run_manager:
await run_manager.on_llm_new_token(delta, chunk=chunk)
yield chunk
def _get_generation_info(self, response: Any) -> Dict[str, Any]:
"""Get the generation info from cohere API response."""

@ -329,9 +329,9 @@ class ChatDeepInfra(BaseChatModel):
chunk = _handle_sse_line(line)
if chunk:
cg_chunk = ChatGenerationChunk(message=chunk, generation_info=None)
yield cg_chunk
if run_manager:
run_manager.on_llm_new_token(str(chunk.content), chunk=cg_chunk)
yield cg_chunk
async def _astream(
self,
@ -352,11 +352,11 @@ class ChatDeepInfra(BaseChatModel):
chunk = _handle_sse_line(line)
if chunk:
cg_chunk = ChatGenerationChunk(message=chunk, generation_info=None)
yield cg_chunk
if run_manager:
await run_manager.on_llm_new_token(
str(chunk.content), chunk=cg_chunk
)
yield cg_chunk
async def _agenerate(
self,

@ -206,12 +206,10 @@ class ChatEdenAI(BaseChatModel):
for chunk_response in response.iter_lines():
chunk = json.loads(chunk_response.decode())
token = chunk["text"]
chat_generatio_chunk = ChatGenerationChunk(
message=AIMessageChunk(content=token)
)
yield chat_generatio_chunk
cg_chunk = ChatGenerationChunk(message=AIMessageChunk(content=token))
if run_manager:
run_manager.on_llm_new_token(token, chunk=chat_generatio_chunk)
run_manager.on_llm_new_token(token, chunk=cg_chunk)
yield cg_chunk
async def _astream(
self,
@ -246,14 +244,14 @@ class ChatEdenAI(BaseChatModel):
async for chunk_response in response.content:
chunk = json.loads(chunk_response.decode())
token = chunk["text"]
chat_generation_chunk = ChatGenerationChunk(
cg_chunk = ChatGenerationChunk(
message=AIMessageChunk(content=token)
)
yield chat_generation_chunk
if run_manager:
await run_manager.on_llm_new_token(
token=chunk["text"], chunk=chat_generation_chunk
token=chunk["text"], chunk=cg_chunk
)
yield cg_chunk
def _generate(
self,

@ -219,10 +219,12 @@ class ChatFireworks(BaseChatModel):
dict(finish_reason=finish_reason) if finish_reason is not None else None
)
default_chunk_class = chunk.__class__
chunk = ChatGenerationChunk(message=chunk, generation_info=generation_info)
yield chunk
cg_chunk = ChatGenerationChunk(
message=chunk, generation_info=generation_info
)
if run_manager:
run_manager.on_llm_new_token(chunk.text, chunk=chunk)
run_manager.on_llm_new_token(chunk.text, chunk=cg_chunk)
yield cg_chunk
async def _astream(
self,
@ -250,10 +252,12 @@ class ChatFireworks(BaseChatModel):
dict(finish_reason=finish_reason) if finish_reason is not None else None
)
default_chunk_class = chunk.__class__
chunk = ChatGenerationChunk(message=chunk, generation_info=generation_info)
yield chunk
cg_chunk = ChatGenerationChunk(
message=chunk, generation_info=generation_info
)
if run_manager:
await run_manager.on_llm_new_token(token=chunk.text, chunk=chunk)
await run_manager.on_llm_new_token(token=chunk.text, chunk=cg_chunk)
yield cg_chunk
def conditional_decorator(

@ -155,9 +155,9 @@ class GigaChat(_BaseGigaChat, BaseChatModel):
if chunk.choices:
content = chunk.choices[0].delta.content
cg_chunk = ChatGenerationChunk(message=AIMessageChunk(content=content))
yield cg_chunk
if run_manager:
run_manager.on_llm_new_token(content, chunk=cg_chunk)
yield cg_chunk
async def _astream(
self,
@ -172,9 +172,9 @@ class GigaChat(_BaseGigaChat, BaseChatModel):
if chunk.choices:
content = chunk.choices[0].delta.content
cg_chunk = ChatGenerationChunk(message=AIMessageChunk(content=content))
yield cg_chunk
if run_manager:
await run_manager.on_llm_new_token(content, chunk=cg_chunk)
yield cg_chunk
def get_num_tokens(self, text: str) -> int:
"""Count approximate number of tokens"""

@ -325,13 +325,13 @@ class GPTRouter(BaseChatModel):
chunk.data, default_chunk_class
)
yield chunk
if run_manager:
run_manager.on_llm_new_token(
token=chunk.message.content, chunk=chunk.message
)
yield chunk
async def _astream(
self,
messages: List[BaseMessage],
@ -358,13 +358,13 @@ class GPTRouter(BaseChatModel):
chunk.data, default_chunk_class
)
yield chunk
if run_manager:
await run_manager.on_llm_new_token(
token=chunk.message.content, chunk=chunk.message
)
yield chunk
def _create_message_dicts(
self, messages: List[BaseMessage], stop: Optional[List[str]]
) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]:

@ -276,9 +276,9 @@ class ChatHunyuan(BaseChatModel):
)
default_chunk_class = chunk.__class__
cg_chunk = ChatGenerationChunk(message=chunk)
yield cg_chunk
if run_manager:
run_manager.on_llm_new_token(chunk.content, chunk=cg_chunk)
yield cg_chunk
def _chat(self, messages: List[BaseMessage], **kwargs: Any) -> requests.Response:
if self.hunyuan_secret_key is None:

@ -313,9 +313,9 @@ class JinaChat(BaseChatModel):
chunk = _convert_delta_to_message_chunk(delta, default_chunk_class)
default_chunk_class = chunk.__class__
cg_chunk = ChatGenerationChunk(message=chunk)
yield cg_chunk
if run_manager:
run_manager.on_llm_new_token(chunk.content, chunk=cg_chunk)
yield cg_chunk
def _generate(
self,
@ -373,9 +373,9 @@ class JinaChat(BaseChatModel):
chunk = _convert_delta_to_message_chunk(delta, default_chunk_class)
default_chunk_class = chunk.__class__
cg_chunk = ChatGenerationChunk(message=chunk)
yield cg_chunk
if run_manager:
await run_manager.on_llm_new_token(chunk.content, chunk=cg_chunk)
yield cg_chunk
async def _agenerate(
self,

@ -217,10 +217,12 @@ class ChatKonko(ChatOpenAI):
dict(finish_reason=finish_reason) if finish_reason is not None else None
)
default_chunk_class = chunk.__class__
chunk = ChatGenerationChunk(message=chunk, generation_info=generation_info)
yield chunk
cg_chunk = ChatGenerationChunk(
message=chunk, generation_info=generation_info
)
if run_manager:
run_manager.on_llm_new_token(chunk.text, chunk=chunk)
run_manager.on_llm_new_token(chunk.text, chunk=cg_chunk)
yield cg_chunk
def _generate(
self,

@ -356,9 +356,9 @@ class ChatLiteLLM(BaseChatModel):
chunk = _convert_delta_to_message_chunk(delta, default_chunk_class)
default_chunk_class = chunk.__class__
cg_chunk = ChatGenerationChunk(message=chunk)
yield cg_chunk
if run_manager:
run_manager.on_llm_new_token(chunk.content, chunk=cg_chunk)
yield cg_chunk
async def _astream(
self,
@ -380,9 +380,9 @@ class ChatLiteLLM(BaseChatModel):
chunk = _convert_delta_to_message_chunk(delta, default_chunk_class)
default_chunk_class = chunk.__class__
cg_chunk = ChatGenerationChunk(message=chunk)
yield cg_chunk
if run_manager:
await run_manager.on_llm_new_token(chunk.content, chunk=cg_chunk)
yield cg_chunk
async def _agenerate(
self,

@ -124,9 +124,9 @@ class ChatLiteLLMRouter(ChatLiteLLM):
chunk = _convert_delta_to_message_chunk(delta, default_chunk_class)
default_chunk_class = chunk.__class__
cg_chunk = ChatGenerationChunk(message=chunk)
yield cg_chunk
if run_manager:
run_manager.on_llm_new_token(chunk.content, chunk=cg_chunk, **params)
yield cg_chunk
async def _astream(
self,
@ -150,11 +150,11 @@ class ChatLiteLLMRouter(ChatLiteLLM):
chunk = _convert_delta_to_message_chunk(delta, default_chunk_class)
default_chunk_class = chunk.__class__
cg_chunk = ChatGenerationChunk(message=chunk)
yield cg_chunk
if run_manager:
await run_manager.on_llm_new_token(
chunk.content, chunk=cg_chunk, **params
)
yield cg_chunk
async def _agenerate(
self,

@ -188,12 +188,12 @@ class LlamaEdgeChatService(BaseChatModel):
else None
)
default_chunk_class = chunk.__class__
chunk = ChatGenerationChunk(
cg_chunk = ChatGenerationChunk(
message=chunk, generation_info=generation_info
)
yield chunk
if run_manager:
run_manager.on_llm_new_token(chunk.text, chunk=chunk)
run_manager.on_llm_new_token(chunk.text, chunk=cg_chunk)
yield cg_chunk
def _chat(self, messages: List[BaseMessage], **kwargs: Any) -> requests.Response:
if self.service_url is None:

@ -318,6 +318,7 @@ class ChatOllama(BaseChatModel, _OllamaCommon):
if run_manager:
run_manager.on_llm_new_token(
chunk.text,
chunk=chunk,
verbose=self.verbose,
)
yield chunk
@ -337,6 +338,7 @@ class ChatOllama(BaseChatModel, _OllamaCommon):
if run_manager:
await run_manager.on_llm_new_token(
chunk.text,
chunk=chunk,
verbose=self.verbose,
)
yield chunk
@ -356,6 +358,7 @@ class ChatOllama(BaseChatModel, _OllamaCommon):
if run_manager:
run_manager.on_llm_new_token(
chunk.text,
chunk=chunk,
verbose=self.verbose,
)
yield chunk

@ -411,10 +411,12 @@ class ChatOpenAI(BaseChatModel):
dict(finish_reason=finish_reason) if finish_reason is not None else None
)
default_chunk_class = chunk.__class__
chunk = ChatGenerationChunk(message=chunk, generation_info=generation_info)
yield chunk
cg_chunk = ChatGenerationChunk(
message=chunk, generation_info=generation_info
)
if run_manager:
run_manager.on_llm_new_token(chunk.text, chunk=chunk)
run_manager.on_llm_new_token(chunk.text, chunk=cg_chunk)
yield cg_chunk
def _generate(
self,
@ -501,10 +503,12 @@ class ChatOpenAI(BaseChatModel):
dict(finish_reason=finish_reason) if finish_reason is not None else None
)
default_chunk_class = chunk.__class__
chunk = ChatGenerationChunk(message=chunk, generation_info=generation_info)
yield chunk
cg_chunk = ChatGenerationChunk(
message=chunk, generation_info=generation_info
)
if run_manager:
await run_manager.on_llm_new_token(token=chunk.text, chunk=chunk)
await run_manager.on_llm_new_token(token=chunk.text, chunk=cg_chunk)
yield cg_chunk
async def _agenerate(
self,

@ -237,9 +237,9 @@ class ChatSparkLLM(BaseChatModel):
delta = content["data"]
chunk = _convert_delta_to_message_chunk(delta, default_chunk_class)
cg_chunk = ChatGenerationChunk(message=chunk)
yield cg_chunk
if run_manager:
run_manager.on_llm_new_token(str(chunk.content), chunk=cg_chunk)
yield cg_chunk
def _generate(
self,

@ -342,9 +342,9 @@ class ChatTongyi(BaseChatModel):
chunk = ChatGenerationChunk(
**self._chat_generation_from_qwen_resp(stream_resp, is_chunk=True)
)
yield chunk
if run_manager:
run_manager.on_llm_new_token(chunk.text, chunk=chunk)
yield chunk
async def _astream(
self,
@ -360,9 +360,9 @@ class ChatTongyi(BaseChatModel):
chunk = ChatGenerationChunk(
**self._chat_generation_from_qwen_resp(stream_resp, is_chunk=True)
)
yield chunk
if run_manager:
await run_manager.on_llm_new_token(chunk.text, chunk=chunk)
yield chunk
def _invocation_params(
self, messages: List[BaseMessage], stop: Any, **kwargs: Any

@ -117,9 +117,9 @@ class VolcEngineMaasChat(BaseChatModel, VolcEngineMaasBase):
if res:
msg = convert_dict_to_message(res)
chunk = ChatGenerationChunk(message=AIMessageChunk(content=msg.content))
yield chunk
if run_manager:
run_manager.on_llm_new_token(cast(str, msg.content), chunk=chunk)
yield chunk
def _generate(
self,

@ -273,9 +273,9 @@ class ChatYuan2(BaseChatModel):
message=chunk,
generation_info=generation_info,
)
yield cg_chunk
if run_manager:
run_manager.on_llm_new_token(chunk.content, chunk=cg_chunk)
yield cg_chunk
def _generate(
self,
@ -356,9 +356,9 @@ class ChatYuan2(BaseChatModel):
message=chunk,
generation_info=generation_info,
)
yield cg_chunk
if run_manager:
await run_manager.on_llm_new_token(chunk.content, chunk=cg_chunk)
yield cg_chunk
async def _agenerate(
self,

@ -328,9 +328,9 @@ class ChatZhipuAI(BaseChatModel):
if r.event == "add":
delta = r.data
chunk = ChatGenerationChunk(message=AIMessageChunk(content=delta))
yield chunk
if run_manager:
run_manager.on_llm_new_token(delta, chunk=chunk)
yield chunk
elif r.event == "error":
raise ValueError(f"Error from ZhipuAI API response: {r.data}")

Loading…
Cancel
Save