mirror of https://github.com/hwchase17/langchain
langchain[patch]: Improve stream_log with AgentExecutor and Runnable Agent (#15792)
This PR fixes an issue where AgentExecutor with RunnableAgent does not allow users to see individual llm tokens if streaming=True is not set explicitly on the underlying chat model. The majority of this PR is testing code: 1. Create a test chat model that makes it easier to test streaming and supports AIMessages that include function invocation information. 2. Tests for the chat model 3. Tests for RunnableAgent (previously untested) 4. Tests for openai agent (previously untested)pull/15816/head
parent
85a4594ed7
commit
feb41c5e28
@ -0,0 +1,185 @@
|
||||
"""Tests for verifying that testing utility code works as expected."""
|
||||
from itertools import cycle
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from uuid import UUID
|
||||
|
||||
from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessage
|
||||
from langchain_core.outputs import ChatGenerationChunk, GenerationChunk
|
||||
|
||||
from langchain.callbacks.base import AsyncCallbackHandler
|
||||
from tests.unit_tests.llms.fake_chat_model import GenericFakeChatModel
|
||||
|
||||
|
||||
def test_generic_fake_chat_model_invoke() -> None:
|
||||
# Will alternate between responding with hello and goodbye
|
||||
infinite_cycle = cycle([AIMessage(content="hello"), AIMessage(content="goodbye")])
|
||||
model = GenericFakeChatModel(messages=infinite_cycle)
|
||||
response = model.invoke("meow")
|
||||
assert response == AIMessage(content="hello")
|
||||
response = model.invoke("kitty")
|
||||
assert response == AIMessage(content="goodbye")
|
||||
response = model.invoke("meow")
|
||||
assert response == AIMessage(content="hello")
|
||||
|
||||
|
||||
async def test_generic_fake_chat_model_ainvoke() -> None:
|
||||
# Will alternate between responding with hello and goodbye
|
||||
infinite_cycle = cycle([AIMessage(content="hello"), AIMessage(content="goodbye")])
|
||||
model = GenericFakeChatModel(messages=infinite_cycle)
|
||||
response = await model.ainvoke("meow")
|
||||
assert response == AIMessage(content="hello")
|
||||
response = await model.ainvoke("kitty")
|
||||
assert response == AIMessage(content="goodbye")
|
||||
response = await model.ainvoke("meow")
|
||||
assert response == AIMessage(content="hello")
|
||||
|
||||
|
||||
async def test_generic_fake_chat_model_stream() -> None:
|
||||
"""Test streaming."""
|
||||
infinite_cycle = cycle(
|
||||
[
|
||||
AIMessage(content="hello goodbye"),
|
||||
]
|
||||
)
|
||||
model = GenericFakeChatModel(messages=infinite_cycle)
|
||||
chunks = [chunk async for chunk in model.astream("meow")]
|
||||
assert chunks == [
|
||||
AIMessageChunk(content="hello"),
|
||||
AIMessageChunk(content=" "),
|
||||
AIMessageChunk(content="goodbye"),
|
||||
]
|
||||
|
||||
chunks = [chunk for chunk in model.stream("meow")]
|
||||
assert chunks == [
|
||||
AIMessageChunk(content="hello"),
|
||||
AIMessageChunk(content=" "),
|
||||
AIMessageChunk(content="goodbye"),
|
||||
]
|
||||
|
||||
# Test streaming of additional kwargs.
|
||||
# Relying on insertion order of the additional kwargs dict
|
||||
message = AIMessage(content="", additional_kwargs={"foo": 42, "bar": 24})
|
||||
model = GenericFakeChatModel(messages=cycle([message]))
|
||||
chunks = [chunk async for chunk in model.astream("meow")]
|
||||
assert chunks == [
|
||||
AIMessageChunk(content="", additional_kwargs={"foo": 42}),
|
||||
AIMessageChunk(content="", additional_kwargs={"bar": 24}),
|
||||
]
|
||||
|
||||
message = AIMessage(
|
||||
content="",
|
||||
additional_kwargs={
|
||||
"function_call": {
|
||||
"name": "move_file",
|
||||
"arguments": '{\n "source_path": "foo",\n "'
|
||||
'destination_path": "bar"\n}',
|
||||
}
|
||||
},
|
||||
)
|
||||
model = GenericFakeChatModel(messages=cycle([message]))
|
||||
chunks = [chunk async for chunk in model.astream("meow")]
|
||||
|
||||
assert chunks == [
|
||||
AIMessageChunk(
|
||||
content="", additional_kwargs={"function_call": {"name": "move_file"}}
|
||||
),
|
||||
AIMessageChunk(
|
||||
content="",
|
||||
additional_kwargs={
|
||||
"function_call": {"arguments": '{\n "source_path": "foo"'}
|
||||
},
|
||||
),
|
||||
AIMessageChunk(
|
||||
content="", additional_kwargs={"function_call": {"arguments": ","}}
|
||||
),
|
||||
AIMessageChunk(
|
||||
content="",
|
||||
additional_kwargs={
|
||||
"function_call": {"arguments": '\n "destination_path": "bar"\n}'}
|
||||
},
|
||||
),
|
||||
]
|
||||
|
||||
accumulate_chunks = None
|
||||
for chunk in chunks:
|
||||
if accumulate_chunks is None:
|
||||
accumulate_chunks = chunk
|
||||
else:
|
||||
accumulate_chunks += chunk
|
||||
|
||||
assert accumulate_chunks == AIMessageChunk(
|
||||
content="",
|
||||
additional_kwargs={
|
||||
"function_call": {
|
||||
"name": "move_file",
|
||||
"arguments": '{\n "source_path": "foo",\n "'
|
||||
'destination_path": "bar"\n}',
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
async def test_generic_fake_chat_model_astream_log() -> None:
|
||||
"""Test streaming."""
|
||||
infinite_cycle = cycle([AIMessage(content="hello goodbye")])
|
||||
model = GenericFakeChatModel(messages=infinite_cycle)
|
||||
log_patches = [
|
||||
log_patch async for log_patch in model.astream_log("meow", diff=False)
|
||||
]
|
||||
final = log_patches[-1]
|
||||
assert final.state["streamed_output"] == [
|
||||
AIMessageChunk(content="hello"),
|
||||
AIMessageChunk(content=" "),
|
||||
AIMessageChunk(content="goodbye"),
|
||||
]
|
||||
|
||||
|
||||
async def test_callback_handlers() -> None:
|
||||
"""Verify that model is implemented correctly with handlers working."""
|
||||
|
||||
class MyCustomAsyncHandler(AsyncCallbackHandler):
|
||||
def __init__(self, store: List[str]) -> None:
|
||||
self.store = store
|
||||
|
||||
async def on_chat_model_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
messages: List[List[BaseMessage]],
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
# Do nothing
|
||||
# Required to implement since this is an abstract method
|
||||
pass
|
||||
|
||||
async def on_llm_new_token(
|
||||
self,
|
||||
token: str,
|
||||
*,
|
||||
chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]] = None,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
self.store.append(token)
|
||||
|
||||
infinite_cycle = cycle(
|
||||
[
|
||||
AIMessage(content="hello goodbye"),
|
||||
]
|
||||
)
|
||||
model = GenericFakeChatModel(messages=infinite_cycle)
|
||||
tokens: List[str] = []
|
||||
# New model
|
||||
results = list(model.stream("meow", {"callbacks": [MyCustomAsyncHandler(tokens)]}))
|
||||
assert results == [
|
||||
AIMessageChunk(content="hello"),
|
||||
AIMessageChunk(content=" "),
|
||||
AIMessageChunk(content="goodbye"),
|
||||
]
|
||||
assert tokens == ["hello", " ", "goodbye"]
|
Loading…
Reference in New Issue