infra: fake model invoke callback prior to yielding token (#18286)

## PR title
core[patch]: Invoke callback prior to yielding

## PR message
Description: Invoke on_llm_new_token callback prior to yielding token in
_stream and _astream methods.
Issue: https://github.com/langchain-ai/langchain/issues/16913
Dependencies: None
Twitter handle: None
pull/18284/head
William De Vena 6 months ago committed by GitHub
parent 31b4e78174
commit 42341bc787
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -225,9 +225,9 @@ class GenericFakeChatModel(BaseChatModel):
for token in content_chunks:
chunk = ChatGenerationChunk(message=AIMessageChunk(content=token))
yield chunk
if run_manager:
run_manager.on_llm_new_token(token, chunk=chunk)
yield chunk
if message.additional_kwargs:
for key, value in message.additional_kwargs.items():
@ -247,12 +247,12 @@ class GenericFakeChatModel(BaseChatModel):
},
)
)
yield chunk
if run_manager:
run_manager.on_llm_new_token(
"",
chunk=chunk, # No token for function call
)
yield chunk
else:
chunk = ChatGenerationChunk(
message=AIMessageChunk(
@ -260,24 +260,24 @@ class GenericFakeChatModel(BaseChatModel):
additional_kwargs={"function_call": {fkey: fvalue}},
)
)
yield chunk
if run_manager:
run_manager.on_llm_new_token(
"",
chunk=chunk, # No token for function call
)
yield chunk
else:
chunk = ChatGenerationChunk(
message=AIMessageChunk(
content="", additional_kwargs={key: value}
)
)
yield chunk
if run_manager:
run_manager.on_llm_new_token(
"",
chunk=chunk, # No token for function call
)
yield chunk
async def _astream(
self,

@ -396,14 +396,6 @@ async def test_event_stream_with_simple_chain() -> None:
"run_id": "",
"tags": ["my_chain", "my_model", "seq:step:2"],
},
{
"data": {"chunk": AIMessageChunk(content="hello")},
"event": "on_chain_stream",
"metadata": {"foo": "bar"},
"name": "my_chain",
"run_id": "",
"tags": ["my_chain"],
},
{
"data": {"chunk": AIMessageChunk(content="hello")},
"event": "on_chat_model_stream",
@ -413,7 +405,7 @@ async def test_event_stream_with_simple_chain() -> None:
"tags": ["my_chain", "my_model", "seq:step:2"],
},
{
"data": {"chunk": AIMessageChunk(content=" ")},
"data": {"chunk": AIMessageChunk(content="hello")},
"event": "on_chain_stream",
"metadata": {"foo": "bar"},
"name": "my_chain",
@ -429,7 +421,7 @@ async def test_event_stream_with_simple_chain() -> None:
"tags": ["my_chain", "my_model", "seq:step:2"],
},
{
"data": {"chunk": AIMessageChunk(content="world!")},
"data": {"chunk": AIMessageChunk(content=" ")},
"event": "on_chain_stream",
"metadata": {"foo": "bar"},
"name": "my_chain",
@ -444,6 +436,14 @@ async def test_event_stream_with_simple_chain() -> None:
"run_id": "",
"tags": ["my_chain", "my_model", "seq:step:2"],
},
{
"data": {"chunk": AIMessageChunk(content="world!")},
"event": "on_chain_stream",
"metadata": {"foo": "bar"},
"name": "my_chain",
"run_id": "",
"tags": ["my_chain"],
},
{
"data": {
"input": {

Loading…
Cancel
Save