\Fix tool_calls message merge (#14613)

<!-- Thank you for contributing to LangChain!

Replace this entire comment with:
  - **Description:** a description of the change, 
  - **Issue:** the issue # it fixes (if applicable),
  - **Dependencies:** any dependencies required for this change,
- **Tag maintainer:** for a quicker response, tag the relevant
maintainer (see below),
- **Twitter handle:** we announce bigger features on Twitter. If your PR
gets announced, and you'd like a mention, we'll gladly shout you out!

Please make sure your PR is passing linting and testing before
submitting. Run `make format`, `make lint` and `make test` to check this
locally.

See contribution guidelines for more information on how to write/run
tests, lint, etc:

https://github.com/langchain-ai/langchain/blob/master/.github/CONTRIBUTING.md

If you're adding a new integration, please include:
1. a test for the integration, preferably unit tests that do not rely on
network access,
2. an example notebook showing its use. It lives in `docs/extras`
directory.

If no one reviews your PR within a few days, please @-mention one of
@baskaryan, @eyurtsev, @hwchase17.
 -->
This commit is contained in:
Nuno Campos 2023-12-13 12:37:40 -08:00 committed by GitHub
parent 405d111da6
commit a16f4a318f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 240 additions and 2 deletions

View File

@ -98,8 +98,12 @@ class BaseMessageChunk(BaseMessage):
merged[k] = v
elif merged[k] is None and v:
merged[k] = v
elif v is None:
continue
elif merged[k] == v:
continue
elif type(merged[k]) != type(v):
raise ValueError(
raise TypeError(
f'additional_kwargs["{k}"] already exists in this message,'
" but with a different type."
)
@ -107,8 +111,17 @@ class BaseMessageChunk(BaseMessage):
merged[k] += v
elif isinstance(merged[k], dict):
merged[k] = self._merge_kwargs_dict(merged[k], v)
elif isinstance(merged[k], list):
merged[k] = merged[k].copy()
for i, e in enumerate(v):
if isinstance(e, dict) and isinstance(e.get("index"), int):
i = e["index"]
if i < len(merged[k]):
merged[k][i] = self._merge_kwargs_dict(merged[k][i], e)
else:
merged[k] = merged[k] + [e]
else:
raise ValueError(
raise TypeError(
f"Additional kwargs key {k} already exists in this message."
)
return merged

View File

@ -1,4 +1,5 @@
import unittest
from typing import List
import pytest
@ -203,3 +204,227 @@ def test_message_chunk_to_message() -> None:
assert message_chunk_to_message(
FunctionMessageChunk(name="hello", content="I am")
) == FunctionMessage(name="hello", content="I am")
def test_tool_calls_merge() -> None:
chunks: List[dict] = [
dict(content=""),
dict(
content="",
additional_kwargs={
"tool_calls": [
{
"index": 0,
"id": "call_CwGAsESnXehQEjiAIWzinlva",
"function": {"arguments": "", "name": "person"},
"type": "function",
}
]
},
),
dict(
content="",
additional_kwargs={
"tool_calls": [
{
"index": 0,
"id": None,
"function": {"arguments": '{"na', "name": None},
"type": None,
}
]
},
),
dict(
content="",
additional_kwargs={
"tool_calls": [
{
"index": 0,
"id": None,
"function": {"arguments": 'me": ', "name": None},
"type": None,
}
]
},
),
dict(
content="",
additional_kwargs={
"tool_calls": [
{
"index": 0,
"id": None,
"function": {"arguments": '"jane"', "name": None},
"type": None,
}
]
},
),
dict(
content="",
additional_kwargs={
"tool_calls": [
{
"index": 0,
"id": None,
"function": {"arguments": ', "a', "name": None},
"type": None,
}
]
},
),
dict(
content="",
additional_kwargs={
"tool_calls": [
{
"index": 0,
"id": None,
"function": {"arguments": 'ge": ', "name": None},
"type": None,
}
]
},
),
dict(
content="",
additional_kwargs={
"tool_calls": [
{
"index": 0,
"id": None,
"function": {"arguments": "2}", "name": None},
"type": None,
}
]
},
),
dict(
content="",
additional_kwargs={
"tool_calls": [
{
"index": 1,
"id": "call_zXSIylHvc5x3JUAPcHZR5GZI",
"function": {"arguments": "", "name": "person"},
"type": "function",
}
]
},
),
dict(
content="",
additional_kwargs={
"tool_calls": [
{
"index": 1,
"id": None,
"function": {"arguments": '{"na', "name": None},
"type": None,
}
]
},
),
dict(
content="",
additional_kwargs={
"tool_calls": [
{
"index": 1,
"id": None,
"function": {"arguments": 'me": ', "name": None},
"type": None,
}
]
},
),
dict(
content="",
additional_kwargs={
"tool_calls": [
{
"index": 1,
"id": None,
"function": {"arguments": '"bob",', "name": None},
"type": None,
}
]
},
),
dict(
content="",
additional_kwargs={
"tool_calls": [
{
"index": 1,
"id": None,
"function": {"arguments": ' "ag', "name": None},
"type": None,
}
]
},
),
dict(
content="",
additional_kwargs={
"tool_calls": [
{
"index": 1,
"id": None,
"function": {"arguments": 'e": 3', "name": None},
"type": None,
}
]
},
),
dict(
content="",
additional_kwargs={
"tool_calls": [
{
"index": 1,
"id": None,
"function": {"arguments": "}", "name": None},
"type": None,
}
]
},
),
dict(content=""),
]
final = None
for chunk in chunks:
msg = AIMessageChunk(**chunk)
if final is None:
final = msg
else:
final = final + msg
assert final == AIMessageChunk(
content="",
additional_kwargs={
"tool_calls": [
{
"index": 0,
"id": "call_CwGAsESnXehQEjiAIWzinlva",
"function": {
"arguments": '{"name": "jane", "age": 2}',
"name": "person",
},
"type": "function",
},
{
"index": 1,
"id": "call_zXSIylHvc5x3JUAPcHZR5GZI",
"function": {
"arguments": '{"name": "bob", "age": 3}',
"name": "person",
},
"type": "function",
},
]
},
)