core: Assign missing message ids in BaseChatModel (#19863)

- This ensures ids are stable across streamed chunks
- Multiple messages in batch call get separate ids
- Also fix ids being dropped when combining message chunks

Thank you for contributing to LangChain!

- [ ] **PR title**: "package: description"
- Where "package" is whichever of langchain, community, core,
experimental, etc. is being modified. Use "docs: ..." for purely docs
changes, "templates: ..." for template changes, "infra: ..." for CI
changes.
  - Example: "community: add foobar LLM"


- [ ] **PR message**: ***Delete this entire checklist*** and replace
with
    - **Description:** a description of the change
    - **Issue:** the issue # it fixes, if applicable
    - **Dependencies:** any dependencies required for this change
- **Twitter handle:** if your PR gets announced, and you'd like a
mention, we'll gladly shout you out!


- [ ] **Add tests and docs**: If you're adding a new integration, please
include
1. a test for the integration, preferably unit tests that do not rely on
network access,
2. an example notebook showing its use. It lives in
`docs/docs/integrations` directory.


- [ ] **Lint and test**: Run `make format`, `make lint` and `make test`
from the root of the package(s) you've modified. See contribution
guidelines for more: https://python.langchain.com/docs/contributing/

Additional guidelines:
- Make sure optional dependencies are imported within a function.
- Please do not add dependencies to pyproject.toml files (even optional
ones) unless they are required for unit tests.
- Most PRs should not touch more than one package.
- Changes should be backwards compatible.
- If you are adding something to community, do not re-import it in
langchain.

If no one reviews your PR within a few days, please @-mention one of
baskaryan, efriis, eyurtsev, hwchase17.
pull/18932/head^2
Nuno Campos 6 months ago committed by GitHub
parent e830a4e731
commit 2ae6dcdf01
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -224,6 +224,8 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
run_manager.on_llm_new_token(
cast(str, chunk.message.content), chunk=chunk
)
if chunk.message.id is None:
chunk.message.id = f"run-{run_manager.run_id}"
chunk.message.response_metadata = _gen_info_and_msg_metadata(chunk)
yield chunk.message
if generation is None:
@ -294,6 +296,8 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
await run_manager.on_llm_new_token(
cast(str, chunk.message.content), chunk=chunk
)
if chunk.message.id is None:
chunk.message.id = f"run-{run_manager.run_id}"
chunk.message.response_metadata = _gen_info_and_msg_metadata(chunk)
yield chunk.message
if generation is None:
@ -607,6 +611,8 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
chunks: List[ChatGenerationChunk] = []
for chunk in self._stream(messages, stop=stop, **kwargs):
if run_manager:
if chunk.message.id is None:
chunk.message.id = f"run-{run_manager.run_id}"
run_manager.on_llm_new_token(
cast(str, chunk.message.content), chunk=chunk
)
@ -622,7 +628,9 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
result = self._generate(messages, stop=stop, **kwargs)
# Add response metadata to each generation
for generation in result.generations:
for idx, generation in enumerate(result.generations):
if run_manager and generation.message.id is None:
generation.message.id = f"run-{run_manager.run_id}-{idx}"
generation.message.response_metadata = _gen_info_and_msg_metadata(
generation
)
@ -684,6 +692,8 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
chunks: List[ChatGenerationChunk] = []
async for chunk in self._astream(messages, stop=stop, **kwargs):
if run_manager:
if chunk.message.id is None:
chunk.message.id = f"run-{run_manager.run_id}"
await run_manager.on_llm_new_token(
cast(str, chunk.message.content), chunk=chunk
)
@ -699,7 +709,9 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
result = await self._agenerate(messages, stop=stop, **kwargs)
# Add response metadata to each generation
for generation in result.generations:
for idx, generation in enumerate(result.generations):
if run_manager and generation.message.id is None:
generation.message.id = f"run-{run_manager.run_id}-{idx}"
generation.message.response_metadata = _gen_info_and_msg_metadata(
generation
)

@ -223,7 +223,9 @@ class GenericFakeChatModel(BaseChatModel):
content_chunks = cast(List[str], re.split(r"(\s)", content))
for token in content_chunks:
chunk = ChatGenerationChunk(message=AIMessageChunk(content=token))
chunk = ChatGenerationChunk(
message=AIMessageChunk(content=token, id=message.id)
)
if run_manager:
run_manager.on_llm_new_token(token, chunk=chunk)
yield chunk
@ -240,6 +242,7 @@ class GenericFakeChatModel(BaseChatModel):
for fvalue_chunk in fvalue_chunks:
chunk = ChatGenerationChunk(
message=AIMessageChunk(
id=message.id,
content="",
additional_kwargs={
"function_call": {fkey: fvalue_chunk}
@ -255,6 +258,7 @@ class GenericFakeChatModel(BaseChatModel):
else:
chunk = ChatGenerationChunk(
message=AIMessageChunk(
id=message.id,
content="",
additional_kwargs={"function_call": {fkey: fvalue}},
)
@ -268,7 +272,7 @@ class GenericFakeChatModel(BaseChatModel):
else:
chunk = ChatGenerationChunk(
message=AIMessageChunk(
content="", additional_kwargs={key: value}
id=message.id, content="", additional_kwargs={key: value}
)
)
if run_manager:

@ -971,4 +971,10 @@ _JS_SERIALIZABLE_MAPPING: Dict[Tuple[str, ...], Tuple[str, ...]] = {
"tool",
"ToolMessageChunk",
),
("langchain_core", "prompts", "image", "ImagePromptTemplate"): (
"langchain_core",
"prompts",
"image",
"ImagePromptTemplate",
),
}

@ -56,6 +56,7 @@ class AIMessageChunk(AIMessage, BaseMessageChunk):
response_metadata=merge_dicts(
self.response_metadata, other.response_metadata
),
id=self.id,
)
return super().__add__(other)

@ -34,6 +34,8 @@ class BaseMessage(Serializable):
name: Optional[str] = None
id: Optional[str] = None
"""An optional unique identifier for the message. This should ideally be
provided by the provider/model which created the message."""
class Config:
extra = Extra.allow

@ -54,6 +54,7 @@ class ChatMessageChunk(ChatMessage, BaseMessageChunk):
response_metadata=merge_dicts(
self.response_metadata, other.response_metadata
),
id=self.id,
)
elif isinstance(other, BaseMessageChunk):
return self.__class__(
@ -65,6 +66,7 @@ class ChatMessageChunk(ChatMessage, BaseMessageChunk):
response_metadata=merge_dicts(
self.response_metadata, other.response_metadata
),
id=self.id,
)
else:
return super().__add__(other)

@ -54,6 +54,7 @@ class FunctionMessageChunk(FunctionMessage, BaseMessageChunk):
response_metadata=merge_dicts(
self.response_metadata, other.response_metadata
),
id=self.id,
)
return super().__add__(other)

@ -54,6 +54,7 @@ class ToolMessageChunk(ToolMessage, BaseMessageChunk):
response_metadata=merge_dicts(
self.response_metadata, other.response_metadata
),
id=self.id,
)
return super().__add__(other)

@ -116,7 +116,9 @@ def node_data_str(node: Node) -> str:
return data if not data.startswith("Runnable") else data[8:]
def node_data_json(node: Node) -> Dict[str, Union[str, Dict[str, Any]]]:
def node_data_json(
node: Node, *, with_schemas: bool = False
) -> Dict[str, Union[str, Dict[str, Any]]]:
from langchain_core.load.serializable import to_json_not_implemented
from langchain_core.runnables.base import Runnable, RunnableSerializable
@ -137,10 +139,17 @@ def node_data_json(node: Node) -> Dict[str, Union[str, Dict[str, Any]]]:
},
}
elif inspect.isclass(node.data) and issubclass(node.data, BaseModel):
return {
"type": "schema",
"data": node.data.schema(),
}
return (
{
"type": "schema",
"data": node.data.schema(),
}
if with_schemas
else {
"type": "schema",
"data": node_data_str(node),
}
)
else:
return {
"type": "unknown",
@ -156,7 +165,7 @@ class Graph:
edges: List[Edge] = field(default_factory=list)
branches: Optional[Dict[str, List[Branch]]] = field(default_factory=dict)
def to_json(self) -> Dict[str, List[Dict[str, Any]]]:
def to_json(self, *, with_schemas: bool = False) -> Dict[str, List[Dict[str, Any]]]:
"""Convert the graph to a JSON-serializable format."""
stable_node_ids = {
node.id: i if is_uuid(node.id) else node.id
@ -165,7 +174,10 @@ class Graph:
return {
"nodes": [
{"id": stable_node_ids[node.id], **node_data_json(node)}
{
"id": stable_node_ids[node.id],
**node_data_json(node, with_schemas=with_schemas),
}
for node in self.nodes.values()
],
"edges": [

@ -8,6 +8,7 @@ from langchain_core.language_models import GenericFakeChatModel, ParrotFakeChatM
from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessage
from langchain_core.messages.human import HumanMessage
from langchain_core.outputs import ChatGenerationChunk, GenerationChunk
from tests.unit_tests.stubs import AnyStr
def test_generic_fake_chat_model_invoke() -> None:
@ -15,11 +16,11 @@ def test_generic_fake_chat_model_invoke() -> None:
infinite_cycle = cycle([AIMessage(content="hello"), AIMessage(content="goodbye")])
model = GenericFakeChatModel(messages=infinite_cycle)
response = model.invoke("meow")
assert response == AIMessage(content="hello")
assert response == AIMessage(content="hello", id=AnyStr())
response = model.invoke("kitty")
assert response == AIMessage(content="goodbye")
assert response == AIMessage(content="goodbye", id=AnyStr())
response = model.invoke("meow")
assert response == AIMessage(content="hello")
assert response == AIMessage(content="hello", id=AnyStr())
async def test_generic_fake_chat_model_ainvoke() -> None:
@ -27,11 +28,11 @@ async def test_generic_fake_chat_model_ainvoke() -> None:
infinite_cycle = cycle([AIMessage(content="hello"), AIMessage(content="goodbye")])
model = GenericFakeChatModel(messages=infinite_cycle)
response = await model.ainvoke("meow")
assert response == AIMessage(content="hello")
assert response == AIMessage(content="hello", id=AnyStr())
response = await model.ainvoke("kitty")
assert response == AIMessage(content="goodbye")
assert response == AIMessage(content="goodbye", id=AnyStr())
response = await model.ainvoke("meow")
assert response == AIMessage(content="hello")
assert response == AIMessage(content="hello", id=AnyStr())
async def test_generic_fake_chat_model_stream() -> None:
@ -44,17 +45,19 @@ async def test_generic_fake_chat_model_stream() -> None:
model = GenericFakeChatModel(messages=infinite_cycle)
chunks = [chunk async for chunk in model.astream("meow")]
assert chunks == [
AIMessageChunk(content="hello"),
AIMessageChunk(content=" "),
AIMessageChunk(content="goodbye"),
AIMessageChunk(content="hello", id=AnyStr()),
AIMessageChunk(content=" ", id=AnyStr()),
AIMessageChunk(content="goodbye", id=AnyStr()),
]
assert len({chunk.id for chunk in chunks}) == 1
chunks = [chunk for chunk in model.stream("meow")]
assert chunks == [
AIMessageChunk(content="hello"),
AIMessageChunk(content=" "),
AIMessageChunk(content="goodbye"),
AIMessageChunk(content="hello", id=AnyStr()),
AIMessageChunk(content=" ", id=AnyStr()),
AIMessageChunk(content="goodbye", id=AnyStr()),
]
assert len({chunk.id for chunk in chunks}) == 1
# Test streaming of additional kwargs.
# Relying on insertion order of the additional kwargs dict
@ -62,9 +65,10 @@ async def test_generic_fake_chat_model_stream() -> None:
model = GenericFakeChatModel(messages=cycle([message]))
chunks = [chunk async for chunk in model.astream("meow")]
assert chunks == [
AIMessageChunk(content="", additional_kwargs={"foo": 42}),
AIMessageChunk(content="", additional_kwargs={"bar": 24}),
AIMessageChunk(content="", additional_kwargs={"foo": 42}, id=AnyStr()),
AIMessageChunk(content="", additional_kwargs={"bar": 24}, id=AnyStr()),
]
assert len({chunk.id for chunk in chunks}) == 1
message = AIMessage(
content="",
@ -81,24 +85,31 @@ async def test_generic_fake_chat_model_stream() -> None:
assert chunks == [
AIMessageChunk(
content="", additional_kwargs={"function_call": {"name": "move_file"}}
content="",
additional_kwargs={"function_call": {"name": "move_file"}},
id=AnyStr(),
),
AIMessageChunk(
content="",
additional_kwargs={
"function_call": {"arguments": '{\n "source_path": "foo"'}
"function_call": {"arguments": '{\n "source_path": "foo"'},
},
id=AnyStr(),
),
AIMessageChunk(
content="", additional_kwargs={"function_call": {"arguments": ","}}
content="",
additional_kwargs={"function_call": {"arguments": ","}},
id=AnyStr(),
),
AIMessageChunk(
content="",
additional_kwargs={
"function_call": {"arguments": '\n "destination_path": "bar"\n}'}
"function_call": {"arguments": '\n "destination_path": "bar"\n}'},
},
id=AnyStr(),
),
]
assert len({chunk.id for chunk in chunks}) == 1
accumulate_chunks = None
for chunk in chunks:
@ -116,6 +127,7 @@ async def test_generic_fake_chat_model_stream() -> None:
'destination_path": "bar"\n}',
}
},
id=chunks[0].id,
)
@ -128,10 +140,11 @@ async def test_generic_fake_chat_model_astream_log() -> None:
]
final = log_patches[-1]
assert final.state["streamed_output"] == [
AIMessageChunk(content="hello"),
AIMessageChunk(content=" "),
AIMessageChunk(content="goodbye"),
AIMessageChunk(content="hello", id=AnyStr()),
AIMessageChunk(content=" ", id=AnyStr()),
AIMessageChunk(content="goodbye", id=AnyStr()),
]
assert len({chunk.id for chunk in final.state["streamed_output"]}) == 1
async def test_callback_handlers() -> None:
@ -178,16 +191,19 @@ async def test_callback_handlers() -> None:
# New model
results = list(model.stream("meow", {"callbacks": [MyCustomAsyncHandler(tokens)]}))
assert results == [
AIMessageChunk(content="hello"),
AIMessageChunk(content=" "),
AIMessageChunk(content="goodbye"),
AIMessageChunk(content="hello", id=AnyStr()),
AIMessageChunk(content=" ", id=AnyStr()),
AIMessageChunk(content="goodbye", id=AnyStr()),
]
assert tokens == ["hello", " ", "goodbye"]
assert len({chunk.id for chunk in results}) == 1
def test_chat_model_inputs() -> None:
fake = ParrotFakeChatModel()
assert fake.invoke("hello") == HumanMessage(content="hello")
assert fake.invoke([("ai", "blah")]) == AIMessage(content="blah")
assert fake.invoke([AIMessage(content="blah")]) == AIMessage(content="blah")
assert fake.invoke("hello") == HumanMessage(content="hello", id=AnyStr())
assert fake.invoke([("ai", "blah")]) == AIMessage(content="blah", id=AnyStr())
assert fake.invoke([AIMessage(content="blah")]) == AIMessage(
content="blah", id=AnyStr()
)

@ -21,6 +21,7 @@ from tests.unit_tests.fake.callbacks import (
FakeAsyncCallbackHandler,
FakeCallbackHandler,
)
from tests.unit_tests.stubs import AnyStr
@pytest.fixture
@ -140,10 +141,10 @@ async def test_astream_fallback_to_ainvoke() -> None:
model = ModelWithGenerate()
chunks = [chunk for chunk in model.stream("anything")]
assert chunks == [AIMessage(content="hello")]
assert chunks == [AIMessage(content="hello", id=AnyStr())]
chunks = [chunk async for chunk in model.astream("anything")]
assert chunks == [AIMessage(content="hello")]
assert chunks == [AIMessage(content="hello", id=AnyStr())]
async def test_astream_implementation_fallback_to_stream() -> None:
@ -178,15 +179,17 @@ async def test_astream_implementation_fallback_to_stream() -> None:
model = ModelWithSyncStream()
chunks = [chunk for chunk in model.stream("anything")]
assert chunks == [
AIMessageChunk(content="a"),
AIMessageChunk(content="b"),
AIMessageChunk(content="a", id=AnyStr()),
AIMessageChunk(content="b", id=AnyStr()),
]
assert len({chunk.id for chunk in chunks}) == 1
assert type(model)._astream == BaseChatModel._astream
astream_chunks = [chunk async for chunk in model.astream("anything")]
assert astream_chunks == [
AIMessageChunk(content="a"),
AIMessageChunk(content="b"),
AIMessageChunk(content="a", id=AnyStr()),
AIMessageChunk(content="b", id=AnyStr()),
]
assert len({chunk.id for chunk in astream_chunks}) == 1
async def test_astream_implementation_uses_astream() -> None:
@ -221,6 +224,7 @@ async def test_astream_implementation_uses_astream() -> None:
model = ModelWithAsyncStream()
chunks = [chunk async for chunk in model.astream("anything")]
assert chunks == [
AIMessageChunk(content="a"),
AIMessageChunk(content="b"),
AIMessageChunk(content="a", id=AnyStr()),
AIMessageChunk(content="b", id=AnyStr()),
]
assert len({chunk.id for chunk in chunks}) == 1

File diff suppressed because one or more lines are too long

@ -33,6 +33,55 @@ def test_graph_sequence(snapshot: SnapshotAssertion) -> None:
sequence = prompt | fake_llm | list_parser
graph = sequence.get_graph()
assert graph.to_json() == {
"nodes": [
{
"id": 0,
"type": "schema",
"data": "PromptInput",
},
{
"id": 1,
"type": "runnable",
"data": {
"id": ["langchain", "prompts", "prompt", "PromptTemplate"],
"name": "PromptTemplate",
},
},
{
"id": 2,
"type": "runnable",
"data": {
"id": ["langchain_core", "language_models", "fake", "FakeListLLM"],
"name": "FakeListLLM",
},
},
{
"id": 3,
"type": "runnable",
"data": {
"id": [
"langchain",
"output_parsers",
"list",
"CommaSeparatedListOutputParser",
],
"name": "CommaSeparatedListOutputParser",
},
},
{
"id": 4,
"type": "schema",
"data": "CommaSeparatedListOutputParserOutput",
},
],
"edges": [
{"source": 0, "target": 1},
{"source": 1, "target": 2},
{"source": 3, "target": 4},
{"source": 2, "target": 3},
],
}
assert graph.to_json(with_schemas=True) == {
"nodes": [
{
"id": 0,
@ -76,9 +125,9 @@ def test_graph_sequence(snapshot: SnapshotAssertion) -> None:
"id": 4,
"type": "schema",
"data": {
"items": {"type": "string"},
"title": "CommaSeparatedListOutputParserOutput",
"type": "array",
"items": {"type": "string"},
},
},
],
@ -115,7 +164,7 @@ def test_graph_sequence_map(snapshot: SnapshotAssertion) -> None:
}
)
graph = sequence.get_graph()
assert graph.to_json() == {
assert graph.to_json(with_schemas=True) == {
"nodes": [
{
"id": 0,
@ -484,5 +533,97 @@ def test_graph_sequence_map(snapshot: SnapshotAssertion) -> None:
{"source": 2, "target": 3},
],
}
assert graph.to_json() == {
"nodes": [
{
"id": 0,
"type": "schema",
"data": "PromptInput",
},
{
"id": 1,
"type": "runnable",
"data": {
"id": ["langchain", "prompts", "prompt", "PromptTemplate"],
"name": "PromptTemplate",
},
},
{
"id": 2,
"type": "runnable",
"data": {
"id": ["langchain_core", "language_models", "fake", "FakeListLLM"],
"name": "FakeListLLM",
},
},
{
"id": 3,
"type": "schema",
"data": "Parallel<as_list,as_str>Input",
},
{
"id": 4,
"type": "schema",
"data": "Parallel<as_list,as_str>Output",
},
{
"id": 5,
"type": "runnable",
"data": {
"id": [
"langchain",
"output_parsers",
"list",
"CommaSeparatedListOutputParser",
],
"name": "CommaSeparatedListOutputParser",
},
},
{
"id": 6,
"type": "schema",
"data": "conditional_str_parser_input",
},
{
"id": 7,
"type": "schema",
"data": "conditional_str_parser_output",
},
{
"id": 8,
"type": "runnable",
"data": {
"id": ["langchain", "schema", "output_parser", "StrOutputParser"],
"name": "StrOutputParser",
},
},
{
"id": 9,
"type": "runnable",
"data": {
"id": [
"langchain_core",
"output_parsers",
"xml",
"XMLOutputParser",
],
"name": "XMLOutputParser",
},
},
],
"edges": [
{"source": 0, "target": 1},
{"source": 1, "target": 2},
{"source": 3, "target": 5},
{"source": 5, "target": 4},
{"source": 6, "target": 8},
{"source": 8, "target": 7},
{"source": 6, "target": 9},
{"source": 9, "target": 7},
{"source": 3, "target": 6},
{"source": 7, "target": 4},
{"source": 2, "target": 3},
],
}
assert graph.draw_ascii() == snapshot(name="ascii")
assert graph.draw_mermaid() == snapshot(name="mermaid")

@ -41,6 +41,7 @@ from langchain_core.messages import (
HumanMessage,
SystemMessage,
)
from langchain_core.messages.base import BaseMessage
from langchain_core.output_parsers import (
BaseOutputParser,
CommaSeparatedListOutputParser,
@ -86,6 +87,7 @@ from langchain_core.tracers import (
RunLogPatch,
)
from langchain_core.tracers.context import collect_runs
from tests.unit_tests.stubs import AnyStr
class FakeTracer(BaseTracer):
@ -106,6 +108,12 @@ class FakeTracer(BaseTracer):
self.uuids_map[uuid] = next(self.uuids_generator)
return self.uuids_map[uuid]
def _replace_message_id(self, maybe_message: Any) -> Any:
if isinstance(maybe_message, BaseMessage):
maybe_message.id = AnyStr()
return maybe_message
def _copy_run(self, run: Run) -> Run:
if run.dotted_order:
levels = run.dotted_order.split(".")
@ -129,6 +137,16 @@ class FakeTracer(BaseTracer):
"child_execution_order": None,
"trace_id": self._replace_uuid(run.trace_id) if run.trace_id else None,
"dotted_order": new_dotted_order,
"inputs": {
k: self._replace_message_id(v) for k, v in run.inputs.items()
}
if isinstance(run.inputs, dict)
else run.inputs,
"outputs": {
k: self._replace_message_id(v) for k, v in run.outputs.items()
}
if isinstance(run.outputs, dict)
else run.outputs,
}
)
@ -1922,7 +1940,7 @@ def test_prompt_with_chat_model(
tracer = FakeTracer()
assert chain.invoke(
{"question": "What is your name?"}, dict(callbacks=[tracer])
) == AIMessage(content="foo")
) == AIMessage(content="foo", id=AnyStr())
assert prompt_spy.call_args.args[1] == {"question": "What is your name?"}
assert chat_spy.call_args.args[1] == ChatPromptValue(
messages=[
@ -1947,8 +1965,8 @@ def test_prompt_with_chat_model(
],
dict(callbacks=[tracer]),
) == [
AIMessage(content="foo"),
AIMessage(content="foo"),
AIMessage(content="foo", id=AnyStr()),
AIMessage(content="foo", id=AnyStr()),
]
assert prompt_spy.call_args.args[1] == [
{"question": "What is your name?"},
@ -1988,9 +2006,9 @@ def test_prompt_with_chat_model(
assert [
*chain.stream({"question": "What is your name?"}, dict(callbacks=[tracer]))
] == [
AIMessageChunk(content="f"),
AIMessageChunk(content="o"),
AIMessageChunk(content="o"),
AIMessageChunk(content="f", id=AnyStr()),
AIMessageChunk(content="o", id=AnyStr()),
AIMessageChunk(content="o", id=AnyStr()),
]
assert prompt_spy.call_args.args[1] == {"question": "What is your name?"}
assert chat_spy.call_args.args[1] == ChatPromptValue(
@ -2026,7 +2044,7 @@ async def test_prompt_with_chat_model_async(
tracer = FakeTracer()
assert await chain.ainvoke(
{"question": "What is your name?"}, dict(callbacks=[tracer])
) == AIMessage(content="foo")
) == AIMessage(content="foo", id=AnyStr())
assert prompt_spy.call_args.args[1] == {"question": "What is your name?"}
assert chat_spy.call_args.args[1] == ChatPromptValue(
messages=[
@ -2051,8 +2069,8 @@ async def test_prompt_with_chat_model_async(
],
dict(callbacks=[tracer]),
) == [
AIMessage(content="foo"),
AIMessage(content="foo"),
AIMessage(content="foo", id=AnyStr()),
AIMessage(content="foo", id=AnyStr()),
]
assert prompt_spy.call_args.args[1] == [
{"question": "What is your name?"},
@ -2095,9 +2113,9 @@ async def test_prompt_with_chat_model_async(
{"question": "What is your name?"}, dict(callbacks=[tracer])
)
] == [
AIMessageChunk(content="f"),
AIMessageChunk(content="o"),
AIMessageChunk(content="o"),
AIMessageChunk(content="f", id=AnyStr()),
AIMessageChunk(content="o", id=AnyStr()),
AIMessageChunk(content="o", id=AnyStr()),
]
assert prompt_spy.call_args.args[1] == {"question": "What is your name?"}
assert chat_spy.call_args.args[1] == ChatPromptValue(
@ -2762,7 +2780,7 @@ def test_prompt_with_chat_model_and_parser(
HumanMessage(content="What is your name?"),
]
)
assert parser_spy.call_args.args[1] == AIMessage(content="foo, bar")
assert parser_spy.call_args.args[1] == AIMessage(content="foo, bar", id=AnyStr())
assert tracer.runs == snapshot
@ -2895,7 +2913,7 @@ What is your name?"""
),
]
)
assert parser_spy.call_args.args[1] == AIMessage(content="foo, bar")
assert parser_spy.call_args.args[1] == AIMessage(content="foo, bar", id=AnyStr())
assert len([r for r in tracer.runs if r.parent_run_id is None]) == 1
parent_run = next(r for r in tracer.runs if r.parent_run_id is None)
assert len(parent_run.child_runs) == 4
@ -2941,7 +2959,7 @@ def test_seq_prompt_dict(mocker: MockerFixture, snapshot: SnapshotAssertion) ->
assert chain.invoke(
{"question": "What is your name?"}, dict(callbacks=[tracer])
) == {
"chat": AIMessage(content="i'm a chatbot"),
"chat": AIMessage(content="i'm a chatbot", id=AnyStr()),
"llm": "i'm a textbot",
}
assert prompt_spy.call_args.args[1] == {"question": "What is your name?"}
@ -3151,7 +3169,7 @@ def test_seq_prompt_map(mocker: MockerFixture, snapshot: SnapshotAssertion) -> N
assert chain.invoke(
{"question": "What is your name?"}, dict(callbacks=[tracer])
) == {
"chat": AIMessage(content="i'm a chatbot"),
"chat": AIMessage(content="i'm a chatbot", id=AnyStr()),
"llm": "i'm a textbot",
"passthrough": ChatPromptValue(
messages=[
@ -3360,12 +3378,13 @@ async def test_map_astream() -> None:
assert streamed_chunks[0] in [
{"passthrough": prompt.invoke({"question": "What is your name?"})},
{"llm": "i"},
{"chat": AIMessageChunk(content="i")},
{"chat": AIMessageChunk(content="i", id=AnyStr())},
]
assert len(streamed_chunks) == len(chat_res) + len(llm_res) + 1
assert all(len(c.keys()) == 1 for c in streamed_chunks)
assert final_value is not None
assert final_value.get("chat").content == "i'm a chatbot"
final_value["chat"].id = AnyStr()
assert final_value.get("llm") == "i'm a textbot"
assert final_value.get("passthrough") == prompt.invoke(
{"question": "What is your name?"}

@ -29,6 +29,7 @@ from langchain_core.runnables import (
from langchain_core.runnables.history import RunnableWithMessageHistory
from langchain_core.runnables.schema import StreamEvent
from langchain_core.tools import tool
from tests.unit_tests.stubs import AnyStr
def _with_nulled_run_id(events: Sequence[StreamEvent]) -> List[StreamEvent]:
@ -340,7 +341,7 @@ async def test_astream_events_from_model() -> None:
"tags": ["my_model"],
},
{
"data": {"chunk": AIMessageChunk(content="hello")},
"data": {"chunk": AIMessageChunk(content="hello", id=AnyStr())},
"event": "on_chat_model_stream",
"metadata": {"a": "b"},
"name": "my_model",
@ -348,7 +349,7 @@ async def test_astream_events_from_model() -> None:
"tags": ["my_model"],
},
{
"data": {"chunk": AIMessageChunk(content=" ")},
"data": {"chunk": AIMessageChunk(content=" ", id=AnyStr())},
"event": "on_chat_model_stream",
"metadata": {"a": "b"},
"name": "my_model",
@ -356,7 +357,7 @@ async def test_astream_events_from_model() -> None:
"tags": ["my_model"],
},
{
"data": {"chunk": AIMessageChunk(content="world!")},
"data": {"chunk": AIMessageChunk(content="world!", id=AnyStr())},
"event": "on_chat_model_stream",
"metadata": {"a": "b"},
"name": "my_model",
@ -364,7 +365,7 @@ async def test_astream_events_from_model() -> None:
"tags": ["my_model"],
},
{
"data": {"output": AIMessageChunk(content="hello world!")},
"data": {"output": AIMessageChunk(content="hello world!", id=AnyStr())},
"event": "on_chat_model_end",
"metadata": {"a": "b"},
"name": "my_model",
@ -399,7 +400,7 @@ async def test_astream_events_from_model() -> None:
"tags": ["my_model"],
},
{
"data": {"chunk": AIMessageChunk(content="hello")},
"data": {"chunk": AIMessageChunk(content="hello", id=AnyStr())},
"event": "on_chat_model_stream",
"metadata": {"a": "b"},
"name": "my_model",
@ -407,7 +408,7 @@ async def test_astream_events_from_model() -> None:
"tags": ["my_model"],
},
{
"data": {"chunk": AIMessageChunk(content=" ")},
"data": {"chunk": AIMessageChunk(content=" ", id=AnyStr())},
"event": "on_chat_model_stream",
"metadata": {"a": "b"},
"name": "my_model",
@ -415,7 +416,7 @@ async def test_astream_events_from_model() -> None:
"tags": ["my_model"],
},
{
"data": {"chunk": AIMessageChunk(content="world!")},
"data": {"chunk": AIMessageChunk(content="world!", id=AnyStr())},
"event": "on_chat_model_stream",
"metadata": {"a": "b"},
"name": "my_model",
@ -430,7 +431,9 @@ async def test_astream_events_from_model() -> None:
[
{
"generation_info": None,
"message": AIMessage(content="hello world!"),
"message": AIMessage(
content="hello world!", id=AnyStr()
),
"text": "hello world!",
"type": "ChatGeneration",
}
@ -447,7 +450,7 @@ async def test_astream_events_from_model() -> None:
"tags": ["my_model"],
},
{
"data": {"chunk": AIMessage(content="hello world!")},
"data": {"chunk": AIMessage(content="hello world!", id=AnyStr())},
"event": "on_chain_stream",
"metadata": {},
"name": "i_dont_stream",
@ -455,7 +458,7 @@ async def test_astream_events_from_model() -> None:
"tags": [],
},
{
"data": {"output": AIMessage(content="hello world!")},
"data": {"output": AIMessage(content="hello world!", id=AnyStr())},
"event": "on_chain_end",
"metadata": {},
"name": "i_dont_stream",
@ -490,7 +493,7 @@ async def test_astream_events_from_model() -> None:
"tags": ["my_model"],
},
{
"data": {"chunk": AIMessageChunk(content="hello")},
"data": {"chunk": AIMessageChunk(content="hello", id=AnyStr())},
"event": "on_chat_model_stream",
"metadata": {"a": "b"},
"name": "my_model",
@ -498,7 +501,7 @@ async def test_astream_events_from_model() -> None:
"tags": ["my_model"],
},
{
"data": {"chunk": AIMessageChunk(content=" ")},
"data": {"chunk": AIMessageChunk(content=" ", id=AnyStr())},
"event": "on_chat_model_stream",
"metadata": {"a": "b"},
"name": "my_model",
@ -506,7 +509,7 @@ async def test_astream_events_from_model() -> None:
"tags": ["my_model"],
},
{
"data": {"chunk": AIMessageChunk(content="world!")},
"data": {"chunk": AIMessageChunk(content="world!", id=AnyStr())},
"event": "on_chat_model_stream",
"metadata": {"a": "b"},
"name": "my_model",
@ -521,7 +524,9 @@ async def test_astream_events_from_model() -> None:
[
{
"generation_info": None,
"message": AIMessage(content="hello world!"),
"message": AIMessage(
content="hello world!", id=AnyStr()
),
"text": "hello world!",
"type": "ChatGeneration",
}
@ -538,7 +543,7 @@ async def test_astream_events_from_model() -> None:
"tags": ["my_model"],
},
{
"data": {"chunk": AIMessage(content="hello world!")},
"data": {"chunk": AIMessage(content="hello world!", id=AnyStr())},
"event": "on_chain_stream",
"metadata": {},
"name": "ai_dont_stream",
@ -546,7 +551,7 @@ async def test_astream_events_from_model() -> None:
"tags": [],
},
{
"data": {"output": AIMessage(content="hello world!")},
"data": {"output": AIMessage(content="hello world!", id=AnyStr())},
"event": "on_chain_end",
"metadata": {},
"name": "ai_dont_stream",
@ -563,7 +568,10 @@ async def test_event_stream_with_simple_chain() -> None:
).with_config({"run_name": "my_template", "tags": ["my_template"]})
infinite_cycle = cycle(
[AIMessage(content="hello world!"), AIMessage(content="goodbye world!")]
[
AIMessage(content="hello world!", id="ai1"),
AIMessage(content="goodbye world!", id="ai2"),
]
)
# When streaming GenericFakeChatModel breaks AIMessage into chunks based on spaces
model = (
@ -640,7 +648,7 @@ async def test_event_stream_with_simple_chain() -> None:
"tags": ["my_chain", "my_model", "seq:step:2"],
},
{
"data": {"chunk": AIMessageChunk(content="hello")},
"data": {"chunk": AIMessageChunk(content="hello", id="ai1")},
"event": "on_chat_model_stream",
"metadata": {"a": "b", "foo": "bar"},
"name": "my_model",
@ -648,7 +656,7 @@ async def test_event_stream_with_simple_chain() -> None:
"tags": ["my_chain", "my_model", "seq:step:2"],
},
{
"data": {"chunk": AIMessageChunk(content="hello")},
"data": {"chunk": AIMessageChunk(content="hello", id="ai1")},
"event": "on_chain_stream",
"metadata": {"foo": "bar"},
"name": "my_chain",
@ -656,7 +664,7 @@ async def test_event_stream_with_simple_chain() -> None:
"tags": ["my_chain"],
},
{
"data": {"chunk": AIMessageChunk(content=" ")},
"data": {"chunk": AIMessageChunk(content=" ", id="ai1")},
"event": "on_chat_model_stream",
"metadata": {"a": "b", "foo": "bar"},
"name": "my_model",
@ -664,7 +672,7 @@ async def test_event_stream_with_simple_chain() -> None:
"tags": ["my_chain", "my_model", "seq:step:2"],
},
{
"data": {"chunk": AIMessageChunk(content=" ")},
"data": {"chunk": AIMessageChunk(content=" ", id="ai1")},
"event": "on_chain_stream",
"metadata": {"foo": "bar"},
"name": "my_chain",
@ -672,7 +680,7 @@ async def test_event_stream_with_simple_chain() -> None:
"tags": ["my_chain"],
},
{
"data": {"chunk": AIMessageChunk(content="world!")},
"data": {"chunk": AIMessageChunk(content="world!", id="ai1")},
"event": "on_chat_model_stream",
"metadata": {"a": "b", "foo": "bar"},
"name": "my_model",
@ -680,7 +688,7 @@ async def test_event_stream_with_simple_chain() -> None:
"tags": ["my_chain", "my_model", "seq:step:2"],
},
{
"data": {"chunk": AIMessageChunk(content="world!")},
"data": {"chunk": AIMessageChunk(content="world!", id="ai1")},
"event": "on_chain_stream",
"metadata": {"foo": "bar"},
"name": "my_chain",
@ -702,7 +710,9 @@ async def test_event_stream_with_simple_chain() -> None:
[
{
"generation_info": None,
"message": AIMessageChunk(content="hello world!"),
"message": AIMessageChunk(
content="hello world!", id="ai1"
),
"text": "hello world!",
"type": "ChatGenerationChunk",
}
@ -719,7 +729,7 @@ async def test_event_stream_with_simple_chain() -> None:
"tags": ["my_chain", "my_model", "seq:step:2"],
},
{
"data": {"output": AIMessageChunk(content="hello world!")},
"data": {"output": AIMessageChunk(content="hello world!", id="ai1")},
"event": "on_chain_end",
"metadata": {"foo": "bar"},
"name": "my_chain",
@ -1332,8 +1342,8 @@ async def test_runnable_each() -> None:
async def test_events_astream_config() -> None:
"""Test that astream events support accepting config"""
infinite_cycle = cycle([AIMessage(content="hello world!")])
good_world_on_repeat = cycle([AIMessage(content="Goodbye world")])
infinite_cycle = cycle([AIMessage(content="hello world!", id="ai1")])
good_world_on_repeat = cycle([AIMessage(content="Goodbye world", id="ai2")])
model = GenericFakeChatModel(messages=infinite_cycle).configurable_fields(
messages=ConfigurableField(
id="messages",
@ -1343,7 +1353,7 @@ async def test_events_astream_config() -> None:
)
model_02 = model.with_config({"configurable": {"messages": good_world_on_repeat}})
assert model_02.invoke("hello") == AIMessage(content="Goodbye world")
assert model_02.invoke("hello") == AIMessage(content="Goodbye world", id="ai2")
events = await _collect_events(model_02.astream_events("hello", version="v1"))
assert events == [
@ -1356,7 +1366,7 @@ async def test_events_astream_config() -> None:
"tags": [],
},
{
"data": {"chunk": AIMessageChunk(content="Goodbye")},
"data": {"chunk": AIMessageChunk(content="Goodbye", id="ai2")},
"event": "on_chat_model_stream",
"metadata": {},
"name": "RunnableConfigurableFields",
@ -1364,7 +1374,7 @@ async def test_events_astream_config() -> None:
"tags": [],
},
{
"data": {"chunk": AIMessageChunk(content=" ")},
"data": {"chunk": AIMessageChunk(content=" ", id="ai2")},
"event": "on_chat_model_stream",
"metadata": {},
"name": "RunnableConfigurableFields",
@ -1372,7 +1382,7 @@ async def test_events_astream_config() -> None:
"tags": [],
},
{
"data": {"chunk": AIMessageChunk(content="world")},
"data": {"chunk": AIMessageChunk(content="world", id="ai2")},
"event": "on_chat_model_stream",
"metadata": {},
"name": "RunnableConfigurableFields",
@ -1380,7 +1390,7 @@ async def test_events_astream_config() -> None:
"tags": [],
},
{
"data": {"output": AIMessageChunk(content="Goodbye world")},
"data": {"output": AIMessageChunk(content="Goodbye world", id="ai2")},
"event": "on_chat_model_end",
"metadata": {},
"name": "RunnableConfigurableFields",
@ -1418,7 +1428,9 @@ async def test_runnable_with_message_history() -> None:
store[session_id] = []
return InMemoryHistory(messages=store[session_id])
infinite_cycle = cycle([AIMessage(content="hello"), AIMessage(content="world")])
infinite_cycle = cycle(
[AIMessage(content="hello", id="ai3"), AIMessage(content="world", id="ai4")]
)
prompt = ChatPromptTemplate.from_messages(
[
@ -1441,7 +1453,10 @@ async def test_runnable_with_message_history() -> None:
).ainvoke({"question": "hello"})
assert store == {
"session-123": [HumanMessage(content="hello"), AIMessage(content="hello")]
"session-123": [
HumanMessage(content="hello"),
AIMessage(content="hello", id="ai3"),
]
}
with_message_history.with_config(
@ -1450,8 +1465,8 @@ async def test_runnable_with_message_history() -> None:
assert store == {
"session-123": [
HumanMessage(content="hello"),
AIMessage(content="hello"),
AIMessage(content="hello", id="ai3"),
HumanMessage(content="meow"),
AIMessage(content="world"),
AIMessage(content="world", id="ai4"),
]
}

@ -0,0 +1,6 @@
from typing import Any
class AnyStr(str):
def __eq__(self, other: Any) -> bool:
return isinstance(other, str)

@ -23,15 +23,16 @@ from langchain_core.messages import (
def test_message_chunks() -> None:
assert AIMessageChunk(content="I am") + AIMessageChunk(
assert AIMessageChunk(content="I am", id="ai3") + AIMessageChunk(
content=" indeed."
) == AIMessageChunk(
content="I am indeed."
content="I am indeed.", id="ai3"
), "MessageChunk + MessageChunk should be a MessageChunk"
assert (
AIMessageChunk(content="I am") + HumanMessageChunk(content=" indeed.")
== AIMessageChunk(content="I am indeed.")
AIMessageChunk(content="I am", id="ai2")
+ HumanMessageChunk(content=" indeed.", id="human1")
== AIMessageChunk(content="I am indeed.", id="ai2")
), "MessageChunk + MessageChunk should be a MessageChunk of same class as the left side" # noqa: E501
assert (
@ -69,10 +70,10 @@ def test_message_chunks() -> None:
def test_chat_message_chunks() -> None:
assert ChatMessageChunk(role="User", content="I am") + ChatMessageChunk(
assert ChatMessageChunk(role="User", content="I am", id="ai4") + ChatMessageChunk(
role="User", content=" indeed."
) == ChatMessageChunk(
role="User", content="I am indeed."
id="ai4", role="User", content="I am indeed."
), "ChatMessageChunk + ChatMessageChunk should be a ChatMessageChunk"
with pytest.raises(ValueError):
@ -94,10 +95,10 @@ def test_chat_message_chunks() -> None:
def test_function_message_chunks() -> None:
assert FunctionMessageChunk(name="hello", content="I am") + FunctionMessageChunk(
name="hello", content=" indeed."
) == FunctionMessageChunk(
name="hello", content="I am indeed."
assert FunctionMessageChunk(
name="hello", content="I am", id="ai5"
) + FunctionMessageChunk(name="hello", content=" indeed.") == FunctionMessageChunk(
id="ai5", name="hello", content="I am indeed."
), "FunctionMessageChunk + FunctionMessageChunk should be a FunctionMessageChunk"
with pytest.raises(ValueError):

@ -25,7 +25,7 @@ extended_tests:
poetry run pytest --disable-socket --allow-unix-socket --only-extended tests/unit_tests
test_watch:
poetry run ptw --snapshot-update --now . -- -x --disable-socket --allow-unix-socket tests/unit_tests
poetry run ptw --snapshot-update --now . -- -x --disable-socket --allow-unix-socket --disable-warnings tests/unit_tests
test_watch_extended:
poetry run ptw --snapshot-update --now . -- -x --disable-socket --allow-unix-socket --only-extended tests/unit_tests

@ -35,6 +35,7 @@ from langchain.prompts import ChatPromptTemplate
from langchain.tools import tool
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
from tests.unit_tests.llms.fake_chat_model import GenericFakeChatModel
from tests.unit_tests.stubs import AnyStr
class FakeListLLM(LLM):
@ -839,6 +840,7 @@ async def test_openai_agent_with_streaming() -> None:
log="\nInvoking: `find_pet` with `{'pet': 'cat'}`\n\n\n",
message_log=[
AIMessageChunk(
id=AnyStr(),
content="",
additional_kwargs={
"function_call": {
@ -852,6 +854,7 @@ async def test_openai_agent_with_streaming() -> None:
],
"messages": [
AIMessageChunk(
id=AnyStr(),
content="",
additional_kwargs={
"function_call": {
@ -874,6 +877,7 @@ async def test_openai_agent_with_streaming() -> None:
log="\nInvoking: `find_pet` with `{'pet': 'cat'}`\n\n\n",
message_log=[
AIMessageChunk(
id=AnyStr(),
content="",
additional_kwargs={
"function_call": {
@ -1014,6 +1018,7 @@ async def test_openai_agent_tools_agent() -> None:
log="\nInvoking: `find_pet` with `{'pet': 'cat'}`\n\n\n",
message_log=[
AIMessageChunk(
id=AnyStr(),
content="",
additional_kwargs={
"tool_calls": [
@ -1040,6 +1045,7 @@ async def test_openai_agent_tools_agent() -> None:
],
"messages": [
AIMessageChunk(
id=AnyStr(),
content="",
additional_kwargs={
"tool_calls": [
@ -1067,6 +1073,7 @@ async def test_openai_agent_tools_agent() -> None:
log="\nInvoking: `check_time` with `{}`\n\n\n",
message_log=[
AIMessageChunk(
id=AnyStr(),
content="",
additional_kwargs={
"tool_calls": [
@ -1093,6 +1100,7 @@ async def test_openai_agent_tools_agent() -> None:
],
"messages": [
AIMessageChunk(
id=AnyStr(),
content="",
additional_kwargs={
"tool_calls": [
@ -1124,6 +1132,7 @@ async def test_openai_agent_tools_agent() -> None:
log="\nInvoking: `find_pet` with `{'pet': 'cat'}`\n\n\n",
message_log=[
AIMessageChunk(
id=AnyStr(),
content="",
additional_kwargs={
"tool_calls": [
@ -1166,6 +1175,7 @@ async def test_openai_agent_tools_agent() -> None:
log="\nInvoking: `check_time` with `{}`\n\n\n",
message_log=[
AIMessageChunk(
id=AnyStr(),
content="",
additional_kwargs={
"tool_calls": [

@ -119,7 +119,9 @@ class GenericFakeChatModel(BaseChatModel):
content_chunks = cast(List[str], re.split(r"(\s)", content))
for token in content_chunks:
chunk = ChatGenerationChunk(message=AIMessageChunk(content=token))
chunk = ChatGenerationChunk(
message=AIMessageChunk(id=message.id, content=token)
)
if run_manager:
run_manager.on_llm_new_token(token, chunk=chunk)
yield chunk
@ -136,6 +138,7 @@ class GenericFakeChatModel(BaseChatModel):
for fvalue_chunk in fvalue_chunks:
chunk = ChatGenerationChunk(
message=AIMessageChunk(
id=message.id,
content="",
additional_kwargs={
"function_call": {fkey: fvalue_chunk}
@ -151,6 +154,7 @@ class GenericFakeChatModel(BaseChatModel):
else:
chunk = ChatGenerationChunk(
message=AIMessageChunk(
id=message.id,
content="",
additional_kwargs={"function_call": {fkey: fvalue}},
)
@ -164,7 +168,7 @@ class GenericFakeChatModel(BaseChatModel):
else:
chunk = ChatGenerationChunk(
message=AIMessageChunk(
content="", additional_kwargs={key: value}
id=message.id, content="", additional_kwargs={key: value}
)
)
if run_manager:

@ -8,6 +8,7 @@ from langchain_core.outputs import ChatGenerationChunk, GenerationChunk
from langchain.callbacks.base import AsyncCallbackHandler
from tests.unit_tests.llms.fake_chat_model import GenericFakeChatModel
from tests.unit_tests.stubs import AnyStr
def test_generic_fake_chat_model_invoke() -> None:
@ -15,11 +16,11 @@ def test_generic_fake_chat_model_invoke() -> None:
infinite_cycle = cycle([AIMessage(content="hello"), AIMessage(content="goodbye")])
model = GenericFakeChatModel(messages=infinite_cycle)
response = model.invoke("meow")
assert response == AIMessage(content="hello")
assert response == AIMessage(content="hello", id=AnyStr())
response = model.invoke("kitty")
assert response == AIMessage(content="goodbye")
assert response == AIMessage(content="goodbye", id=AnyStr())
response = model.invoke("meow")
assert response == AIMessage(content="hello")
assert response == AIMessage(content="hello", id=AnyStr())
async def test_generic_fake_chat_model_ainvoke() -> None:
@ -27,11 +28,11 @@ async def test_generic_fake_chat_model_ainvoke() -> None:
infinite_cycle = cycle([AIMessage(content="hello"), AIMessage(content="goodbye")])
model = GenericFakeChatModel(messages=infinite_cycle)
response = await model.ainvoke("meow")
assert response == AIMessage(content="hello")
assert response == AIMessage(content="hello", id=AnyStr())
response = await model.ainvoke("kitty")
assert response == AIMessage(content="goodbye")
assert response == AIMessage(content="goodbye", id=AnyStr())
response = await model.ainvoke("meow")
assert response == AIMessage(content="hello")
assert response == AIMessage(content="hello", id=AnyStr())
async def test_generic_fake_chat_model_stream() -> None:
@ -44,16 +45,16 @@ async def test_generic_fake_chat_model_stream() -> None:
model = GenericFakeChatModel(messages=infinite_cycle)
chunks = [chunk async for chunk in model.astream("meow")]
assert chunks == [
AIMessageChunk(content="hello"),
AIMessageChunk(content=" "),
AIMessageChunk(content="goodbye"),
AIMessageChunk(content="hello", id=AnyStr()),
AIMessageChunk(content=" ", id=AnyStr()),
AIMessageChunk(content="goodbye", id=AnyStr()),
]
chunks = [chunk for chunk in model.stream("meow")]
assert chunks == [
AIMessageChunk(content="hello"),
AIMessageChunk(content=" "),
AIMessageChunk(content="goodbye"),
AIMessageChunk(content="hello", id=AnyStr()),
AIMessageChunk(content=" ", id=AnyStr()),
AIMessageChunk(content="goodbye", id=AnyStr()),
]
# Test streaming of additional kwargs.
@ -62,11 +63,12 @@ async def test_generic_fake_chat_model_stream() -> None:
model = GenericFakeChatModel(messages=cycle([message]))
chunks = [chunk async for chunk in model.astream("meow")]
assert chunks == [
AIMessageChunk(content="", additional_kwargs={"foo": 42}),
AIMessageChunk(content="", additional_kwargs={"bar": 24}),
AIMessageChunk(content="", additional_kwargs={"foo": 42}, id=AnyStr()),
AIMessageChunk(content="", additional_kwargs={"bar": 24}, id=AnyStr()),
]
message = AIMessage(
id="a1",
content="",
additional_kwargs={
"function_call": {
@ -81,18 +83,22 @@ async def test_generic_fake_chat_model_stream() -> None:
assert chunks == [
AIMessageChunk(
content="", additional_kwargs={"function_call": {"name": "move_file"}}
content="",
additional_kwargs={"function_call": {"name": "move_file"}},
id="a1",
),
AIMessageChunk(
id="a1",
content="",
additional_kwargs={
"function_call": {"arguments": '{\n "source_path": "foo"'}
},
),
AIMessageChunk(
content="", additional_kwargs={"function_call": {"arguments": ","}}
id="a1", content="", additional_kwargs={"function_call": {"arguments": ","}}
),
AIMessageChunk(
id="a1",
content="",
additional_kwargs={
"function_call": {"arguments": '\n "destination_path": "bar"\n}'}
@ -108,6 +114,7 @@ async def test_generic_fake_chat_model_stream() -> None:
accumulate_chunks += chunk
assert accumulate_chunks == AIMessageChunk(
id="a1",
content="",
additional_kwargs={
"function_call": {
@ -128,9 +135,9 @@ async def test_generic_fake_chat_model_astream_log() -> None:
]
final = log_patches[-1]
assert final.state["streamed_output"] == [
AIMessageChunk(content="hello"),
AIMessageChunk(content=" "),
AIMessageChunk(content="goodbye"),
AIMessageChunk(content="hello", id=AnyStr()),
AIMessageChunk(content=" ", id=AnyStr()),
AIMessageChunk(content="goodbye", id=AnyStr()),
]
@ -178,8 +185,8 @@ async def test_callback_handlers() -> None:
# New model
results = list(model.stream("meow", {"callbacks": [MyCustomAsyncHandler(tokens)]}))
assert results == [
AIMessageChunk(content="hello"),
AIMessageChunk(content=" "),
AIMessageChunk(content="goodbye"),
AIMessageChunk(content="hello", id=AnyStr()),
AIMessageChunk(content=" ", id=AnyStr()),
AIMessageChunk(content="goodbye", id=AnyStr()),
]
assert tokens == ["hello", " ", "goodbye"]

@ -0,0 +1,6 @@
from typing import Any
class AnyStr(str):
def __eq__(self, other: Any) -> bool:
return isinstance(other, str)
Loading…
Cancel
Save