core[minor], ...: add tool calls message (#18947)

core[minor], langchain[patch], openai[minor], anthropic[minor], fireworks[minor], groq[minor], mistralai[minor]

```python
class ToolCall(TypedDict):
    name: str
    args: Dict[str, Any]
    id: Optional[str]

class InvalidToolCall(TypedDict):
    name: Optional[str]
    args: Optional[str]
    id: Optional[str]
    error: Optional[str]

class ToolCallChunk(TypedDict):
    name: Optional[str]
    args: Optional[str]
    id: Optional[str]
    index: Optional[int]


class AIMessage(BaseMessage):
    ...
    tool_calls: List[ToolCall] = []
    invalid_tool_calls: List[InvalidToolCall] = []
    ...


class AIMessageChunk(AIMessage, BaseMessageChunk):
    ...
    tool_call_chunks: Optional[List[ToolCallChunk]] = None
    ...
```
Important considerations:
- Parsing logic occurs within different providers;
- ~Changing output type is a breaking change for anyone doing explicit
type checking;~
- ~Langsmith rendering will need to be updated:
https://github.com/langchain-ai/langchainplus/pull/3561~
- ~Langserve will need to be updated~
- Adding chunks:
- ~AIMessage + ToolCallsMessage = ToolCallsMessage if either has
non-null .tool_calls.~
- Tool call chunks are appended, merging when having equal values of
`index`.
  - additional_kwargs accumulate the normal way.
- During streaming:
- ~Messages can change types (e.g., from AIMessageChunk to
AIToolCallsMessageChunk)~
- Output parsers parse additional_kwargs (during .invoke they read off
tool calls).

Packages outside of `partners/`:
- https://github.com/langchain-ai/langchain-cohere/pull/7
- https://github.com/langchain-ai/langchain-google/pull/123/files

---------

Co-authored-by: Chester Curme <chester.curme@gmail.com>
pull/11274/merge
Bagatur 2 months ago committed by GitHub
parent 00552918ac
commit 9514bc4d67
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -0,0 +1,423 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "c48812ed-35bd-4fbe-9a2c-6c7335e5645e",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/chestercurme/repos/langchain/libs/core/langchain_core/_api/beta_decorator.py:87: LangChainBetaWarning: The function `bind_tools` is in beta. It is actively being worked on, so the API may change.\n",
" warn_beta(\n"
]
}
],
"source": [
"from langchain_anthropic import ChatAnthropic\n",
"from langchain_core.runnables import ConfigurableField\n",
"from langchain_core.tools import tool\n",
"from langchain_openai import ChatOpenAI\n",
"\n",
"\n",
"@tool\n",
"def multiply(x: float, y: float) -> float:\n",
" \"\"\"Multiply 'x' times 'y'.\"\"\"\n",
" return x * y\n",
"\n",
"\n",
"@tool\n",
"def exponentiate(x: float, y: float) -> float:\n",
" \"\"\"Raise 'x' to the 'y'.\"\"\"\n",
" return x**y\n",
"\n",
"\n",
"@tool\n",
"def add(x: float, y: float) -> float:\n",
" \"\"\"Add 'x' and 'y'.\"\"\"\n",
" return x + y\n",
"\n",
"\n",
"tools = [multiply, exponentiate, add]\n",
"\n",
"gpt35 = ChatOpenAI(model=\"gpt-3.5-turbo-0125\", temperature=0).bind_tools(tools)\n",
"claude3 = ChatAnthropic(model=\"claude-3-sonnet-20240229\").bind_tools(tools)\n",
"llm_with_tools = gpt35.configurable_alternatives(\n",
" ConfigurableField(id=\"llm\"), default_key=\"gpt35\", claude3=claude3\n",
")"
]
},
{
"cell_type": "markdown",
"id": "4719ebdb-ad50-468e-9b30-fb5fb086e140",
"metadata": {},
"source": [
"# AgentExecutor"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "b98feaa5-8c2d-4125-9519-67114a1fef31",
"metadata": {},
"outputs": [],
"source": [
"from typing import List, Tuple, Union\n",
"\n",
"from langchain.agents import AgentExecutor\n",
"from langchain.agents.output_parsers.openai_tools import OpenAIToolAgentAction\n",
"from langchain_core.agents import AgentFinish, _convert_agent_action_to_messages\n",
"from langchain_core.messages import (\n",
" AIMessage,\n",
" BaseMessage,\n",
" ToolMessage,\n",
")\n",
"from langchain_core.prompts import ChatPromptTemplate\n",
"from langchain_core.runnables import RunnablePassthrough\n",
"\n",
"\n",
"def actions_observations_to_messages(\n",
" steps: List[Tuple[OpenAIToolAgentAction, str]],\n",
") -> List[BaseMessage]:\n",
" messages = []\n",
" for action, observation in steps:\n",
" messages.extend([m for m in action.message_log if m not in messages])\n",
" messages.append(ToolMessage(observation, tool_call_id=action.tool_call_id))\n",
" return messages\n",
"\n",
"\n",
"def messages_to_action(\n",
" msg: AIMessage,\n",
") -> Union[List[OpenAIToolAgentAction], AgentFinish]:\n",
" if isinstance(msg, AIMessage) and msg.tool_calls is not None:\n",
" actions = []\n",
" for tool_call in msg.tool_calls:\n",
" actions.append(\n",
" OpenAIToolAgentAction(\n",
" tool=tool_call.name,\n",
" tool_input=tool_call.args,\n",
" tool_call_id=tool_call.id,\n",
" message_log=[msg],\n",
" log=\"\",\n",
" )\n",
" )\n",
" return actions\n",
" else:\n",
" return AgentFinish(return_values={\"output\": msg.content}, log=\"\")\n",
"\n",
"\n",
"prompt = ChatPromptTemplate.from_messages(\n",
" [\n",
" (\"system\", \"You're a helpful assistant with access to tools\"),\n",
" (\"human\", \"{input}\"),\n",
" (\"placeholder\", \"{agent_scratchpad}\"),\n",
" ]\n",
")\n",
"\n",
"agent = (\n",
" RunnablePassthrough.assign(\n",
" agent_scratchpad=lambda x: actions_observations_to_messages(\n",
" x[\"intermediate_steps\"]\n",
" ),\n",
" )\n",
" | prompt\n",
" | llm_with_tools\n",
" | messages_to_action\n",
")\n",
"\n",
"agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "b4c0fc7a-80bb-4bb8-a87b-7388291ae8b6",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
"\u001b[32;1m\u001b[1;3m\u001b[0m\u001b[33;1m\u001b[1;3m300.03770462067547\u001b[0m\u001b[32;1m\u001b[1;3m\u001b[0m\u001b[38;5;200m\u001b[1;3m-900.8841\u001b[0m\u001b[32;1m\u001b[1;3m\u001b[0m\n",
"\n",
"\u001b[1m> Finished chain.\u001b[0m\n"
]
},
{
"data": {
"text/plain": [
"{'input': \"what's 3 plus 5 raised to the 2.743. also what's 17.24 - 918.1241\",\n",
" 'output': 'The result of \\\\(3 + 5^{2.743}\\\\) is approximately 300.04, and the result of \\\\(17.24 - 918.1241\\\\) is approximately -900.88.'}"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"agent_executor.invoke(\n",
" {\"input\": \"what's 3 plus 5 raised to the 2.743. also what's 17.24 - 918.1241\"}\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "41a3a3c8-185d-4861-b6f0-7592668feb62",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/chestercurme/repos/langchain/libs/partners/anthropic/langchain_anthropic/chat_models.py:336: UserWarning: stream: Tool use is not yet supported in streaming mode.\n",
" warnings.warn(\"stream: Tool use is not yet supported in streaming mode.\")\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[32;1m\u001b[1;3m\u001b[0m\u001b[33;1m\u001b[1;3m82.65606421491815\u001b[0m"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/chestercurme/repos/langchain/libs/partners/anthropic/langchain_anthropic/chat_models.py:336: UserWarning: stream: Tool use is not yet supported in streaming mode.\n",
" warnings.warn(\"stream: Tool use is not yet supported in streaming mode.\")\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[32;1m\u001b[1;3m\u001b[0m\u001b[38;5;200m\u001b[1;3m85.65606421491815\u001b[0m"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/chestercurme/repos/langchain/libs/partners/anthropic/langchain_anthropic/chat_models.py:336: UserWarning: stream: Tool use is not yet supported in streaming mode.\n",
" warnings.warn(\"stream: Tool use is not yet supported in streaming mode.\")\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[32;1m\u001b[1;3m\u001b[0m\u001b[38;5;200m\u001b[1;3m-900.8841\u001b[0m"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/chestercurme/repos/langchain/libs/partners/anthropic/langchain_anthropic/chat_models.py:336: UserWarning: stream: Tool use is not yet supported in streaming mode.\n",
" warnings.warn(\"stream: Tool use is not yet supported in streaming mode.\")\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[32;1m\u001b[1;3m\u001b[0m\n",
"\n",
"\u001b[1m> Finished chain.\u001b[0m\n"
]
},
{
"data": {
"text/plain": [
"{'input': \"what's 3 plus 5 raised to the 2.743. also what's 17.24 - 918.1241\",\n",
" 'output': 'Therefore, 17.24 - 918.1241 = -900.8841'}"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"agent_executor = AgentExecutor(\n",
" agent=agent.with_config(configurable={\"llm\": \"claude3\"}), tools=tools, verbose=True\n",
")\n",
"agent_executor.invoke(\n",
" {\"input\": \"what's 3 plus 5 raised to the 2.743. also what's 17.24 - 918.1241\"},\n",
")"
]
},
{
"cell_type": "markdown",
"id": "9c186263-1b98-4cb2-b6d1-71f65eb0d811",
"metadata": {},
"source": [
"# LangGraph"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "28fc2c60-7dbc-428a-8983-1a6a15ea30d2",
"metadata": {},
"outputs": [],
"source": [
"import operator\n",
"from typing import Annotated, Sequence, TypedDict\n",
"\n",
"from langchain_core.messages import AIMessage, BaseMessage, HumanMessage\n",
"from langchain_core.runnables import RunnableLambda\n",
"from langgraph.graph import END, StateGraph\n",
"\n",
"\n",
"class AgentState(TypedDict):\n",
" messages: Annotated[Sequence[BaseMessage], operator.add]\n",
"\n",
"\n",
"def should_continue(state):\n",
" return \"continue\" if state[\"messages\"][-1].tool_calls is not None else \"end\"\n",
"\n",
"\n",
"def call_model(state, config):\n",
" return {\"messages\": [llm_with_tools.invoke(state[\"messages\"], config=config)]}\n",
"\n",
"\n",
"def _invoke_tool(tool_call):\n",
" tool = {tool.name: tool for tool in tools}[tool_call.name]\n",
" return ToolMessage(tool.invoke(tool_call.args), tool_call_id=tool_call.id)\n",
"\n",
"\n",
"tool_executor = RunnableLambda(_invoke_tool)\n",
"\n",
"\n",
"def call_tools(state):\n",
" last_message = state[\"messages\"][-1]\n",
" return {\"messages\": tool_executor.batch(last_message.tool_calls)}\n",
"\n",
"\n",
"workflow = StateGraph(AgentState)\n",
"workflow.add_node(\"agent\", call_model)\n",
"workflow.add_node(\"action\", call_tools)\n",
"workflow.set_entry_point(\"agent\")\n",
"workflow.add_conditional_edges(\n",
" \"agent\",\n",
" should_continue,\n",
" {\n",
" \"continue\": \"action\",\n",
" \"end\": END,\n",
" },\n",
")\n",
"workflow.add_edge(\"action\", \"agent\")\n",
"graph = workflow.compile()"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "24463798-74e6-4c39-8092-7a1524d83225",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'messages': [HumanMessage(content=\"what's 3 plus 5 raised to the 2.743. also what's 17.24 - 918.1241\"),\n",
" AIMessage(content='', additional_kwargs={'tool_calls': [{'id': 'call_kbBUUeqK75fZZqDTvu8aif7Z', 'function': {'arguments': '{\"x\": 8, \"y\": 2.743}', 'name': 'exponentiate'}, 'type': 'function'}, {'id': 'call_pBD8daSyXidXnrIyG4vG5C9O', 'function': {'arguments': '{\"x\": 17.24, \"y\": -918.1241}', 'name': 'add'}, 'type': 'function'}]}, response_metadata={'token_usage': {'completion_tokens': 58, 'prompt_tokens': 168, 'total_tokens': 226}, 'model_name': 'gpt-3.5-turbo-0125', 'system_fingerprint': 'fp_b28b39ffa8', 'finish_reason': 'tool_calls', 'logprobs': None}, id='run-8e1d9687-611c-4c8e-9fcd-ef6e48bd22a6-0', tool_calls=[ToolCall(name='exponentiate', args={'x': 8, 'y': 2.743}, id='call_kbBUUeqK75fZZqDTvu8aif7Z'), ToolCall(name='add', args={'x': 17.24, 'y': -918.1241}, id='call_pBD8daSyXidXnrIyG4vG5C9O')]),\n",
" ToolMessage(content='300.03770462067547', tool_call_id='call_kbBUUeqK75fZZqDTvu8aif7Z'),\n",
" ToolMessage(content='-900.8841', tool_call_id='call_pBD8daSyXidXnrIyG4vG5C9O'),\n",
" AIMessage(content='The result of \\\\(3 + 5^{2.743}\\\\) is approximately 300.04, and the result of \\\\(17.24 - 918.1241\\\\) is approximately -900.88.', response_metadata={'token_usage': {'completion_tokens': 44, 'prompt_tokens': 251, 'total_tokens': 295}, 'model_name': 'gpt-3.5-turbo-0125', 'system_fingerprint': 'fp_b28b39ffa8', 'finish_reason': 'stop', 'logprobs': None}, id='run-47fe5cbc-3f25-44c3-85b2-6540c3054a77-0')]}"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"graph.invoke(\n",
" {\n",
" \"messages\": [\n",
" HumanMessage(\n",
" \"what's 3 plus 5 raised to the 2.743. also what's 17.24 - 918.1241\"\n",
" )\n",
" ]\n",
" }\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "073c074e-d722-42e0-85ec-c62c079207e4",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'messages': [HumanMessage(content=\"what's 3 plus 5 raised to the 2.743. also what's 17.24 - 918.1241\"),\n",
" AIMessage(content=[{'text': \"Okay, let's break this down into steps:\", 'type': 'text'}, {'id': 'toolu_01DJkSDpB8ztmJx2DLNbc3eW', 'input': {'x': 5, 'y': 2.743}, 'name': 'exponentiate', 'type': 'tool_use'}], response_metadata={'id': 'msg_01KuVNohyJr24cPhJkY3XVtt', 'model': 'claude-3-sonnet-20240229', 'stop_reason': 'tool_use', 'stop_sequence': None, 'usage': {'input_tokens': 450, 'output_tokens': 84}}, id='run-336cdfb6-0fe4-4d7a-9946-9f01c2eb41ae-0', tool_calls=[ToolCall(name='exponentiate', args={'x': 5, 'y': 2.743}, id='toolu_01DJkSDpB8ztmJx2DLNbc3eW', index=1)]),\n",
" ToolMessage(content='82.65606421491815', tool_call_id='toolu_01DJkSDpB8ztmJx2DLNbc3eW'),\n",
" AIMessage(content=[{'text': 'To get 5 raised to the 2.743 power.', 'type': 'text'}, {'id': 'toolu_01MKQqnDw5CtyuKjQP8YG1FX', 'input': {'x': 3, 'y': 82.65606421491815}, 'name': 'add', 'type': 'tool_use'}], response_metadata={'id': 'msg_01UBsKkvA4StUR4NEvoFFFep', 'model': 'claude-3-sonnet-20240229', 'stop_reason': 'tool_use', 'stop_sequence': None, 'usage': {'input_tokens': 552, 'output_tokens': 91}}, id='run-9d25b7bd-58aa-47dd-933f-15459b24b2c2-0', tool_calls=[ToolCall(name='add', args={'x': 3, 'y': 82.65606421491815}, id='toolu_01MKQqnDw5CtyuKjQP8YG1FX', index=1)]),\n",
" ToolMessage(content='85.65606421491815', tool_call_id='toolu_01MKQqnDw5CtyuKjQP8YG1FX'),\n",
" AIMessage(content=[{'text': 'So 3 plus 5 raised to the 2.743 power is 85.656.\\n\\nFor the second part:', 'type': 'text'}, {'id': 'toolu_019Wb2zPouCR3dw2bSKvCRUL', 'input': {'x': 17.24, 'y': -918.1241}, 'name': 'add', 'type': 'tool_use'}], response_metadata={'id': 'msg_01Y2H2L8FWcDtVkCtuosie2P', 'model': 'claude-3-sonnet-20240229', 'stop_reason': 'tool_use', 'stop_sequence': None, 'usage': {'input_tokens': 661, 'output_tokens': 105}}, id='run-e553c1e3-24ba-4e1b-93ba-6f1985932db4-0', tool_calls=[ToolCall(name='add', args={'x': 17.24, 'y': -918.1241}, id='toolu_019Wb2zPouCR3dw2bSKvCRUL', index=1)]),\n",
" ToolMessage(content='-900.8841', tool_call_id='toolu_019Wb2zPouCR3dw2bSKvCRUL'),\n",
" AIMessage(content='Therefore, 17.24 - 918.1241 = -900.8841', response_metadata={'id': 'msg_01Q14dqvaCD2eA4zwrUvxTcF', 'model': 'claude-3-sonnet-20240229', 'stop_reason': 'end_turn', 'stop_sequence': None, 'usage': {'input_tokens': 782, 'output_tokens': 24}}, id='run-f6b6e525-2df6-4617-9bb3-b39d5cc963a9-0')]}"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"graph.invoke(\n",
" {\n",
" \"messages\": [\n",
" HumanMessage(\n",
" \"what's 3 plus 5 raised to the 2.743. also what's 17.24 - 918.1241\"\n",
" )\n",
" ]\n",
" },\n",
" config={\"configurable\": {\"llm\": \"claude3\"}},\n",
")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.4"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

@ -15,7 +15,13 @@
""" # noqa: E501
from langchain_core.messages.ai import AIMessage, AIMessageChunk
from langchain_core.messages.ai import (
AIMessage,
AIMessageChunk,
InvalidToolCall,
ToolCall,
ToolCallChunk,
)
from langchain_core.messages.base import (
BaseMessage,
BaseMessageChunk,
@ -50,9 +56,12 @@ __all__ = [
"FunctionMessageChunk",
"HumanMessage",
"HumanMessageChunk",
"InvalidToolCall",
"MessageLikeRepresentation",
"SystemMessage",
"SystemMessageChunk",
"ToolCall",
"ToolCallChunk",
"ToolMessage",
"ToolMessageChunk",
"_message_from_dict",

@ -1,3 +1,4 @@
import warnings
from typing import Any, List, Literal
from langchain_core.messages.base import (
@ -5,7 +6,18 @@ from langchain_core.messages.base import (
BaseMessageChunk,
merge_content,
)
from langchain_core.utils._merge import merge_dicts
from langchain_core.messages.tool import (
InvalidToolCall,
ToolCall,
ToolCallChunk,
default_tool_chunk_parser,
default_tool_parser,
)
from langchain_core.pydantic_v1 import root_validator
from langchain_core.utils._merge import merge_dicts, merge_lists
from langchain_core.utils.json import (
parse_partial_json,
)
class AIMessage(BaseMessage):
@ -16,6 +28,11 @@ class AIMessage(BaseMessage):
conversation.
"""
tool_calls: List[ToolCall] = []
"""If provided, tool calls associated with the message."""
invalid_tool_calls: List[InvalidToolCall] = []
"""If provided, tool calls with parsing errors associated with the message."""
type: Literal["ai"] = "ai"
@classmethod
@ -23,6 +40,34 @@ class AIMessage(BaseMessage):
"""Get the namespace of the langchain object."""
return ["langchain", "schema", "messages"]
@root_validator
def _backwards_compat_tool_calls(cls, values: dict) -> dict:
raw_tool_calls = values.get("additional_kwargs", {}).get("tool_calls")
tool_calls = (
values.get("tool_calls")
or values.get("invalid_tool_calls")
or values.get("tool_call_chunks")
)
if raw_tool_calls and not tool_calls:
warnings.warn(
"New langchain packages are available that more efficiently handle "
"tool calling. Please upgrade your packages to versions that set "
"message tool calls. e.g., `pip install --upgrade langchain-anthropic"
"`, pip install--upgrade langchain-openai`, etc."
)
try:
if issubclass(cls, AIMessageChunk): # type: ignore
values["tool_call_chunks"] = default_tool_chunk_parser(
raw_tool_calls
)
else:
tool_calls, invalid_tool_calls = default_tool_parser(raw_tool_calls)
values["tool_calls"] = tool_calls
values["invalid_tool_calls"] = invalid_tool_calls
except Exception:
pass
return values
AIMessage.update_forward_refs()
@ -35,11 +80,48 @@ class AIMessageChunk(AIMessage, BaseMessageChunk):
# non-chunk variant.
type: Literal["AIMessageChunk"] = "AIMessageChunk" # type: ignore[assignment] # noqa: E501
tool_call_chunks: List[ToolCallChunk] = []
"""If provided, tool call chunks associated with the message."""
@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "schema", "messages"]
@root_validator()
def init_tool_calls(cls, values: dict) -> dict:
if not values["tool_call_chunks"]:
values["tool_calls"] = []
values["invalid_tool_calls"] = []
return values
tool_calls = []
invalid_tool_calls = []
for chunk in values["tool_call_chunks"]:
try:
args_ = parse_partial_json(chunk["args"])
if isinstance(args_, dict):
tool_calls.append(
ToolCall(
name=chunk["name"] or "",
args=args_,
id=chunk["id"],
)
)
else:
raise ValueError("Malformed args.")
except Exception:
invalid_tool_calls.append(
InvalidToolCall(
name=chunk["name"],
args=chunk["args"],
id=chunk["id"],
error="Malformed args.",
)
)
values["tool_calls"] = tool_calls
values["invalid_tool_calls"] = invalid_tool_calls
return values
def __add__(self, other: Any) -> BaseMessageChunk: # type: ignore
if isinstance(other, AIMessageChunk):
if self.example != other.example:
@ -47,15 +129,41 @@ class AIMessageChunk(AIMessage, BaseMessageChunk):
"Cannot concatenate AIMessageChunks with different example values."
)
content = merge_content(self.content, other.content)
additional_kwargs = merge_dicts(
self.additional_kwargs, other.additional_kwargs
)
response_metadata = merge_dicts(
self.response_metadata, other.response_metadata
)
# Merge tool call chunks
if self.tool_call_chunks or other.tool_call_chunks:
raw_tool_calls = merge_lists(
self.tool_call_chunks,
other.tool_call_chunks,
)
if raw_tool_calls:
tool_call_chunks = [
ToolCallChunk(
name=rtc.get("name"),
args=rtc.get("args"),
index=rtc.get("index"),
id=rtc.get("id"),
)
for rtc in raw_tool_calls
]
else:
tool_call_chunks = []
else:
tool_call_chunks = []
return self.__class__(
example=self.example,
content=merge_content(self.content, other.content),
additional_kwargs=merge_dicts(
self.additional_kwargs, other.additional_kwargs
),
response_metadata=merge_dicts(
self.response_metadata, other.response_metadata
),
content=content,
additional_kwargs=additional_kwargs,
tool_call_chunks=tool_call_chunks,
response_metadata=response_metadata,
id=self.id,
)

@ -1,4 +1,7 @@
from typing import Any, List, Literal
import json
from typing import Any, Dict, List, Literal, Optional, Tuple
from typing_extensions import TypedDict
from langchain_core.messages.base import (
BaseMessage,
@ -61,3 +64,112 @@ class ToolMessageChunk(ToolMessage, BaseMessageChunk):
)
return super().__add__(other)
class ToolCall(TypedDict):
"""A call to a tool.
Attributes:
name: (str) the name of the tool to be called
args: (dict) the arguments to the tool call
id: (str) if provided, an identifier associated with the tool call
"""
name: str
args: Dict[str, Any]
id: Optional[str]
class ToolCallChunk(TypedDict):
"""A chunk of a tool call (e.g., as part of a stream).
When merging ToolCallChunks (e.g., via AIMessageChunk.__add__),
all string attributes are concatenated. Chunks are only merged if their
values of `index` are equal and not None.
Example:
.. code-block:: python
left_chunks = [ToolCallChunk(name="foo", args='{"a":', index=0)]
right_chunks = [ToolCallChunk(name=None, args='1}', index=0)]
(
AIMessageChunk(content="", tool_call_chunks=left_chunks)
+ AIMessageChunk(content="", tool_call_chunks=right_chunks)
).tool_call_chunks == [ToolCallChunk(name='foo', args='{"a":1}', index=0)]
Attributes:
name: (str) if provided, a substring of the name of the tool to be called
args: (str) if provided, a JSON substring of the arguments to the tool call
id: (str) if provided, a substring of an identifier for the tool call
index: (int) if provided, the index of the tool call in a sequence
"""
name: Optional[str]
args: Optional[str]
id: Optional[str]
index: Optional[int]
class InvalidToolCall(TypedDict):
"""Allowance for errors made by LLM.
Here we add an `error` key to surface errors made during generation
(e.g., invalid JSON arguments.)
"""
name: Optional[str]
args: Optional[str]
id: Optional[str]
error: Optional[str]
def default_tool_parser(
raw_tool_calls: List[dict],
) -> Tuple[List[ToolCall], List[InvalidToolCall]]:
"""Best-effort parsing of tools."""
tool_calls = []
invalid_tool_calls = []
for tool_call in raw_tool_calls:
if "function" not in tool_call:
continue
else:
function_name = tool_call["function"]["name"]
try:
function_args = json.loads(tool_call["function"]["arguments"])
parsed = ToolCall(
name=function_name or "",
args=function_args or {},
id=tool_call.get("id"),
)
tool_calls.append(parsed)
except json.JSONDecodeError:
invalid_tool_calls.append(
InvalidToolCall(
name=function_name,
args=tool_call["function"]["arguments"],
id=tool_call.get("id"),
error="Malformed args.",
)
)
return tool_calls, invalid_tool_calls
def default_tool_chunk_parser(raw_tool_calls: List[dict]) -> List[ToolCallChunk]:
"""Best-effort parsing of tool chunks."""
tool_call_chunks = []
for tool_call in raw_tool_calls:
if "function" not in tool_call:
function_args = None
function_name = None
else:
function_args = tool_call["function"]["arguments"]
function_name = tool_call["function"]["name"]
parsed = ToolCallChunk(
name=function_name,
args=function_args,
id=tool_call.get("id"),
index=tool_call.get("index"),
)
tool_call_chunks.append(parsed)
return tool_call_chunks

@ -1,6 +1,9 @@
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
from langchain_core.messages.ai import AIMessage, AIMessageChunk
from langchain_core.messages.ai import (
AIMessage,
AIMessageChunk,
)
from langchain_core.messages.base import (
BaseMessage,
BaseMessageChunk,
@ -119,8 +122,11 @@ def message_chunk_to_message(chunk: BaseMessageChunk) -> BaseMessage:
if not isinstance(chunk, BaseMessageChunk):
return chunk
# chunk classes always have the equivalent non-chunk class as their first parent
ignore_keys = ["type"]
if isinstance(chunk, AIMessageChunk):
ignore_keys.append("tool_call_chunks")
return chunk.__class__.__mro__[1](
**{k: v for k, v in chunk.__dict__.items() if k != "type"}
**{k: v for k, v in chunk.__dict__.items() if k not in ignore_keys}
)

@ -1,9 +1,8 @@
from __future__ import annotations
import json
import re
from json import JSONDecodeError
from typing import Any, Callable, List, Optional, Type, TypeVar, Union
from typing import Any, List, Optional, Type, TypeVar, Union
import jsonpatch # type: ignore[import]
import pydantic # pydantic: ignore
@ -12,6 +11,11 @@ from langchain_core.exceptions import OutputParserException
from langchain_core.output_parsers.format_instructions import JSON_FORMAT_INSTRUCTIONS
from langchain_core.output_parsers.transform import BaseCumulativeTransformOutputParser
from langchain_core.outputs import Generation
from langchain_core.utils.json import (
parse_and_check_json_markdown,
parse_json_markdown,
parse_partial_json,
)
from langchain_core.utils.pydantic import PYDANTIC_MAJOR_VERSION
if PYDANTIC_MAJOR_VERSION < 2:
@ -26,182 +30,6 @@ else:
TBaseModel = TypeVar("TBaseModel", bound=PydanticBaseModel)
def _replace_new_line(match: re.Match[str]) -> str:
value = match.group(2)
value = re.sub(r"\n", r"\\n", value)
value = re.sub(r"\r", r"\\r", value)
value = re.sub(r"\t", r"\\t", value)
value = re.sub(r'(?<!\\)"', r"\"", value)
return match.group(1) + value + match.group(3)
def _custom_parser(multiline_string: str) -> str:
"""
The LLM response for `action_input` may be a multiline
string containing unescaped newlines, tabs or quotes. This function
replaces those characters with their escaped counterparts.
(newlines in JSON must be double-escaped: `\\n`)
"""
if isinstance(multiline_string, (bytes, bytearray)):
multiline_string = multiline_string.decode()
multiline_string = re.sub(
r'("action_input"\:\s*")(.*?)(")',
_replace_new_line,
multiline_string,
flags=re.DOTALL,
)
return multiline_string
# Adapted from https://github.com/KillianLucas/open-interpreter/blob/5b6080fae1f8c68938a1e4fa8667e3744084ee21/interpreter/utils/parse_partial_json.py
# MIT License
def parse_partial_json(s: str, *, strict: bool = False) -> Any:
"""Parse a JSON string that may be missing closing braces.
Args:
s: The JSON string to parse.
strict: Whether to use strict parsing. Defaults to False.
Returns:
The parsed JSON object as a Python dictionary.
"""
# Attempt to parse the string as-is.
try:
return json.loads(s, strict=strict)
except json.JSONDecodeError:
pass
# Initialize variables.
new_s = ""
stack = []
is_inside_string = False
escaped = False
# Process each character in the string one at a time.
for char in s:
if is_inside_string:
if char == '"' and not escaped:
is_inside_string = False
elif char == "\n" and not escaped:
char = "\\n" # Replace the newline character with the escape sequence.
elif char == "\\":
escaped = not escaped
else:
escaped = False
else:
if char == '"':
is_inside_string = True
escaped = False
elif char == "{":
stack.append("}")
elif char == "[":
stack.append("]")
elif char == "}" or char == "]":
if stack and stack[-1] == char:
stack.pop()
else:
# Mismatched closing character; the input is malformed.
return None
# Append the processed character to the new string.
new_s += char
# If we're still inside a string at the end of processing,
# we need to close the string.
if is_inside_string:
new_s += '"'
# Try to parse mods of string until we succeed or run out of characters.
while new_s:
final_s = new_s
# Close any remaining open structures in the reverse
# order that they were opened.
for closing_char in reversed(stack):
final_s += closing_char
# Attempt to parse the modified string as JSON.
try:
return json.loads(final_s, strict=strict)
except json.JSONDecodeError:
# If we still can't parse the string as JSON,
# try removing the last character
new_s = new_s[:-1]
# If we got here, we ran out of characters to remove
# and still couldn't parse the string as JSON, so return the parse error
# for the original string.
return json.loads(s, strict=strict)
def parse_json_markdown(
json_string: str, *, parser: Callable[[str], Any] = parse_partial_json
) -> dict:
"""
Parse a JSON string from a Markdown string.
Args:
json_string: The Markdown string.
Returns:
The parsed JSON object as a Python dictionary.
"""
try:
return _parse_json(json_string, parser=parser)
except json.JSONDecodeError:
# Try to find JSON string within triple backticks
match = re.search(r"```(json)?(.*)", json_string, re.DOTALL)
# If no match found, assume the entire string is a JSON string
if match is None:
json_str = json_string
else:
# If match found, use the content within the backticks
json_str = match.group(2)
return _parse_json(json_str, parser=parser)
def _parse_json(
json_str: str, *, parser: Callable[[str], Any] = parse_partial_json
) -> dict:
# Strip whitespace and newlines from the start and end
json_str = json_str.strip().strip("`")
# handle newlines and other special characters inside the returned value
json_str = _custom_parser(json_str)
# Parse the JSON string into a Python dictionary
return parser(json_str)
def parse_and_check_json_markdown(text: str, expected_keys: List[str]) -> dict:
"""
Parse a JSON string from a Markdown string and check that it
contains the expected keys.
Args:
text: The Markdown string.
expected_keys: The expected keys in the JSON string.
Returns:
The parsed JSON object as a Python dictionary.
"""
try:
json_obj = parse_json_markdown(text)
except json.JSONDecodeError as e:
raise OutputParserException(f"Got invalid JSON object. Error: {e}")
for key in expected_keys:
if key not in json_obj:
raise OutputParserException(
f"Got invalid return object. Expected key `{key}` "
f"to be present, but got {json_obj}"
)
return json_obj
class JsonOutputParser(BaseCumulativeTransformOutputParser[Any]):
"""Parse the output of an LLM call to a JSON object.
@ -267,3 +95,5 @@ class JsonOutputParser(BaseCumulativeTransformOutputParser[Any]):
# For backwards compatibility
SimpleJsonOutputParser = JsonOutputParser
parse_partial_json = parse_partial_json
parse_and_check_json_markdown = parse_and_check_json_markdown

@ -1,13 +1,89 @@
import copy
import json
from json import JSONDecodeError
from typing import Any, List, Type
from typing import Any, Dict, List, Optional, Type
from langchain_core.exceptions import OutputParserException
from langchain_core.messages import AIMessage, InvalidToolCall
from langchain_core.output_parsers import BaseCumulativeTransformOutputParser
from langchain_core.output_parsers.json import parse_partial_json
from langchain_core.outputs import ChatGeneration, Generation
from langchain_core.pydantic_v1 import BaseModel, ValidationError
from langchain_core.utils.json import parse_partial_json
def parse_tool_call(
raw_tool_call: Dict[str, Any],
*,
partial: bool = False,
strict: bool = False,
return_id: bool = True,
) -> Optional[Dict[str, Any]]:
"""Parse a single tool call."""
if "function" not in raw_tool_call:
return None
if partial:
try:
function_args = parse_partial_json(
raw_tool_call["function"]["arguments"], strict=strict
)
except (JSONDecodeError, TypeError): # None args raise TypeError
return None
else:
try:
function_args = json.loads(
raw_tool_call["function"]["arguments"], strict=strict
)
except JSONDecodeError as e:
raise OutputParserException(
f"Function {raw_tool_call['function']['name']} arguments:\n\n"
f"{raw_tool_call['function']['arguments']}\n\nare not valid JSON. "
f"Received JSONDecodeError {e}"
)
parsed = {
"name": raw_tool_call["function"]["name"] or "",
"args": function_args or {},
}
if return_id:
parsed["id"] = raw_tool_call["id"]
return parsed
def make_invalid_tool_call(
raw_tool_call: Dict[str, Any],
error_msg: Optional[str],
) -> InvalidToolCall:
"""Create an InvalidToolCall from a raw tool call."""
return InvalidToolCall(
name=raw_tool_call["function"]["name"],
args=raw_tool_call["function"]["arguments"],
id=raw_tool_call.get("id"),
error=error_msg,
)
def parse_tool_calls(
raw_tool_calls: List[dict],
*,
partial: bool = False,
strict: bool = False,
return_id: bool = True,
) -> List[dict]:
"""Parse a list of tool calls."""
final_tools = []
exceptions = []
for tool_call in raw_tool_calls:
try:
parsed = parse_tool_call(
tool_call, partial=partial, strict=strict, return_id=return_id
)
if parsed:
final_tools.append(parsed)
except OutputParserException as e:
exceptions.append(str(e))
continue
if exceptions:
raise OutputParserException("\n\n".join(exceptions))
return final_tools
class JsonOutputToolsParser(BaseCumulativeTransformOutputParser[Any]):
@ -40,47 +116,29 @@ class JsonOutputToolsParser(BaseCumulativeTransformOutputParser[Any]):
"This output parser can only be used with a chat generation."
)
message = generation.message
try:
tool_calls = copy.deepcopy(message.additional_kwargs["tool_calls"])
except KeyError:
return []
final_tools = []
exceptions = []
for tool_call in tool_calls:
if "function" not in tool_call:
continue
if partial:
try:
function_args = parse_partial_json(
tool_call["function"]["arguments"], strict=self.strict
)
except JSONDecodeError:
continue
else:
try:
function_args = json.loads(
tool_call["function"]["arguments"], strict=self.strict
)
except JSONDecodeError as e:
exceptions.append(
f"Function {tool_call['function']['name']} arguments:\n\n"
f"{tool_call['function']['arguments']}\n\nare not valid JSON. "
f"Received JSONDecodeError {e}"
)
continue
parsed = {
"type": tool_call["function"]["name"],
"args": function_args,
}
if self.return_id:
parsed["id"] = tool_call["id"]
final_tools.append(parsed)
if exceptions:
raise OutputParserException("\n\n".join(exceptions))
if isinstance(message, AIMessage) and message.tool_calls:
tool_calls = [dict(tc) for tc in message.tool_calls]
for tool_call in tool_calls:
if not self.return_id:
_ = tool_call.pop("id")
else:
try:
raw_tool_calls = copy.deepcopy(message.additional_kwargs["tool_calls"])
except KeyError:
return []
tool_calls = parse_tool_calls(
raw_tool_calls,
partial=partial,
strict=self.strict,
return_id=self.return_id,
)
# for backwards compatibility
for tc in tool_calls:
tc["type"] = tc.pop("name")
if self.first_tool_only:
return final_tools[0] if final_tools else None
return final_tools
return tool_calls[0] if tool_calls else None
return tool_calls
def parse(self, text: str) -> Any:
raise NotImplementedError()

@ -1,6 +1,6 @@
from __future__ import annotations
from typing import Any, Dict
from typing import Any, Dict, List, Optional
def merge_dicts(left: Dict[str, Any], right: Dict[str, Any]) -> Dict[str, Any]:
@ -33,22 +33,7 @@ def merge_dicts(left: Dict[str, Any], right: Dict[str, Any]) -> Dict[str, Any]:
elif isinstance(merged[right_k], dict):
merged[right_k] = merge_dicts(merged[right_k], right_v)
elif isinstance(merged[right_k], list):
merged[right_k] = merged[right_k].copy()
for e in right_v:
if isinstance(e, dict) and "index" in e and isinstance(e["index"], int):
to_merge = [
i
for i, e_left in enumerate(merged[right_k])
if e_left["index"] == e["index"]
]
if to_merge:
merged[right_k][to_merge[0]] = merge_dicts(
merged[right_k][to_merge[0]], e
)
else:
merged[right_k] = merged[right_k] + [e]
else:
merged[right_k] = merged[right_k] + [e]
merged[right_k] = merge_lists(merged[right_k], right_v)
elif merged[right_k] == right_v:
continue
else:
@ -57,3 +42,27 @@ def merge_dicts(left: Dict[str, Any], right: Dict[str, Any]) -> Dict[str, Any]:
f"value has unsupported type {type(merged[right_k])}."
)
return merged
def merge_lists(left: Optional[List], right: Optional[List]) -> Optional[List]:
"""Add two lists, handling None."""
if left is None and right is None:
return None
elif left is None or right is None:
return left or right
else:
merged = left.copy()
for e in right:
if isinstance(e, dict) and "index" in e and isinstance(e["index"], int):
to_merge = [
i
for i, e_left in enumerate(merged)
if e_left["index"] == e["index"]
]
if to_merge:
merged[to_merge[0]] = merge_dicts(merged[to_merge[0]], e)
else:
merged = merged + [e]
else:
merged = merged + [e]
return merged

@ -0,0 +1,185 @@
from __future__ import annotations
import json
import re
from typing import Any, Callable, List
from langchain_core.exceptions import OutputParserException
def _replace_new_line(match: re.Match[str]) -> str:
value = match.group(2)
value = re.sub(r"\n", r"\\n", value)
value = re.sub(r"\r", r"\\r", value)
value = re.sub(r"\t", r"\\t", value)
value = re.sub(r'(?<!\\)"', r"\"", value)
return match.group(1) + value + match.group(3)
def _custom_parser(multiline_string: str) -> str:
"""
The LLM response for `action_input` may be a multiline
string containing unescaped newlines, tabs or quotes. This function
replaces those characters with their escaped counterparts.
(newlines in JSON must be double-escaped: `\\n`)
"""
if isinstance(multiline_string, (bytes, bytearray)):
multiline_string = multiline_string.decode()
multiline_string = re.sub(
r'("action_input"\:\s*")(.*?)(")',
_replace_new_line,
multiline_string,
flags=re.DOTALL,
)
return multiline_string
# Adapted from https://github.com/KillianLucas/open-interpreter/blob/5b6080fae1f8c68938a1e4fa8667e3744084ee21/interpreter/utils/parse_partial_json.py
# MIT License
def parse_partial_json(s: str, *, strict: bool = False) -> Any:
"""Parse a JSON string that may be missing closing braces.
Args:
s: The JSON string to parse.
strict: Whether to use strict parsing. Defaults to False.
Returns:
The parsed JSON object as a Python dictionary.
"""
# Attempt to parse the string as-is.
try:
return json.loads(s, strict=strict)
except json.JSONDecodeError:
pass
# Initialize variables.
new_s = ""
stack = []
is_inside_string = False
escaped = False
# Process each character in the string one at a time.
for char in s:
if is_inside_string:
if char == '"' and not escaped:
is_inside_string = False
elif char == "\n" and not escaped:
char = "\\n" # Replace the newline character with the escape sequence.
elif char == "\\":
escaped = not escaped
else:
escaped = False
else:
if char == '"':
is_inside_string = True
escaped = False
elif char == "{":
stack.append("}")
elif char == "[":
stack.append("]")
elif char == "}" or char == "]":
if stack and stack[-1] == char:
stack.pop()
else:
# Mismatched closing character; the input is malformed.
return None
# Append the processed character to the new string.
new_s += char
# If we're still inside a string at the end of processing,
# we need to close the string.
if is_inside_string:
new_s += '"'
# Try to parse mods of string until we succeed or run out of characters.
while new_s:
final_s = new_s
# Close any remaining open structures in the reverse
# order that they were opened.
for closing_char in reversed(stack):
final_s += closing_char
# Attempt to parse the modified string as JSON.
try:
return json.loads(final_s, strict=strict)
except json.JSONDecodeError:
# If we still can't parse the string as JSON,
# try removing the last character
new_s = new_s[:-1]
# If we got here, we ran out of characters to remove
# and still couldn't parse the string as JSON, so return the parse error
# for the original string.
return json.loads(s, strict=strict)
def parse_json_markdown(
json_string: str, *, parser: Callable[[str], Any] = parse_partial_json
) -> dict:
"""
Parse a JSON string from a Markdown string.
Args:
json_string: The Markdown string.
Returns:
The parsed JSON object as a Python dictionary.
"""
try:
return _parse_json(json_string, parser=parser)
except json.JSONDecodeError:
# Try to find JSON string within triple backticks
match = re.search(r"```(json)?(.*)", json_string, re.DOTALL)
# If no match found, assume the entire string is a JSON string
if match is None:
json_str = json_string
else:
# If match found, use the content within the backticks
json_str = match.group(2)
return _parse_json(json_str, parser=parser)
def _parse_json(
json_str: str, *, parser: Callable[[str], Any] = parse_partial_json
) -> dict:
# Strip whitespace and newlines from the start and end
json_str = json_str.strip().strip("`")
# handle newlines and other special characters inside the returned value
json_str = _custom_parser(json_str)
# Parse the JSON string into a Python dictionary
return parser(json_str)
def parse_and_check_json_markdown(text: str, expected_keys: List[str]) -> dict:
"""
Parse a JSON string from a Markdown string and check that it
contains the expected keys.
Args:
text: The Markdown string.
expected_keys: The expected keys in the JSON string.
Returns:
The parsed JSON object as a Python dictionary.
"""
try:
json_obj = parse_json_markdown(text)
except json.JSONDecodeError as e:
raise OutputParserException(f"Got invalid JSON object. Error: {e}")
for key in expected_keys:
if key not in json_obj:
raise OutputParserException(
f"Got invalid return object. Expected key `{key}` "
f"to be present, but got {json_obj}"
)
return json_obj

@ -14,8 +14,11 @@ EXPECTED_ALL = [
"FunctionMessageChunk",
"HumanMessage",
"HumanMessageChunk",
"InvalidToolCall",
"SystemMessage",
"SystemMessageChunk",
"ToolCall",
"ToolCallChunk",
"ToolMessage",
"ToolMessageChunk",
"convert_to_messages",

@ -5,11 +5,10 @@ import pytest
from langchain_core.output_parsers.json import (
SimpleJsonOutputParser,
parse_json_markdown,
parse_partial_json,
)
from langchain_core.pydantic_v1 import BaseModel
from langchain_core.utils.function_calling import convert_to_openai_function
from langchain_core.utils.json import parse_json_markdown, parse_partial_json
GOOD_JSON = """```json
{

@ -1,6 +1,6 @@
from typing import Any, AsyncIterator, Iterator, List
from langchain_core.messages import AIMessageChunk, BaseMessage
from langchain_core.messages import AIMessageChunk, BaseMessage, ToolCallChunk
from langchain_core.output_parsers.openai_tools import (
JsonOutputKeyToolsParser,
JsonOutputToolsParser,
@ -300,6 +300,28 @@ STREAMED_MESSAGES: list = [
]
STREAMED_MESSAGES_WITH_TOOL_CALLS = []
for message in STREAMED_MESSAGES:
if message.additional_kwargs:
STREAMED_MESSAGES_WITH_TOOL_CALLS.append(
AIMessageChunk(
content=message.content,
additional_kwargs=message.additional_kwargs,
tool_call_chunks=[
ToolCallChunk(
name=chunk["function"].get("name"),
args=chunk["function"].get("arguments"),
id=chunk.get("id"),
index=chunk["index"],
)
for chunk in message.additional_kwargs["tool_calls"]
],
)
)
else:
STREAMED_MESSAGES_WITH_TOOL_CALLS.append(message)
EXPECTED_STREAMED_JSON = [
{},
{"names": ["suz"]},
@ -330,101 +352,118 @@ EXPECTED_STREAMED_JSON = [
]
def test_partial_json_output_parser() -> None:
def _get_iter(use_tool_calls: bool = False) -> Any:
if use_tool_calls:
list_to_iter = STREAMED_MESSAGES_WITH_TOOL_CALLS
else:
list_to_iter = STREAMED_MESSAGES
def input_iter(_: Any) -> Iterator[BaseMessage]:
for msg in STREAMED_MESSAGES:
for msg in list_to_iter:
yield msg
chain = input_iter | JsonOutputToolsParser()
return input_iter
actual = list(chain.stream(None))
expected: list = [[]] + [
[{"type": "NameCollector", "args": chunk}] for chunk in EXPECTED_STREAMED_JSON
]
assert actual == expected
def _get_aiter(use_tool_calls: bool = False) -> Any:
if use_tool_calls:
list_to_iter = STREAMED_MESSAGES_WITH_TOOL_CALLS
else:
list_to_iter = STREAMED_MESSAGES
async def test_partial_json_output_parser_async() -> None:
async def input_iter(_: Any) -> AsyncIterator[BaseMessage]:
for token in STREAMED_MESSAGES:
yield token
for msg in list_to_iter:
yield msg
chain = input_iter | JsonOutputToolsParser()
return input_iter
actual = [p async for p in chain.astream(None)]
expected: list = [[]] + [
[{"type": "NameCollector", "args": chunk}] for chunk in EXPECTED_STREAMED_JSON
]
assert actual == expected
def test_partial_json_output_parser() -> None:
for use_tool_calls in [False, True]:
input_iter = _get_iter(use_tool_calls)
chain = input_iter | JsonOutputToolsParser()
actual = list(chain.stream(None))
expected: list = [[]] + [
[{"type": "NameCollector", "args": chunk}]
for chunk in EXPECTED_STREAMED_JSON
]
assert actual == expected
async def test_partial_json_output_parser_async() -> None:
for use_tool_calls in [False, True]:
input_iter = _get_aiter(use_tool_calls)
chain = input_iter | JsonOutputToolsParser()
actual = [p async for p in chain.astream(None)]
expected: list = [[]] + [
[{"type": "NameCollector", "args": chunk}]
for chunk in EXPECTED_STREAMED_JSON
]
assert actual == expected
def test_partial_json_output_parser_return_id() -> None:
def input_iter(_: Any) -> Iterator[BaseMessage]:
for msg in STREAMED_MESSAGES:
yield msg
chain = input_iter | JsonOutputToolsParser(return_id=True)
def test_partial_json_output_parser_return_id() -> None:
for use_tool_calls in [False, True]:
input_iter = _get_iter(use_tool_calls)
chain = input_iter | JsonOutputToolsParser(return_id=True)
actual = list(chain.stream(None))
expected: list = [[]] + [
[
{
"type": "NameCollector",
"args": chunk,
"id": "call_OwL7f5PEPJTYzw9sQlNJtCZl",
}
actual = list(chain.stream(None))
expected: list = [[]] + [
[
{
"type": "NameCollector",
"args": chunk,
"id": "call_OwL7f5PEPJTYzw9sQlNJtCZl",
}
]
for chunk in EXPECTED_STREAMED_JSON
]
for chunk in EXPECTED_STREAMED_JSON
]
assert actual == expected
assert actual == expected
def test_partial_json_output_key_parser() -> None:
def input_iter(_: Any) -> Iterator[BaseMessage]:
for msg in STREAMED_MESSAGES:
yield msg
chain = input_iter | JsonOutputKeyToolsParser(key_name="NameCollector")
for use_tool_calls in [False, True]:
input_iter = _get_iter(use_tool_calls)
chain = input_iter | JsonOutputKeyToolsParser(key_name="NameCollector")
actual = list(chain.stream(None))
expected: list = [[]] + [[chunk] for chunk in EXPECTED_STREAMED_JSON]
assert actual == expected
actual = list(chain.stream(None))
expected: list = [[]] + [[chunk] for chunk in EXPECTED_STREAMED_JSON]
assert actual == expected
async def test_partial_json_output_parser_key_async() -> None:
async def input_iter(_: Any) -> AsyncIterator[BaseMessage]:
for token in STREAMED_MESSAGES:
yield token
for use_tool_calls in [False, True]:
input_iter = _get_aiter(use_tool_calls)
chain = input_iter | JsonOutputKeyToolsParser(key_name="NameCollector")
chain = input_iter | JsonOutputKeyToolsParser(key_name="NameCollector")
actual = [p async for p in chain.astream(None)]
expected: list = [[]] + [[chunk] for chunk in EXPECTED_STREAMED_JSON]
assert actual == expected
actual = [p async for p in chain.astream(None)]
expected: list = [[]] + [[chunk] for chunk in EXPECTED_STREAMED_JSON]
assert actual == expected
def test_partial_json_output_key_parser_first_only() -> None:
def input_iter(_: Any) -> Iterator[BaseMessage]:
for msg in STREAMED_MESSAGES:
yield msg
for use_tool_calls in [False, True]:
input_iter = _get_iter(use_tool_calls)
chain = input_iter | JsonOutputKeyToolsParser(
key_name="NameCollector", first_tool_only=True
)
chain = input_iter | JsonOutputKeyToolsParser(
key_name="NameCollector", first_tool_only=True
)
assert list(chain.stream(None)) == EXPECTED_STREAMED_JSON
assert list(chain.stream(None)) == EXPECTED_STREAMED_JSON
async def test_partial_json_output_parser_key_async_first_only() -> None:
async def input_iter(_: Any) -> AsyncIterator[BaseMessage]:
for token in STREAMED_MESSAGES:
yield token
for use_tool_calls in [False, True]:
input_iter = _get_aiter(use_tool_calls)
chain = input_iter | JsonOutputKeyToolsParser(
key_name="NameCollector", first_tool_only=True
)
chain = input_iter | JsonOutputKeyToolsParser(
key_name="NameCollector", first_tool_only=True
)
assert [p async for p in chain.astream(None)] == EXPECTED_STREAMED_JSON
assert [p async for p in chain.astream(None)] == EXPECTED_STREAMED_JSON
class Person(BaseModel):
@ -458,26 +497,24 @@ EXPECTED_STREAMED_PYDANTIC = [
def test_partial_pydantic_output_parser() -> None:
def input_iter(_: Any) -> Iterator[BaseMessage]:
for msg in STREAMED_MESSAGES:
yield msg
for use_tool_calls in [False, True]:
input_iter = _get_iter(use_tool_calls)
chain = input_iter | PydanticToolsParser(
tools=[NameCollector], first_tool_only=True
)
chain = input_iter | PydanticToolsParser(
tools=[NameCollector], first_tool_only=True
)
actual = list(chain.stream(None))
assert actual == EXPECTED_STREAMED_PYDANTIC
actual = list(chain.stream(None))
assert actual == EXPECTED_STREAMED_PYDANTIC
async def test_partial_pydantic_output_parser_async() -> None:
async def input_iter(_: Any) -> AsyncIterator[BaseMessage]:
for token in STREAMED_MESSAGES:
yield token
for use_tool_calls in [False, True]:
input_iter = _get_aiter(use_tool_calls)
chain = input_iter | PydanticToolsParser(
tools=[NameCollector], first_tool_only=True
)
chain = input_iter | PydanticToolsParser(
tools=[NameCollector], first_tool_only=True
)
actual = [p async for p in chain.astream(None)]
assert actual == EXPECTED_STREAMED_PYDANTIC
actual = [p async for p in chain.astream(None)]
assert actual == EXPECTED_STREAMED_PYDANTIC

@ -5299,6 +5299,15 @@
'title': 'Id',
'type': 'string',
}),
'invalid_tool_calls': dict({
'default': list([
]),
'items': dict({
'$ref': '#/definitions/InvalidToolCall',
}),
'title': 'Invalid Tool Calls',
'type': 'array',
}),
'name': dict({
'title': 'Name',
'type': 'string',
@ -5307,6 +5316,15 @@
'title': 'Response Metadata',
'type': 'object',
}),
'tool_calls': dict({
'default': list([
]),
'items': dict({
'$ref': '#/definitions/ToolCall',
}),
'title': 'Tool Calls',
'type': 'array',
}),
'type': dict({
'default': 'ai',
'enum': list([
@ -5545,6 +5563,34 @@
'title': 'HumanMessage',
'type': 'object',
}),
'InvalidToolCall': dict({
'properties': dict({
'args': dict({
'title': 'Args',
'type': 'string',
}),
'error': dict({
'title': 'Error',
'type': 'string',
}),
'id': dict({
'title': 'Id',
'type': 'string',
}),
'name': dict({
'title': 'Name',
'type': 'string',
}),
}),
'required': list([
'name',
'args',
'id',
'error',
]),
'title': 'InvalidToolCall',
'type': 'object',
}),
'StringPromptValue': dict({
'description': 'String prompt value.',
'properties': dict({
@ -5625,6 +5671,29 @@
'title': 'SystemMessage',
'type': 'object',
}),
'ToolCall': dict({
'properties': dict({
'args': dict({
'title': 'Args',
'type': 'object',
}),
'id': dict({
'title': 'Id',
'type': 'string',
}),
'name': dict({
'title': 'Name',
'type': 'string',
}),
}),
'required': list([
'name',
'args',
'id',
]),
'title': 'ToolCall',
'type': 'object',
}),
'ToolMessage': dict({
'description': 'Message for passing the result of executing a tool back to a model.',
'properties': dict({
@ -5765,6 +5834,15 @@
'title': 'Id',
'type': 'string',
}),
'invalid_tool_calls': dict({
'default': list([
]),
'items': dict({
'$ref': '#/definitions/InvalidToolCall',
}),
'title': 'Invalid Tool Calls',
'type': 'array',
}),
'name': dict({
'title': 'Name',
'type': 'string',
@ -5773,6 +5851,15 @@
'title': 'Response Metadata',
'type': 'object',
}),
'tool_calls': dict({
'default': list([
]),
'items': dict({
'$ref': '#/definitions/ToolCall',
}),
'title': 'Tool Calls',
'type': 'array',
}),
'type': dict({
'default': 'ai',
'enum': list([
@ -6011,6 +6098,34 @@
'title': 'HumanMessage',
'type': 'object',
}),
'InvalidToolCall': dict({
'properties': dict({
'args': dict({
'title': 'Args',
'type': 'string',
}),
'error': dict({
'title': 'Error',
'type': 'string',
}),
'id': dict({
'title': 'Id',
'type': 'string',
}),
'name': dict({
'title': 'Name',
'type': 'string',
}),
}),
'required': list([
'name',
'args',
'id',
'error',
]),
'title': 'InvalidToolCall',
'type': 'object',
}),
'StringPromptValue': dict({
'description': 'String prompt value.',
'properties': dict({
@ -6091,6 +6206,29 @@
'title': 'SystemMessage',
'type': 'object',
}),
'ToolCall': dict({
'properties': dict({
'args': dict({
'title': 'Args',
'type': 'object',
}),
'id': dict({
'title': 'Id',
'type': 'string',
}),
'name': dict({
'title': 'Name',
'type': 'string',
}),
}),
'required': list([
'name',
'args',
'id',
]),
'title': 'ToolCall',
'type': 'object',
}),
'ToolMessage': dict({
'description': 'Message for passing the result of executing a tool back to a model.',
'properties': dict({
@ -6215,6 +6353,15 @@
'title': 'Id',
'type': 'string',
}),
'invalid_tool_calls': dict({
'default': list([
]),
'items': dict({
'$ref': '#/definitions/InvalidToolCall',
}),
'title': 'Invalid Tool Calls',
'type': 'array',
}),
'name': dict({
'title': 'Name',
'type': 'string',
@ -6223,6 +6370,15 @@
'title': 'Response Metadata',
'type': 'object',
}),
'tool_calls': dict({
'default': list([
]),
'items': dict({
'$ref': '#/definitions/ToolCall',
}),
'title': 'Tool Calls',
'type': 'array',
}),
'type': dict({
'default': 'ai',
'enum': list([
@ -6414,6 +6570,34 @@
'title': 'HumanMessage',
'type': 'object',
}),
'InvalidToolCall': dict({
'properties': dict({
'args': dict({
'title': 'Args',
'type': 'string',
}),
'error': dict({
'title': 'Error',
'type': 'string',
}),
'id': dict({
'title': 'Id',
'type': 'string',
}),
'name': dict({
'title': 'Name',
'type': 'string',
}),
}),
'required': list([
'name',
'args',
'id',
'error',
]),
'title': 'InvalidToolCall',
'type': 'object',
}),
'SystemMessage': dict({
'description': '''
Message for priming AI behavior, usually passed in as the first of a sequence
@ -6472,6 +6656,29 @@
'title': 'SystemMessage',
'type': 'object',
}),
'ToolCall': dict({
'properties': dict({
'args': dict({
'title': 'Args',
'type': 'object',
}),
'id': dict({
'title': 'Id',
'type': 'string',
}),
'name': dict({
'title': 'Name',
'type': 'string',
}),
}),
'required': list([
'name',
'args',
'id',
]),
'title': 'ToolCall',
'type': 'object',
}),
'ToolMessage': dict({
'description': 'Message for passing the result of executing a tool back to a model.',
'properties': dict({
@ -6584,6 +6791,15 @@
'title': 'Id',
'type': 'string',
}),
'invalid_tool_calls': dict({
'default': list([
]),
'items': dict({
'$ref': '#/definitions/InvalidToolCall',
}),
'title': 'Invalid Tool Calls',
'type': 'array',
}),
'name': dict({
'title': 'Name',
'type': 'string',
@ -6592,6 +6808,15 @@
'title': 'Response Metadata',
'type': 'object',
}),
'tool_calls': dict({
'default': list([
]),
'items': dict({
'$ref': '#/definitions/ToolCall',
}),
'title': 'Tool Calls',
'type': 'array',
}),
'type': dict({
'default': 'ai',
'enum': list([
@ -6830,6 +7055,34 @@
'title': 'HumanMessage',
'type': 'object',
}),
'InvalidToolCall': dict({
'properties': dict({
'args': dict({
'title': 'Args',
'type': 'string',
}),
'error': dict({
'title': 'Error',
'type': 'string',
}),
'id': dict({
'title': 'Id',
'type': 'string',
}),
'name': dict({
'title': 'Name',
'type': 'string',
}),
}),
'required': list([
'name',
'args',
'id',
'error',
]),
'title': 'InvalidToolCall',
'type': 'object',
}),
'StringPromptValue': dict({
'description': 'String prompt value.',
'properties': dict({
@ -6910,6 +7163,29 @@
'title': 'SystemMessage',
'type': 'object',
}),
'ToolCall': dict({
'properties': dict({
'args': dict({
'title': 'Args',
'type': 'object',
}),
'id': dict({
'title': 'Id',
'type': 'string',
}),
'name': dict({
'title': 'Name',
'type': 'string',
}),
}),
'required': list([
'name',
'args',
'id',
]),
'title': 'ToolCall',
'type': 'object',
}),
'ToolMessage': dict({
'description': 'Message for passing the result of executing a tool back to a model.',
'properties': dict({
@ -7022,6 +7298,15 @@
'title': 'Id',
'type': 'string',
}),
'invalid_tool_calls': dict({
'default': list([
]),
'items': dict({
'$ref': '#/definitions/InvalidToolCall',
}),
'title': 'Invalid Tool Calls',
'type': 'array',
}),
'name': dict({
'title': 'Name',
'type': 'string',
@ -7030,6 +7315,15 @@
'title': 'Response Metadata',
'type': 'object',
}),
'tool_calls': dict({
'default': list([
]),
'items': dict({
'$ref': '#/definitions/ToolCall',
}),
'title': 'Tool Calls',
'type': 'array',
}),
'type': dict({
'default': 'ai',
'enum': list([
@ -7268,6 +7562,34 @@
'title': 'HumanMessage',
'type': 'object',
}),
'InvalidToolCall': dict({
'properties': dict({
'args': dict({
'title': 'Args',
'type': 'string',
}),
'error': dict({
'title': 'Error',
'type': 'string',
}),
'id': dict({
'title': 'Id',
'type': 'string',
}),
'name': dict({
'title': 'Name',
'type': 'string',
}),
}),
'required': list([
'name',
'args',
'id',
'error',
]),
'title': 'InvalidToolCall',
'type': 'object',
}),
'StringPromptValue': dict({
'description': 'String prompt value.',
'properties': dict({
@ -7348,6 +7670,29 @@
'title': 'SystemMessage',
'type': 'object',
}),
'ToolCall': dict({
'properties': dict({
'args': dict({
'title': 'Args',
'type': 'object',
}),
'id': dict({
'title': 'Id',
'type': 'string',
}),
'name': dict({
'title': 'Name',
'type': 'string',
}),
}),
'required': list([
'name',
'args',
'id',
]),
'title': 'ToolCall',
'type': 'object',
}),
'ToolMessage': dict({
'description': 'Message for passing the result of executing a tool back to a model.',
'properties': dict({
@ -7452,6 +7797,15 @@
'title': 'Id',
'type': 'string',
}),
'invalid_tool_calls': dict({
'default': list([
]),
'items': dict({
'$ref': '#/definitions/InvalidToolCall',
}),
'title': 'Invalid Tool Calls',
'type': 'array',
}),
'name': dict({
'title': 'Name',
'type': 'string',
@ -7460,6 +7814,15 @@
'title': 'Response Metadata',
'type': 'object',
}),
'tool_calls': dict({
'default': list([
]),
'items': dict({
'$ref': '#/definitions/ToolCall',
}),
'title': 'Tool Calls',
'type': 'array',
}),
'type': dict({
'default': 'ai',
'enum': list([
@ -7698,6 +8061,34 @@
'title': 'HumanMessage',
'type': 'object',
}),
'InvalidToolCall': dict({
'properties': dict({
'args': dict({
'title': 'Args',
'type': 'string',
}),
'error': dict({
'title': 'Error',
'type': 'string',
}),
'id': dict({
'title': 'Id',
'type': 'string',
}),
'name': dict({
'title': 'Name',
'type': 'string',
}),
}),
'required': list([
'name',
'args',
'id',
'error',
]),
'title': 'InvalidToolCall',
'type': 'object',
}),
'PromptTemplateOutput': dict({
'anyOf': list([
dict({
@ -7789,6 +8180,29 @@
'title': 'SystemMessage',
'type': 'object',
}),
'ToolCall': dict({
'properties': dict({
'args': dict({
'title': 'Args',
'type': 'object',
}),
'id': dict({
'title': 'Id',
'type': 'string',
}),
'name': dict({
'title': 'Name',
'type': 'string',
}),
}),
'required': list([
'name',
'args',
'id',
]),
'title': 'ToolCall',
'type': 'object',
}),
'ToolMessage': dict({
'description': 'Message for passing the result of executing a tool back to a model.',
'properties': dict({
@ -7920,6 +8334,15 @@
'title': 'Id',
'type': 'string',
}),
'invalid_tool_calls': dict({
'default': list([
]),
'items': dict({
'$ref': '#/definitions/InvalidToolCall',
}),
'title': 'Invalid Tool Calls',
'type': 'array',
}),
'name': dict({
'title': 'Name',
'type': 'string',
@ -7928,6 +8351,15 @@
'title': 'Response Metadata',
'type': 'object',
}),
'tool_calls': dict({
'default': list([
]),
'items': dict({
'$ref': '#/definitions/ToolCall',
}),
'title': 'Tool Calls',
'type': 'array',
}),
'type': dict({
'default': 'ai',
'enum': list([
@ -8119,6 +8551,34 @@
'title': 'HumanMessage',
'type': 'object',
}),
'InvalidToolCall': dict({
'properties': dict({
'args': dict({
'title': 'Args',
'type': 'string',
}),
'error': dict({
'title': 'Error',
'type': 'string',
}),
'id': dict({
'title': 'Id',
'type': 'string',
}),
'name': dict({
'title': 'Name',
'type': 'string',
}),
}),
'required': list([
'name',
'args',
'id',
'error',
]),
'title': 'InvalidToolCall',
'type': 'object',
}),
'SystemMessage': dict({
'description': '''
Message for priming AI behavior, usually passed in as the first of a sequence
@ -8177,6 +8637,29 @@
'title': 'SystemMessage',
'type': 'object',
}),
'ToolCall': dict({
'properties': dict({
'args': dict({
'title': 'Args',
'type': 'object',
}),
'id': dict({
'title': 'Id',
'type': 'string',
}),
'name': dict({
'title': 'Name',
'type': 'string',
}),
}),
'required': list([
'name',
'args',
'id',
]),
'title': 'ToolCall',
'type': 'object',
}),
'ToolMessage': dict({
'description': 'Message for passing the result of executing a tool back to a model.',
'properties': dict({

@ -206,6 +206,27 @@ def test_graph_sequence_map(snapshot: SnapshotAssertion) -> None:
{"$ref": "#/definitions/ToolMessage"},
],
"definitions": {
"ToolCall": {
"title": "ToolCall",
"type": "object",
"properties": {
"name": {"title": "Name", "type": "string"},
"args": {"title": "Args", "type": "object"},
"id": {"title": "Id", "type": "string"},
},
"required": ["name", "args", "id"],
},
"InvalidToolCall": {
"title": "InvalidToolCall",
"type": "object",
"properties": {
"name": {"title": "Name", "type": "string"},
"args": {"title": "Args", "type": "string"},
"id": {"title": "Id", "type": "string"},
"error": {"title": "Error", "type": "string"},
},
"required": ["name", "args", "id", "error"],
},
"AIMessage": {
"title": "AIMessage",
"description": "Message from an AI.",
@ -240,13 +261,25 @@ def test_graph_sequence_map(snapshot: SnapshotAssertion) -> None:
"enum": ["ai"],
"type": "string",
},
"id": {"title": "Id", "type": "string"},
"name": {"title": "Name", "type": "string"},
"id": {"title": "Id", "type": "string"},
"example": {
"title": "Example",
"default": False,
"type": "boolean",
},
"tool_calls": {
"title": "Tool Calls",
"default": [],
"type": "array",
"items": {"$ref": "#/definitions/ToolCall"},
},
"invalid_tool_calls": {
"title": "Invalid Tool Calls",
"default": [],
"type": "array",
"items": {"$ref": "#/definitions/InvalidToolCall"},
},
},
"required": ["content"],
},
@ -284,8 +317,8 @@ def test_graph_sequence_map(snapshot: SnapshotAssertion) -> None:
"enum": ["human"],
"type": "string",
},
"id": {"title": "Id", "type": "string"},
"name": {"title": "Name", "type": "string"},
"id": {"title": "Id", "type": "string"},
"example": {
"title": "Example",
"default": False,
@ -328,8 +361,8 @@ def test_graph_sequence_map(snapshot: SnapshotAssertion) -> None:
"enum": ["chat"],
"type": "string",
},
"id": {"title": "Id", "type": "string"},
"name": {"title": "Name", "type": "string"},
"id": {"title": "Id", "type": "string"},
"role": {"title": "Role", "type": "string"},
},
"required": ["content", "role"],
@ -368,8 +401,8 @@ def test_graph_sequence_map(snapshot: SnapshotAssertion) -> None:
"enum": ["system"],
"type": "string",
},
"id": {"title": "Id", "type": "string"},
"name": {"title": "Name", "type": "string"},
"id": {"title": "Id", "type": "string"},
},
"required": ["content"],
},
@ -407,8 +440,8 @@ def test_graph_sequence_map(snapshot: SnapshotAssertion) -> None:
"enum": ["function"],
"type": "string",
},
"id": {"title": "Id", "type": "string"},
"name": {"title": "Name", "type": "string"},
"id": {"title": "Id", "type": "string"},
},
"required": ["content", "name"],
},
@ -446,8 +479,8 @@ def test_graph_sequence_map(snapshot: SnapshotAssertion) -> None:
"enum": ["tool"],
"type": "string",
},
"id": {"title": "Id", "type": "string"},
"name": {"title": "Name", "type": "string"},
"id": {"title": "Id", "type": "string"},
"tool_call_id": {
"title": "Tool Call Id",
"type": "string",

@ -357,6 +357,27 @@ def test_schemas(snapshot: SnapshotAssertion) -> None:
}
},
"definitions": {
"ToolCall": {
"title": "ToolCall",
"type": "object",
"properties": {
"name": {"title": "Name", "type": "string"},
"args": {"title": "Args", "type": "object"},
"id": {"title": "Id", "type": "string"},
},
"required": ["name", "args", "id"],
},
"InvalidToolCall": {
"title": "InvalidToolCall",
"type": "object",
"properties": {
"name": {"title": "Name", "type": "string"},
"args": {"title": "Args", "type": "string"},
"id": {"title": "Id", "type": "string"},
"error": {"title": "Error", "type": "string"},
},
"required": ["name", "args", "id", "error"],
},
"AIMessage": {
"title": "AIMessage",
"description": "Message from an AI.",
@ -388,13 +409,25 @@ def test_schemas(snapshot: SnapshotAssertion) -> None:
"enum": ["ai"],
"type": "string",
},
"id": {"title": "Id", "type": "string"},
"name": {"title": "Name", "type": "string"},
"id": {"title": "Id", "type": "string"},
"example": {
"title": "Example",
"default": False,
"type": "boolean",
},
"tool_calls": {
"title": "Tool Calls",
"default": [],
"type": "array",
"items": {"$ref": "#/definitions/ToolCall"},
},
"invalid_tool_calls": {
"title": "Invalid Tool Calls",
"default": [],
"type": "array",
"items": {"$ref": "#/definitions/InvalidToolCall"},
},
},
"required": ["content"],
},
@ -429,8 +462,8 @@ def test_schemas(snapshot: SnapshotAssertion) -> None:
"enum": ["human"],
"type": "string",
},
"id": {"title": "Id", "type": "string"},
"name": {"title": "Name", "type": "string"},
"id": {"title": "Id", "type": "string"},
"example": {
"title": "Example",
"default": False,
@ -470,8 +503,8 @@ def test_schemas(snapshot: SnapshotAssertion) -> None:
"enum": ["chat"],
"type": "string",
},
"id": {"title": "Id", "type": "string"},
"name": {"title": "Name", "type": "string"},
"id": {"title": "Id", "type": "string"},
"role": {"title": "Role", "type": "string"},
},
"required": ["content", "role"],
@ -507,8 +540,8 @@ def test_schemas(snapshot: SnapshotAssertion) -> None:
"enum": ["system"],
"type": "string",
},
"id": {"title": "Id", "type": "string"},
"name": {"title": "Name", "type": "string"},
"id": {"title": "Id", "type": "string"},
},
"required": ["content"],
},
@ -543,8 +576,8 @@ def test_schemas(snapshot: SnapshotAssertion) -> None:
"enum": ["function"],
"type": "string",
},
"id": {"title": "Id", "type": "string"},
"name": {"title": "Name", "type": "string"},
"id": {"title": "Id", "type": "string"},
},
"required": ["content", "name"],
},
@ -579,8 +612,8 @@ def test_schemas(snapshot: SnapshotAssertion) -> None:
"enum": ["tool"],
"type": "string",
},
"id": {"title": "Id", "type": "string"},
"name": {"title": "Name", "type": "string"},
"id": {"title": "Id", "type": "string"},
"tool_call_id": {"title": "Tool Call Id", "type": "string"},
},
"required": ["content", "tool_call_id"],

@ -13,6 +13,8 @@ from langchain_core.messages import (
HumanMessage,
HumanMessageChunk,
SystemMessage,
ToolCall,
ToolCallChunk,
ToolMessage,
convert_to_messages,
get_buffer_string,
@ -20,6 +22,7 @@ from langchain_core.messages import (
messages_from_dict,
messages_to_dict,
)
from langchain_core.utils._merge import merge_lists
def test_message_chunks() -> None:
@ -68,6 +71,55 @@ def test_message_chunks() -> None:
)
), "MessageChunk + MessageChunk should be a MessageChunk with merged additional_kwargs" # noqa: E501
# Test tool calls
assert (
AIMessageChunk(
content="",
tool_call_chunks=[ToolCallChunk(name="tool1", args="", id="1", index=0)],
)
+ AIMessageChunk(
content="",
tool_call_chunks=[
ToolCallChunk(name=None, args='{"arg1": "val', id=None, index=0)
],
)
+ AIMessageChunk(
content="",
tool_call_chunks=[ToolCallChunk(name=None, args='ue}"', id=None, index=0)],
)
) == AIMessageChunk(
content="",
tool_call_chunks=[
ToolCallChunk(name="tool1", args='{"arg1": "value}"', id="1", index=0)
],
)
assert (
AIMessageChunk(
content="",
tool_call_chunks=[ToolCallChunk(name="tool1", args="", id="1", index=0)],
)
+ AIMessageChunk(
content="",
tool_call_chunks=[ToolCallChunk(name="tool1", args="a", id=None, index=1)],
)
# Don't merge if `index` field does not match.
) == AIMessageChunk(
content="",
tool_call_chunks=[
ToolCallChunk(name="tool1", args="", id="1", index=0),
ToolCallChunk(name="tool1", args="a", id=None, index=1),
],
)
ai_msg_chunk = AIMessageChunk(content="")
tool_calls_msg_chunk = AIMessageChunk(
content="",
tool_call_chunks=[ToolCallChunk(name="tool1", args="a", id=None, index=1)],
)
assert ai_msg_chunk + tool_calls_msg_chunk == tool_calls_msg_chunk
assert tool_calls_msg_chunk + ai_msg_chunk == tool_calls_msg_chunk
def test_chat_message_chunks() -> None:
assert ChatMessageChunk(role="User", content="I am", id="ai4") + ChatMessageChunk(
@ -128,6 +180,7 @@ class TestGetBufferString(unittest.TestCase):
self.func_msg = FunctionMessage(name="func", content="function")
self.tool_msg = ToolMessage(tool_call_id="tool_id", content="tool")
self.chat_msg = ChatMessage(role="Chat", content="chat")
self.tool_calls_msg = AIMessage(content="tool")
def test_empty_input(self) -> None:
self.assertEqual(get_buffer_string([]), "")
@ -163,6 +216,7 @@ class TestGetBufferString(unittest.TestCase):
self.func_msg,
self.tool_msg,
self.chat_msg,
self.tool_calls_msg,
]
expected_output = "\n".join(
[
@ -172,6 +226,7 @@ class TestGetBufferString(unittest.TestCase):
"Function: function",
"Tool: tool",
"Chat: chat",
"AI: tool",
]
)
self.assertEqual(
@ -192,6 +247,19 @@ def test_multiple_msg() -> None:
]
assert messages_from_dict(messages_to_dict(msgs)) == msgs
# Test with tool calls
msgs = [
AIMessage(
content="",
tool_calls=[ToolCall(name="a", args={"b": 1}, id=None)],
),
AIMessage(
content="",
tool_calls=[ToolCall(name="c", args={"c": 2}, id=None)],
),
]
assert messages_from_dict(messages_to_dict(msgs)) == msgs
def test_multiple_msg_with_name() -> None:
human_msg = HumanMessage(
@ -222,6 +290,30 @@ def test_message_chunk_to_message() -> None:
FunctionMessageChunk(name="hello", content="I am")
) == FunctionMessage(name="hello", content="I am")
chunk = AIMessageChunk(
content="I am",
tool_call_chunks=[
ToolCallChunk(name="tool1", args='{"a": 1}', id="1", index=0),
ToolCallChunk(name="tool2", args='{"b": ', id="2", index=0),
ToolCallChunk(name="tool3", args=None, id="3", index=0),
ToolCallChunk(name="tool4", args="abc", id="4", index=0),
],
)
expected = AIMessage(
content="I am",
tool_calls=[
{"name": "tool1", "args": {"a": 1}, "id": "1"},
{"name": "tool2", "args": {}, "id": "2"},
],
invalid_tool_calls=[
{"name": "tool3", "args": None, "id": "3", "error": "Malformed args."},
{"name": "tool4", "args": "abc", "id": "4", "error": "Malformed args."},
],
)
assert message_chunk_to_message(chunk) == expected
assert AIMessage(**expected.dict()) == expected
assert AIMessageChunk(**chunk.dict()) == chunk
def test_tool_calls_merge() -> None:
chunks: List[dict] = [
@ -542,3 +634,35 @@ def test_message_name_chat(MessageClass: Type) -> None:
msg3 = MessageClass(content="foo", role="user")
assert msg3.name is None
def test_merge_tool_calls() -> None:
tool_call_1 = ToolCallChunk(name="tool1", args="", id="1", index=0)
tool_call_2 = ToolCallChunk(name=None, args='{"arg1": "val', id=None, index=0)
tool_call_3 = ToolCallChunk(name=None, args='ue}"', id=None, index=0)
merged = merge_lists([tool_call_1], [tool_call_2])
assert merged is not None
assert merged == [{"name": "tool1", "args": '{"arg1": "val', "id": "1", "index": 0}]
merged = merge_lists(merged, [tool_call_3])
assert merged is not None
assert merged == [
{"name": "tool1", "args": '{"arg1": "value}"', "id": "1", "index": 0}
]
left = ToolCallChunk(name="tool1", args='{"arg1": "value1"}', id="1", index=None)
right = ToolCallChunk(name="tool2", args='{"arg2": "value2"}', id="1", index=None)
merged = merge_lists([left], [right])
assert merged is not None
assert len(merged) == 2
left = ToolCallChunk(name="tool1", args='{"arg1": "value1"}', id=None, index=None)
right = ToolCallChunk(name="tool1", args='{"arg2": "value2"}', id=None, index=None)
merged = merge_lists([left], [right])
assert merged is not None
assert len(merged) == 2
left = ToolCallChunk(name="tool1", args='{"arg1": "value1"}', id="1", index=0)
right = ToolCallChunk(name="tool2", args='{"arg2": "value2"}', id=None, index=1)
merged = merge_lists([left], [right])
assert merged is not None
assert len(merged) == 2

@ -1,5 +1,7 @@
from langchain_core.output_parsers.json import (
SimpleJsonOutputParser,
)
from langchain_core.utils.json import (
parse_and_check_json_markdown,
parse_json_markdown,
parse_partial_json,

@ -1,3 +1,4 @@
import json
import os
import re
import warnings
@ -54,7 +55,7 @@ from langchain_core.utils import (
)
from langchain_core.utils.function_calling import convert_to_openai_tool
from langchain_anthropic.output_parsers import ToolsOutputParser
from langchain_anthropic.output_parsers import ToolsOutputParser, extract_tool_calls
_message_type_lookups = {
"human": "user",
@ -347,7 +348,24 @@ class ChatAnthropic(BaseChatModel):
result = self._generate(
messages, stop=stop, run_manager=run_manager, **kwargs
)
yield cast(ChatGenerationChunk, result.generations[0])
message = result.generations[0].message
if isinstance(message, AIMessage) and message.tool_calls is not None:
tool_call_chunks = [
{
"name": tool_call["name"],
"args": json.dumps(tool_call["args"]),
"id": tool_call["id"],
"index": idx,
}
for idx, tool_call in enumerate(message.tool_calls)
]
message_chunk = AIMessageChunk(
content=message.content,
tool_call_chunks=tool_call_chunks,
)
yield ChatGenerationChunk(message=message_chunk)
else:
yield cast(ChatGenerationChunk, result.generations[0])
return
with self._client.messages.stream(**params) as stream:
for text in stream.text_stream:
@ -369,7 +387,24 @@ class ChatAnthropic(BaseChatModel):
result = await self._agenerate(
messages, stop=stop, run_manager=run_manager, **kwargs
)
yield cast(ChatGenerationChunk, result.generations[0])
message = result.generations[0].message
if isinstance(message, AIMessage) and message.tool_calls is not None:
tool_call_chunks = [
{
"name": tool_call["name"],
"args": json.dumps(tool_call["args"]),
"id": tool_call["id"],
"index": idx,
}
for idx, tool_call in enumerate(message.tool_calls)
]
message_chunk = AIMessageChunk(
content=message.content,
tool_call_chunks=tool_call_chunks,
)
yield ChatGenerationChunk(message=message_chunk)
else:
yield cast(ChatGenerationChunk, result.generations[0])
return
async with self._async_client.messages.stream(**params) as stream:
async for text in stream.text_stream:
@ -386,6 +421,12 @@ class ChatAnthropic(BaseChatModel):
}
if len(content) == 1 and content[0]["type"] == "text":
msg = AIMessage(content=content[0]["text"])
elif any(block["type"] == "tool_use" for block in content):
tool_calls = extract_tool_calls(content)
msg = AIMessage(
content=content,
tool_calls=tool_calls,
)
else:
msg = AIMessage(content=content)
return ChatResult(

@ -1,18 +1,11 @@
from typing import Any, List, Optional, Type, TypedDict, cast
from typing import Any, List, Optional, Type
from langchain_core.messages import BaseMessage
from langchain_core.messages import ToolCall
from langchain_core.output_parsers import BaseGenerationOutputParser
from langchain_core.outputs import ChatGeneration, Generation
from langchain_core.pydantic_v1 import BaseModel
class _ToolCall(TypedDict):
name: str
args: dict
id: str
index: int
class ToolsOutputParser(BaseGenerationOutputParser):
first_tool_only: bool = False
args_only: bool = False
@ -33,7 +26,19 @@ class ToolsOutputParser(BaseGenerationOutputParser):
"""
if not result or not isinstance(result[0], ChatGeneration):
return None if self.first_tool_only else []
tool_calls: List = _extract_tool_calls(result[0].message)
message = result[0].message
if isinstance(message.content, str):
tool_calls: List = []
else:
content: List = message.content
_tool_calls = [dict(tc) for tc in extract_tool_calls(content)]
# Map tool call id to index
id_to_index = {
block["id"]: i
for i, block in enumerate(content)
if block["type"] == "tool_use"
}
tool_calls = [{**tc, "index": id_to_index[tc["id"]]} for tc in _tool_calls]
if self.pydantic_schemas:
tool_calls = [self._pydantic_parse(tc) for tc in tool_calls]
elif self.args_only:
@ -44,23 +49,21 @@ class ToolsOutputParser(BaseGenerationOutputParser):
if self.first_tool_only:
return tool_calls[0] if tool_calls else None
else:
return tool_calls
return [tool_call for tool_call in tool_calls]
def _pydantic_parse(self, tool_call: _ToolCall) -> BaseModel:
def _pydantic_parse(self, tool_call: dict) -> BaseModel:
cls_ = {schema.__name__: schema for schema in self.pydantic_schemas or []}[
tool_call["name"]
]
return cls_(**tool_call["args"])
def _extract_tool_calls(msg: BaseMessage) -> List[_ToolCall]:
if isinstance(msg.content, str):
return []
def extract_tool_calls(content: List[dict]) -> List[ToolCall]:
tool_calls = []
for i, block in enumerate(cast(List[dict], msg.content)):
for block in content:
if block["type"] != "tool_use":
continue
tool_calls.append(
_ToolCall(name=block["name"], args=block["input"], id=block["id"], index=i)
ToolCall(name=block["name"], args=block["input"], id=block["id"])
)
return tool_calls

@ -1,9 +1,15 @@
"""Test ChatAnthropic chat model."""
import json
from typing import List
from langchain_core.callbacks import CallbackManager
from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessage, HumanMessage
from langchain_core.messages import (
AIMessage,
AIMessageChunk,
BaseMessage,
HumanMessage,
)
from langchain_core.outputs import ChatGeneration, LLMResult
from langchain_core.prompts import ChatPromptTemplate
@ -234,6 +240,28 @@ def test_tool_use() -> None:
response = llm_with_tools.invoke("what's the weather in san francisco, ca")
assert isinstance(response, AIMessage)
assert isinstance(response.content, list)
assert isinstance(response.tool_calls, list)
assert len(response.tool_calls) == 1
tool_call = response.tool_calls[0]
assert tool_call["name"] == "get_weather"
assert isinstance(tool_call["args"], dict)
assert "location" in tool_call["args"]
# Test streaming
first = True
for chunk in llm_with_tools.stream("what's the weather in san francisco, ca"):
if first:
gathered = chunk
first = False
else:
gathered = gathered + chunk # type: ignore
assert isinstance(gathered, AIMessageChunk)
assert isinstance(gathered.tool_call_chunks, list)
assert len(gathered.tool_call_chunks) == 1
tool_call_chunk = gathered.tool_call_chunks[0]
assert tool_call_chunk["name"] == "get_weather"
assert isinstance(tool_call_chunk["args"], str)
assert "location" in json.loads(tool_call_chunk["args"])
def test_with_structured_output() -> None:

@ -56,6 +56,8 @@ from langchain_core.output_parsers.base import OutputParserLike
from langchain_core.output_parsers.openai_tools import (
JsonOutputKeyToolsParser,
PydanticToolsParser,
make_invalid_tool_call,
parse_tool_call,
)
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validator
@ -94,9 +96,23 @@ def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
additional_kwargs: Dict = {}
if function_call := _dict.get("function_call"):
additional_kwargs["function_call"] = dict(function_call)
if tool_calls := _dict.get("tool_calls"):
additional_kwargs["tool_calls"] = tool_calls
return AIMessage(content=content, additional_kwargs=additional_kwargs)
tool_calls = []
invalid_tool_calls = []
if raw_tool_calls := _dict.get("tool_calls"):
additional_kwargs["tool_calls"] = raw_tool_calls
for raw_tool_call in raw_tool_calls:
try:
tool_calls.append(parse_tool_call(raw_tool_call, return_id=True))
except Exception as e:
invalid_tool_calls.append(
dict(make_invalid_tool_call(raw_tool_call, str(e)))
)
return AIMessage(
content=content,
additional_kwargs=additional_kwargs,
tool_calls=tool_calls,
invalid_tool_calls=invalid_tool_calls,
)
elif role == "system":
return SystemMessage(content=_dict.get("content", ""))
elif role == "function":
@ -174,13 +190,31 @@ def _convert_delta_to_message_chunk(
if "name" in function_call and function_call["name"] is None:
function_call["name"] = ""
additional_kwargs["function_call"] = function_call
if _dict.get("tool_calls"):
additional_kwargs["tool_calls"] = _dict["tool_calls"]
if raw_tool_calls := _dict.get("tool_calls"):
additional_kwargs["tool_calls"] = raw_tool_calls
try:
tool_call_chunks = [
{
"name": rtc["function"].get("name"),
"args": rtc["function"].get("arguments"),
"id": rtc.get("id"),
"index": rtc["index"],
}
for rtc in raw_tool_calls
]
except KeyError:
pass
else:
tool_call_chunks = []
if role == "user" or default_class == HumanMessageChunk:
return HumanMessageChunk(content=content)
elif role == "assistant" or default_class == AIMessageChunk:
return AIMessageChunk(content=content, additional_kwargs=additional_kwargs)
return AIMessageChunk(
content=content,
additional_kwargs=additional_kwargs,
tool_call_chunks=tool_call_chunks,
)
elif role == "system" or default_class == SystemMessageChunk:
return SystemMessageChunk(content=content)
elif role == "function" or default_class == FunctionMessageChunk:

@ -47,6 +47,11 @@ def test_tool_choice() -> None:
"name": "Erick",
}
assert tool_call["type"] == "function"
assert isinstance(resp.tool_calls, list)
assert len(resp.tool_calls) == 1
tool_call = resp.tool_calls[0]
assert tool_call["name"] == "MyTool"
assert tool_call["args"] == {"age": 27, "name": "Erick"}
def test_tool_choice_bool() -> None:

@ -58,6 +58,8 @@ from langchain_core.output_parsers.base import OutputParserLike
from langchain_core.output_parsers.openai_tools import (
JsonOutputKeyToolsParser,
PydanticToolsParser,
make_invalid_tool_call,
parse_tool_call,
)
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validator
@ -278,9 +280,20 @@ class ChatGroq(BaseChatModel):
chat_result = self._create_chat_result(response)
generation = chat_result.generations[0]
message = generation.message
tool_call_chunks = [
{
"name": rtc["function"].get("name"),
"args": rtc["function"].get("arguments"),
"id": rtc.get("id"),
"index": rtc.get("index"),
}
for rtc in message.additional_kwargs["tool_calls"]
]
chunk_ = ChatGenerationChunk(
message=AIMessageChunk(
content=message.content, additional_kwargs=message.additional_kwargs
content=message.content,
additional_kwargs=message.additional_kwargs,
tool_call_chunks=tool_call_chunks,
),
generation_info=generation.generation_info,
)
@ -338,9 +351,20 @@ class ChatGroq(BaseChatModel):
chat_result = self._create_chat_result(response)
generation = chat_result.generations[0]
message = generation.message
tool_call_chunks = [
{
"name": rtc["function"].get("name"),
"args": rtc["function"].get("arguments"),
"id": rtc.get("id"),
"index": rtc.get("index"),
}
for rtc in message.additional_kwargs["tool_calls"]
]
chunk_ = ChatGenerationChunk(
message=AIMessageChunk(
content=message.content, additional_kwargs=message.additional_kwargs
content=message.content,
additional_kwargs=message.additional_kwargs,
tool_call_chunks=tool_call_chunks,
),
generation_info=generation.generation_info,
)
@ -883,9 +907,24 @@ def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
additional_kwargs: Dict = {}
if function_call := _dict.get("function_call"):
additional_kwargs["function_call"] = dict(function_call)
if tool_calls := _dict.get("tool_calls"):
additional_kwargs["tool_calls"] = tool_calls
return AIMessage(content=content, id=id_, additional_kwargs=additional_kwargs)
tool_calls = []
invalid_tool_calls = []
if raw_tool_calls := _dict.get("tool_calls"):
additional_kwargs["tool_calls"] = raw_tool_calls
for raw_tool_call in raw_tool_calls:
try:
tool_calls.append(parse_tool_call(raw_tool_call, return_id=True))
except Exception as e:
invalid_tool_calls.append(
make_invalid_tool_call(raw_tool_call, str(e))
)
return AIMessage(
content=content,
id=id_,
additional_kwargs=additional_kwargs,
tool_calls=tool_calls,
invalid_tool_calls=invalid_tool_calls,
)
elif role == "system":
return SystemMessage(content=_dict.get("content", ""))
elif role == "function":

@ -247,6 +247,12 @@ def test_tool_choice() -> None:
}
assert tool_call["type"] == "function"
assert isinstance(resp.tool_calls, list)
assert len(resp.tool_calls) == 1
tool_call = resp.tool_calls[0]
assert tool_call["name"] == "MyTool"
assert tool_call["args"] == {"name": "Erick", "age": 27}
@pytest.mark.xfail(reason="Groq tool_choice doesn't currently force a tool call")
def test_tool_choice_bool() -> None:
@ -302,6 +308,14 @@ def test_streaming_tool_call() -> None:
}
assert tool_call["type"] == "function"
assert isinstance(chunk, AIMessageChunk)
assert isinstance(chunk.tool_call_chunks, list)
assert len(chunk.tool_call_chunks) == 1
tool_call_chunk = chunk.tool_call_chunks[0]
assert tool_call_chunk["name"] == "MyTool"
assert isinstance(tool_call_chunk["args"], str)
assert json.loads(tool_call_chunk["args"]) == {"name": "Erick", "age": 27}
@pytest.mark.xfail(reason="Groq tool_choice doesn't currently force a tool call")
async def test_astreaming_tool_call() -> None:
@ -332,6 +346,14 @@ async def test_astreaming_tool_call() -> None:
}
assert tool_call["type"] == "function"
assert isinstance(chunk, AIMessageChunk)
assert isinstance(chunk.tool_call_chunks, list)
assert len(chunk.tool_call_chunks) == 1
tool_call_chunk = chunk.tool_call_chunks[0]
assert tool_call_chunk["name"] == "MyTool"
assert isinstance(tool_call_chunk["args"], str)
assert json.loads(tool_call_chunk["args"]) == {"name": "Erick", "age": 27}
@pytest.mark.scheduled
def test_json_mode_structured_output() -> None:

@ -11,7 +11,9 @@ from langchain_core.messages import (
AIMessage,
FunctionMessage,
HumanMessage,
InvalidToolCall,
SystemMessage,
ToolCall,
)
from langchain_groq.chat_models import ChatGroq, _convert_dict_to_message
@ -56,6 +58,73 @@ def test__convert_dict_to_message_ai() -> None:
assert result == expected_output
def test__convert_dict_to_message_tool_call() -> None:
raw_tool_call = {
"id": "call_wm0JY6CdwOMZ4eTxHWUThDNz",
"function": {
"arguments": '{"name":"Sally","hair_color":"green"}',
"name": "GenerateUsername",
},
"type": "function",
}
message = {"role": "assistant", "content": None, "tool_calls": [raw_tool_call]}
result = _convert_dict_to_message(message)
expected_output = AIMessage(
content="",
additional_kwargs={"tool_calls": [raw_tool_call]},
tool_calls=[
ToolCall(
name="GenerateUsername",
args={"name": "Sally", "hair_color": "green"},
id="call_wm0JY6CdwOMZ4eTxHWUThDNz",
)
],
)
assert result == expected_output
# Test malformed tool call
raw_tool_calls = [
{
"id": "call_wm0JY6CdwOMZ4eTxHWUThDNz",
"function": {
"arguments": "oops",
"name": "GenerateUsername",
},
"type": "function",
},
{
"id": "call_abc123",
"function": {
"arguments": '{"name":"Sally","hair_color":"green"}',
"name": "GenerateUsername",
},
"type": "function",
},
]
message = {"role": "assistant", "content": None, "tool_calls": raw_tool_calls}
result = _convert_dict_to_message(message)
expected_output = AIMessage(
content="",
additional_kwargs={"tool_calls": raw_tool_calls},
invalid_tool_calls=[
InvalidToolCall(
name="GenerateUsername",
args="oops",
id="call_wm0JY6CdwOMZ4eTxHWUThDNz",
error="Function GenerateUsername arguments:\n\noops\n\nare not valid JSON. Received JSONDecodeError Expecting value: line 1 column 1 (char 0)", # noqa: E501
),
],
tool_calls=[
ToolCall(
name="GenerateUsername",
args={"name": "Sally", "hair_color": "green"},
id="call_abc123",
),
],
)
assert result == expected_output
def test__convert_dict_to_message_system() -> None:
message = {"role": "system", "content": "foo"}
result = _convert_dict_to_message(message)

@ -49,6 +49,8 @@ from langchain_core.output_parsers.base import OutputParserLike
from langchain_core.output_parsers.openai_tools import (
JsonOutputKeyToolsParser,
PydanticToolsParser,
make_invalid_tool_call,
parse_tool_call,
)
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validator
@ -82,9 +84,31 @@ def _convert_mistral_chat_message_to_message(
content = cast(str, _message["content"])
additional_kwargs: Dict = {}
if tool_calls := _message.get("tool_calls"):
additional_kwargs["tool_calls"] = tool_calls
return AIMessage(content=content, additional_kwargs=additional_kwargs)
tool_calls = []
invalid_tool_calls = []
if raw_tool_calls := _message.get("tool_calls"):
additional_kwargs["tool_calls"] = raw_tool_calls
for raw_tool_call in raw_tool_calls:
try:
parsed: dict = cast(
dict, parse_tool_call(raw_tool_call, return_id=False)
)
tool_calls.append(
{
**parsed,
**{"id": None},
},
)
except Exception as e:
invalid_tool_calls.append(
dict(make_invalid_tool_call(raw_tool_call, str(e)))
)
return AIMessage(
content=content,
additional_kwargs=additional_kwargs,
tool_calls=tool_calls,
invalid_tool_calls=invalid_tool_calls,
)
async def _aiter_sse(
@ -133,9 +157,27 @@ def _convert_delta_to_message_chunk(
return HumanMessageChunk(content=content)
elif role == "assistant" or default_class == AIMessageChunk:
additional_kwargs: Dict = {}
if tool_calls := _delta.get("tool_calls"):
additional_kwargs["tool_calls"] = tool_calls
return AIMessageChunk(content=content, additional_kwargs=additional_kwargs)
if raw_tool_calls := _delta.get("tool_calls"):
additional_kwargs["tool_calls"] = raw_tool_calls
try:
tool_call_chunks = [
{
"name": rtc["function"].get("name"),
"args": rtc["function"].get("arguments"),
"id": rtc.get("id"),
"index": rtc.get("index"),
}
for rtc in raw_tool_calls
]
except KeyError:
pass
else:
tool_call_chunks = []
return AIMessageChunk(
content=content,
additional_kwargs=additional_kwargs,
tool_call_chunks=tool_call_chunks,
)
elif role == "system" or default_class == SystemMessageChunk:
return SystemMessageChunk(content=content)
elif role or default_class == ChatMessageChunk:
@ -163,7 +205,7 @@ def _convert_message_to_mistral_chat_message(
for tc in message.additional_kwargs["tool_calls"]
]
else:
tool_calls = None
tool_calls = []
return {
"role": "assistant",
"content": message.content,

@ -3,7 +3,13 @@
import json
from typing import Any
from langchain_core.messages import AIMessageChunk, HumanMessage
from langchain_core.messages import (
AIMessage,
AIMessageChunk,
HumanMessage,
ToolCall,
ToolCallChunk,
)
from langchain_core.pydantic_v1 import BaseModel
from langchain_mistralai.chat_models import ChatMistralAI
@ -151,6 +157,22 @@ def test_streaming_structured_output() -> None:
chunk_num += 1
def test_tool_call() -> None:
llm = ChatMistralAI(model="mistral-large", temperature=0)
class Person(BaseModel):
name: str
age: int
tool_llm = llm.bind_tools([Person])
result = tool_llm.invoke("Erick, 27 years old")
assert isinstance(result, AIMessage)
assert result.tool_calls == [
ToolCall(name="Person", args={"name": "Erick", "age": 27}, id=None)
]
def test_streaming_tool_call() -> None:
llm = ChatMistralAI(model="mistral-large", temperature=0)
@ -178,6 +200,13 @@ def test_streaming_tool_call() -> None:
"age": 27,
}
assert isinstance(chunk, AIMessageChunk)
assert chunk.tool_call_chunks == [
ToolCallChunk(
name="Person", args='{"name": "Erick", "age": 27}', id=None, index=None
)
]
# where it doesn't call the tool
strm = tool_llm.stream("What is 2+2?")
acc: Any = None

@ -11,13 +11,16 @@ from langchain_core.messages import (
BaseMessage,
ChatMessage,
HumanMessage,
InvalidToolCall,
SystemMessage,
ToolCall,
)
from langchain_core.pydantic_v1 import SecretStr
from langchain_mistralai.chat_models import ( # type: ignore[import]
ChatMistralAI,
_convert_message_to_mistral_chat_message,
_convert_mistral_chat_message_to_message,
)
os.environ["MISTRAL_API_KEY"] = "foo"
@ -52,7 +55,7 @@ def test_mistralai_initialization() -> None:
),
(
AIMessage(content="Hello"),
dict(role="assistant", content="Hello", tool_calls=None),
dict(role="assistant", content="Hello", tool_calls=[]),
),
(
ChatMessage(role="assistant", content="Hello"),
@ -121,3 +124,66 @@ async def test_astream_with_callback() -> None:
chat = ChatMistralAI(callbacks=[callback])
async for token in chat.astream("Hello"):
assert callback.last_token == token.content
def test__convert_dict_to_message_tool_call() -> None:
raw_tool_call = {
"function": {
"arguments": '{"name":"Sally","hair_color":"green"}',
"name": "GenerateUsername",
},
}
message = {"role": "assistant", "content": "", "tool_calls": [raw_tool_call]}
result = _convert_mistral_chat_message_to_message(message)
expected_output = AIMessage(
content="",
additional_kwargs={"tool_calls": [raw_tool_call]},
tool_calls=[
ToolCall(
name="GenerateUsername",
args={"name": "Sally", "hair_color": "green"},
id=None,
)
],
)
assert result == expected_output
assert _convert_message_to_mistral_chat_message(expected_output) == message
# Test malformed tool call
raw_tool_calls = [
{
"function": {
"arguments": "oops",
"name": "GenerateUsername",
},
},
{
"function": {
"arguments": '{"name":"Sally","hair_color":"green"}',
"name": "GenerateUsername",
},
},
]
message = {"role": "assistant", "content": "", "tool_calls": raw_tool_calls}
result = _convert_mistral_chat_message_to_message(message)
expected_output = AIMessage(
content="",
additional_kwargs={"tool_calls": raw_tool_calls},
invalid_tool_calls=[
InvalidToolCall(
name="GenerateUsername",
args="oops",
error="Function GenerateUsername arguments:\n\noops\n\nare not valid JSON. Received JSONDecodeError Expecting value: line 1 column 1 (char 0)", # noqa: E501
id=None,
),
],
tool_calls=[
ToolCall(
name="GenerateUsername",
args={"name": "Sally", "hair_color": "green"},
id=None,
),
],
)
assert result == expected_output
assert _convert_message_to_mistral_chat_message(expected_output) == message

@ -63,6 +63,8 @@ from langchain_core.output_parsers.base import OutputParserLike
from langchain_core.output_parsers.openai_tools import (
JsonOutputKeyToolsParser,
PydanticToolsParser,
make_invalid_tool_call,
parse_tool_call,
)
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validator
@ -103,10 +105,24 @@ def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
additional_kwargs: Dict = {}
if function_call := _dict.get("function_call"):
additional_kwargs["function_call"] = dict(function_call)
if tool_calls := _dict.get("tool_calls"):
additional_kwargs["tool_calls"] = tool_calls
tool_calls = []
invalid_tool_calls = []
if raw_tool_calls := _dict.get("tool_calls"):
additional_kwargs["tool_calls"] = raw_tool_calls
for raw_tool_call in raw_tool_calls:
try:
tool_calls.append(parse_tool_call(raw_tool_call, return_id=True))
except Exception as e:
invalid_tool_calls.append(
make_invalid_tool_call(raw_tool_call, str(e))
)
return AIMessage(
content=content, additional_kwargs=additional_kwargs, name=name, id=id_
content=content,
additional_kwargs=additional_kwargs,
name=name,
id=id_,
tool_calls=tool_calls,
invalid_tool_calls=invalid_tool_calls,
)
elif role == "system":
return SystemMessage(content=_dict.get("content", ""), name=name, id=id_)
@ -188,14 +204,30 @@ def _convert_delta_to_message_chunk(
if "name" in function_call and function_call["name"] is None:
function_call["name"] = ""
additional_kwargs["function_call"] = function_call
if _dict.get("tool_calls"):
additional_kwargs["tool_calls"] = _dict["tool_calls"]
tool_call_chunks = []
if raw_tool_calls := _dict.get("tool_calls"):
additional_kwargs["tool_calls"] = raw_tool_calls
try:
tool_call_chunks = [
{
"name": rtc["function"].get("name"),
"args": rtc["function"].get("arguments"),
"id": rtc.get("id"),
"index": rtc["index"],
}
for rtc in raw_tool_calls
]
except KeyError:
pass
if role == "user" or default_class == HumanMessageChunk:
return HumanMessageChunk(content=content, id=id_)
elif role == "assistant" or default_class == AIMessageChunk:
return AIMessageChunk(
content=content, additional_kwargs=additional_kwargs, id=id_
content=content,
additional_kwargs=additional_kwargs,
id=id_,
tool_call_chunks=tool_call_chunks,
)
elif role == "system" or default_class == SystemMessageChunk:
return SystemMessageChunk(content=content, id=id_)

@ -5,6 +5,7 @@ import pytest
from langchain_core.callbacks import CallbackManager
from langchain_core.messages import (
AIMessage,
AIMessageChunk,
BaseMessage,
BaseMessageChunk,
HumanMessage,
@ -482,6 +483,28 @@ def test_tool_use() -> None:
llm_with_tool = llm.bind_tools(tools=[GenerateUsername], tool_choice=True)
msgs: List = [HumanMessage("Sally has green hair, what would her username be?")]
ai_msg = llm_with_tool.invoke(msgs)
assert isinstance(ai_msg, AIMessage)
assert isinstance(ai_msg.tool_calls, list)
assert len(ai_msg.tool_calls) == 1
tool_call = ai_msg.tool_calls[0]
assert "args" in tool_call
# Test streaming
ai_messages = llm_with_tool.stream(msgs)
first = True
for message in ai_messages:
if first:
gathered = message
first = False
else:
gathered = gathered + message # type: ignore
assert isinstance(gathered, AIMessageChunk)
assert isinstance(gathered.tool_call_chunks, list)
assert len(gathered.tool_call_chunks) == 1
tool_call_chunk = gathered.tool_call_chunks[0]
assert "args" in tool_call_chunk
tool_msg = ToolMessage(
"sally_green_hair", tool_call_id=ai_msg.additional_kwargs["tool_calls"][0]["id"]
)

@ -9,7 +9,9 @@ from langchain_core.messages import (
AIMessage,
FunctionMessage,
HumanMessage,
InvalidToolCall,
SystemMessage,
ToolCall,
ToolMessage,
)
@ -98,6 +100,75 @@ def test__convert_dict_to_message_tool() -> None:
assert _convert_message_to_dict(expected_output) == message
def test__convert_dict_to_message_tool_call() -> None:
raw_tool_call = {
"id": "call_wm0JY6CdwOMZ4eTxHWUThDNz",
"function": {
"arguments": '{"name":"Sally","hair_color":"green"}',
"name": "GenerateUsername",
},
"type": "function",
}
message = {"role": "assistant", "content": None, "tool_calls": [raw_tool_call]}
result = _convert_dict_to_message(message)
expected_output = AIMessage(
content="",
additional_kwargs={"tool_calls": [raw_tool_call]},
tool_calls=[
ToolCall(
name="GenerateUsername",
args={"name": "Sally", "hair_color": "green"},
id="call_wm0JY6CdwOMZ4eTxHWUThDNz",
)
],
)
assert result == expected_output
assert _convert_message_to_dict(expected_output) == message
# Test malformed tool call
raw_tool_calls = [
{
"id": "call_wm0JY6CdwOMZ4eTxHWUThDNz",
"function": {
"arguments": "oops",
"name": "GenerateUsername",
},
"type": "function",
},
{
"id": "call_abc123",
"function": {
"arguments": '{"name":"Sally","hair_color":"green"}',
"name": "GenerateUsername",
},
"type": "function",
},
]
message = {"role": "assistant", "content": None, "tool_calls": raw_tool_calls}
result = _convert_dict_to_message(message)
expected_output = AIMessage(
content="",
additional_kwargs={"tool_calls": raw_tool_calls},
invalid_tool_calls=[
InvalidToolCall(
name="GenerateUsername",
args="oops",
id="call_wm0JY6CdwOMZ4eTxHWUThDNz",
error="Function GenerateUsername arguments:\n\noops\n\nare not valid JSON. Received JSONDecodeError Expecting value: line 1 column 1 (char 0)", # noqa: E501
),
],
tool_calls=[
ToolCall(
name="GenerateUsername",
args={"name": "Sally", "hair_color": "green"},
id="call_abc123",
),
],
)
assert result == expected_output
assert _convert_message_to_dict(expected_output) == message
@pytest.fixture
def mock_completion() -> dict:
return {

Loading…
Cancel
Save