mirror of https://github.com/hwchase17/langchain
core[minor], ...: add tool calls message (#18947)
core[minor], langchain[patch], openai[minor], anthropic[minor], fireworks[minor], groq[minor], mistralai[minor] ```python class ToolCall(TypedDict): name: str args: Dict[str, Any] id: Optional[str] class InvalidToolCall(TypedDict): name: Optional[str] args: Optional[str] id: Optional[str] error: Optional[str] class ToolCallChunk(TypedDict): name: Optional[str] args: Optional[str] id: Optional[str] index: Optional[int] class AIMessage(BaseMessage): ... tool_calls: List[ToolCall] = [] invalid_tool_calls: List[InvalidToolCall] = [] ... class AIMessageChunk(AIMessage, BaseMessageChunk): ... tool_call_chunks: Optional[List[ToolCallChunk]] = None ... ``` Important considerations: - Parsing logic occurs within different providers; - ~Changing output type is a breaking change for anyone doing explicit type checking;~ - ~Langsmith rendering will need to be updated: https://github.com/langchain-ai/langchainplus/pull/3561~ - ~Langserve will need to be updated~ - Adding chunks: - ~AIMessage + ToolCallsMessage = ToolCallsMessage if either has non-null .tool_calls.~ - Tool call chunks are appended, merging when having equal values of `index`. - additional_kwargs accumulate the normal way. - During streaming: - ~Messages can change types (e.g., from AIMessageChunk to AIToolCallsMessageChunk)~ - Output parsers parse additional_kwargs (during .invoke they read off tool calls). Packages outside of `partners/`: - https://github.com/langchain-ai/langchain-cohere/pull/7 - https://github.com/langchain-ai/langchain-google/pull/123/files --------- Co-authored-by: Chester Curme <chester.curme@gmail.com>pull/11274/merge
parent
00552918ac
commit
9514bc4d67
@ -0,0 +1,423 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"id": "c48812ed-35bd-4fbe-9a2c-6c7335e5645e",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"/Users/chestercurme/repos/langchain/libs/core/langchain_core/_api/beta_decorator.py:87: LangChainBetaWarning: The function `bind_tools` is in beta. It is actively being worked on, so the API may change.\n",
|
||||
" warn_beta(\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from langchain_anthropic import ChatAnthropic\n",
|
||||
"from langchain_core.runnables import ConfigurableField\n",
|
||||
"from langchain_core.tools import tool\n",
|
||||
"from langchain_openai import ChatOpenAI\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"@tool\n",
|
||||
"def multiply(x: float, y: float) -> float:\n",
|
||||
" \"\"\"Multiply 'x' times 'y'.\"\"\"\n",
|
||||
" return x * y\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"@tool\n",
|
||||
"def exponentiate(x: float, y: float) -> float:\n",
|
||||
" \"\"\"Raise 'x' to the 'y'.\"\"\"\n",
|
||||
" return x**y\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"@tool\n",
|
||||
"def add(x: float, y: float) -> float:\n",
|
||||
" \"\"\"Add 'x' and 'y'.\"\"\"\n",
|
||||
" return x + y\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"tools = [multiply, exponentiate, add]\n",
|
||||
"\n",
|
||||
"gpt35 = ChatOpenAI(model=\"gpt-3.5-turbo-0125\", temperature=0).bind_tools(tools)\n",
|
||||
"claude3 = ChatAnthropic(model=\"claude-3-sonnet-20240229\").bind_tools(tools)\n",
|
||||
"llm_with_tools = gpt35.configurable_alternatives(\n",
|
||||
" ConfigurableField(id=\"llm\"), default_key=\"gpt35\", claude3=claude3\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "4719ebdb-ad50-468e-9b30-fb5fb086e140",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# AgentExecutor"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"id": "b98feaa5-8c2d-4125-9519-67114a1fef31",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from typing import List, Tuple, Union\n",
|
||||
"\n",
|
||||
"from langchain.agents import AgentExecutor\n",
|
||||
"from langchain.agents.output_parsers.openai_tools import OpenAIToolAgentAction\n",
|
||||
"from langchain_core.agents import AgentFinish, _convert_agent_action_to_messages\n",
|
||||
"from langchain_core.messages import (\n",
|
||||
" AIMessage,\n",
|
||||
" BaseMessage,\n",
|
||||
" ToolMessage,\n",
|
||||
")\n",
|
||||
"from langchain_core.prompts import ChatPromptTemplate\n",
|
||||
"from langchain_core.runnables import RunnablePassthrough\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def actions_observations_to_messages(\n",
|
||||
" steps: List[Tuple[OpenAIToolAgentAction, str]],\n",
|
||||
") -> List[BaseMessage]:\n",
|
||||
" messages = []\n",
|
||||
" for action, observation in steps:\n",
|
||||
" messages.extend([m for m in action.message_log if m not in messages])\n",
|
||||
" messages.append(ToolMessage(observation, tool_call_id=action.tool_call_id))\n",
|
||||
" return messages\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def messages_to_action(\n",
|
||||
" msg: AIMessage,\n",
|
||||
") -> Union[List[OpenAIToolAgentAction], AgentFinish]:\n",
|
||||
" if isinstance(msg, AIMessage) and msg.tool_calls is not None:\n",
|
||||
" actions = []\n",
|
||||
" for tool_call in msg.tool_calls:\n",
|
||||
" actions.append(\n",
|
||||
" OpenAIToolAgentAction(\n",
|
||||
" tool=tool_call.name,\n",
|
||||
" tool_input=tool_call.args,\n",
|
||||
" tool_call_id=tool_call.id,\n",
|
||||
" message_log=[msg],\n",
|
||||
" log=\"\",\n",
|
||||
" )\n",
|
||||
" )\n",
|
||||
" return actions\n",
|
||||
" else:\n",
|
||||
" return AgentFinish(return_values={\"output\": msg.content}, log=\"\")\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"prompt = ChatPromptTemplate.from_messages(\n",
|
||||
" [\n",
|
||||
" (\"system\", \"You're a helpful assistant with access to tools\"),\n",
|
||||
" (\"human\", \"{input}\"),\n",
|
||||
" (\"placeholder\", \"{agent_scratchpad}\"),\n",
|
||||
" ]\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"agent = (\n",
|
||||
" RunnablePassthrough.assign(\n",
|
||||
" agent_scratchpad=lambda x: actions_observations_to_messages(\n",
|
||||
" x[\"intermediate_steps\"]\n",
|
||||
" ),\n",
|
||||
" )\n",
|
||||
" | prompt\n",
|
||||
" | llm_with_tools\n",
|
||||
" | messages_to_action\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"id": "b4c0fc7a-80bb-4bb8-a87b-7388291ae8b6",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n",
|
||||
"\n",
|
||||
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
|
||||
"\u001b[32;1m\u001b[1;3m\u001b[0m\u001b[33;1m\u001b[1;3m300.03770462067547\u001b[0m\u001b[32;1m\u001b[1;3m\u001b[0m\u001b[38;5;200m\u001b[1;3m-900.8841\u001b[0m\u001b[32;1m\u001b[1;3m\u001b[0m\n",
|
||||
"\n",
|
||||
"\u001b[1m> Finished chain.\u001b[0m\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"{'input': \"what's 3 plus 5 raised to the 2.743. also what's 17.24 - 918.1241\",\n",
|
||||
" 'output': 'The result of \\\\(3 + 5^{2.743}\\\\) is approximately 300.04, and the result of \\\\(17.24 - 918.1241\\\\) is approximately -900.88.'}"
|
||||
]
|
||||
},
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"agent_executor.invoke(\n",
|
||||
" {\"input\": \"what's 3 plus 5 raised to the 2.743. also what's 17.24 - 918.1241\"}\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"id": "41a3a3c8-185d-4861-b6f0-7592668feb62",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n",
|
||||
"\n",
|
||||
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"/Users/chestercurme/repos/langchain/libs/partners/anthropic/langchain_anthropic/chat_models.py:336: UserWarning: stream: Tool use is not yet supported in streaming mode.\n",
|
||||
" warnings.warn(\"stream: Tool use is not yet supported in streaming mode.\")\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\u001b[32;1m\u001b[1;3m\u001b[0m\u001b[33;1m\u001b[1;3m82.65606421491815\u001b[0m"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"/Users/chestercurme/repos/langchain/libs/partners/anthropic/langchain_anthropic/chat_models.py:336: UserWarning: stream: Tool use is not yet supported in streaming mode.\n",
|
||||
" warnings.warn(\"stream: Tool use is not yet supported in streaming mode.\")\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\u001b[32;1m\u001b[1;3m\u001b[0m\u001b[38;5;200m\u001b[1;3m85.65606421491815\u001b[0m"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"/Users/chestercurme/repos/langchain/libs/partners/anthropic/langchain_anthropic/chat_models.py:336: UserWarning: stream: Tool use is not yet supported in streaming mode.\n",
|
||||
" warnings.warn(\"stream: Tool use is not yet supported in streaming mode.\")\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\u001b[32;1m\u001b[1;3m\u001b[0m\u001b[38;5;200m\u001b[1;3m-900.8841\u001b[0m"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"/Users/chestercurme/repos/langchain/libs/partners/anthropic/langchain_anthropic/chat_models.py:336: UserWarning: stream: Tool use is not yet supported in streaming mode.\n",
|
||||
" warnings.warn(\"stream: Tool use is not yet supported in streaming mode.\")\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\u001b[32;1m\u001b[1;3m\u001b[0m\n",
|
||||
"\n",
|
||||
"\u001b[1m> Finished chain.\u001b[0m\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"{'input': \"what's 3 plus 5 raised to the 2.743. also what's 17.24 - 918.1241\",\n",
|
||||
" 'output': 'Therefore, 17.24 - 918.1241 = -900.8841'}"
|
||||
]
|
||||
},
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"agent_executor = AgentExecutor(\n",
|
||||
" agent=agent.with_config(configurable={\"llm\": \"claude3\"}), tools=tools, verbose=True\n",
|
||||
")\n",
|
||||
"agent_executor.invoke(\n",
|
||||
" {\"input\": \"what's 3 plus 5 raised to the 2.743. also what's 17.24 - 918.1241\"},\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "9c186263-1b98-4cb2-b6d1-71f65eb0d811",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# LangGraph"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"id": "28fc2c60-7dbc-428a-8983-1a6a15ea30d2",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import operator\n",
|
||||
"from typing import Annotated, Sequence, TypedDict\n",
|
||||
"\n",
|
||||
"from langchain_core.messages import AIMessage, BaseMessage, HumanMessage\n",
|
||||
"from langchain_core.runnables import RunnableLambda\n",
|
||||
"from langgraph.graph import END, StateGraph\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"class AgentState(TypedDict):\n",
|
||||
" messages: Annotated[Sequence[BaseMessage], operator.add]\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def should_continue(state):\n",
|
||||
" return \"continue\" if state[\"messages\"][-1].tool_calls is not None else \"end\"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def call_model(state, config):\n",
|
||||
" return {\"messages\": [llm_with_tools.invoke(state[\"messages\"], config=config)]}\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def _invoke_tool(tool_call):\n",
|
||||
" tool = {tool.name: tool for tool in tools}[tool_call.name]\n",
|
||||
" return ToolMessage(tool.invoke(tool_call.args), tool_call_id=tool_call.id)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"tool_executor = RunnableLambda(_invoke_tool)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def call_tools(state):\n",
|
||||
" last_message = state[\"messages\"][-1]\n",
|
||||
" return {\"messages\": tool_executor.batch(last_message.tool_calls)}\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"workflow = StateGraph(AgentState)\n",
|
||||
"workflow.add_node(\"agent\", call_model)\n",
|
||||
"workflow.add_node(\"action\", call_tools)\n",
|
||||
"workflow.set_entry_point(\"agent\")\n",
|
||||
"workflow.add_conditional_edges(\n",
|
||||
" \"agent\",\n",
|
||||
" should_continue,\n",
|
||||
" {\n",
|
||||
" \"continue\": \"action\",\n",
|
||||
" \"end\": END,\n",
|
||||
" },\n",
|
||||
")\n",
|
||||
"workflow.add_edge(\"action\", \"agent\")\n",
|
||||
"graph = workflow.compile()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"id": "24463798-74e6-4c39-8092-7a1524d83225",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"{'messages': [HumanMessage(content=\"what's 3 plus 5 raised to the 2.743. also what's 17.24 - 918.1241\"),\n",
|
||||
" AIMessage(content='', additional_kwargs={'tool_calls': [{'id': 'call_kbBUUeqK75fZZqDTvu8aif7Z', 'function': {'arguments': '{\"x\": 8, \"y\": 2.743}', 'name': 'exponentiate'}, 'type': 'function'}, {'id': 'call_pBD8daSyXidXnrIyG4vG5C9O', 'function': {'arguments': '{\"x\": 17.24, \"y\": -918.1241}', 'name': 'add'}, 'type': 'function'}]}, response_metadata={'token_usage': {'completion_tokens': 58, 'prompt_tokens': 168, 'total_tokens': 226}, 'model_name': 'gpt-3.5-turbo-0125', 'system_fingerprint': 'fp_b28b39ffa8', 'finish_reason': 'tool_calls', 'logprobs': None}, id='run-8e1d9687-611c-4c8e-9fcd-ef6e48bd22a6-0', tool_calls=[ToolCall(name='exponentiate', args={'x': 8, 'y': 2.743}, id='call_kbBUUeqK75fZZqDTvu8aif7Z'), ToolCall(name='add', args={'x': 17.24, 'y': -918.1241}, id='call_pBD8daSyXidXnrIyG4vG5C9O')]),\n",
|
||||
" ToolMessage(content='300.03770462067547', tool_call_id='call_kbBUUeqK75fZZqDTvu8aif7Z'),\n",
|
||||
" ToolMessage(content='-900.8841', tool_call_id='call_pBD8daSyXidXnrIyG4vG5C9O'),\n",
|
||||
" AIMessage(content='The result of \\\\(3 + 5^{2.743}\\\\) is approximately 300.04, and the result of \\\\(17.24 - 918.1241\\\\) is approximately -900.88.', response_metadata={'token_usage': {'completion_tokens': 44, 'prompt_tokens': 251, 'total_tokens': 295}, 'model_name': 'gpt-3.5-turbo-0125', 'system_fingerprint': 'fp_b28b39ffa8', 'finish_reason': 'stop', 'logprobs': None}, id='run-47fe5cbc-3f25-44c3-85b2-6540c3054a77-0')]}"
|
||||
]
|
||||
},
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"graph.invoke(\n",
|
||||
" {\n",
|
||||
" \"messages\": [\n",
|
||||
" HumanMessage(\n",
|
||||
" \"what's 3 plus 5 raised to the 2.743. also what's 17.24 - 918.1241\"\n",
|
||||
" )\n",
|
||||
" ]\n",
|
||||
" }\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"id": "073c074e-d722-42e0-85ec-c62c079207e4",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"{'messages': [HumanMessage(content=\"what's 3 plus 5 raised to the 2.743. also what's 17.24 - 918.1241\"),\n",
|
||||
" AIMessage(content=[{'text': \"Okay, let's break this down into steps:\", 'type': 'text'}, {'id': 'toolu_01DJkSDpB8ztmJx2DLNbc3eW', 'input': {'x': 5, 'y': 2.743}, 'name': 'exponentiate', 'type': 'tool_use'}], response_metadata={'id': 'msg_01KuVNohyJr24cPhJkY3XVtt', 'model': 'claude-3-sonnet-20240229', 'stop_reason': 'tool_use', 'stop_sequence': None, 'usage': {'input_tokens': 450, 'output_tokens': 84}}, id='run-336cdfb6-0fe4-4d7a-9946-9f01c2eb41ae-0', tool_calls=[ToolCall(name='exponentiate', args={'x': 5, 'y': 2.743}, id='toolu_01DJkSDpB8ztmJx2DLNbc3eW', index=1)]),\n",
|
||||
" ToolMessage(content='82.65606421491815', tool_call_id='toolu_01DJkSDpB8ztmJx2DLNbc3eW'),\n",
|
||||
" AIMessage(content=[{'text': 'To get 5 raised to the 2.743 power.', 'type': 'text'}, {'id': 'toolu_01MKQqnDw5CtyuKjQP8YG1FX', 'input': {'x': 3, 'y': 82.65606421491815}, 'name': 'add', 'type': 'tool_use'}], response_metadata={'id': 'msg_01UBsKkvA4StUR4NEvoFFFep', 'model': 'claude-3-sonnet-20240229', 'stop_reason': 'tool_use', 'stop_sequence': None, 'usage': {'input_tokens': 552, 'output_tokens': 91}}, id='run-9d25b7bd-58aa-47dd-933f-15459b24b2c2-0', tool_calls=[ToolCall(name='add', args={'x': 3, 'y': 82.65606421491815}, id='toolu_01MKQqnDw5CtyuKjQP8YG1FX', index=1)]),\n",
|
||||
" ToolMessage(content='85.65606421491815', tool_call_id='toolu_01MKQqnDw5CtyuKjQP8YG1FX'),\n",
|
||||
" AIMessage(content=[{'text': 'So 3 plus 5 raised to the 2.743 power is 85.656.\\n\\nFor the second part:', 'type': 'text'}, {'id': 'toolu_019Wb2zPouCR3dw2bSKvCRUL', 'input': {'x': 17.24, 'y': -918.1241}, 'name': 'add', 'type': 'tool_use'}], response_metadata={'id': 'msg_01Y2H2L8FWcDtVkCtuosie2P', 'model': 'claude-3-sonnet-20240229', 'stop_reason': 'tool_use', 'stop_sequence': None, 'usage': {'input_tokens': 661, 'output_tokens': 105}}, id='run-e553c1e3-24ba-4e1b-93ba-6f1985932db4-0', tool_calls=[ToolCall(name='add', args={'x': 17.24, 'y': -918.1241}, id='toolu_019Wb2zPouCR3dw2bSKvCRUL', index=1)]),\n",
|
||||
" ToolMessage(content='-900.8841', tool_call_id='toolu_019Wb2zPouCR3dw2bSKvCRUL'),\n",
|
||||
" AIMessage(content='Therefore, 17.24 - 918.1241 = -900.8841', response_metadata={'id': 'msg_01Q14dqvaCD2eA4zwrUvxTcF', 'model': 'claude-3-sonnet-20240229', 'stop_reason': 'end_turn', 'stop_sequence': None, 'usage': {'input_tokens': 782, 'output_tokens': 24}}, id='run-f6b6e525-2df6-4617-9bb3-b39d5cc963a9-0')]}"
|
||||
]
|
||||
},
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"graph.invoke(\n",
|
||||
" {\n",
|
||||
" \"messages\": [\n",
|
||||
" HumanMessage(\n",
|
||||
" \"what's 3 plus 5 raised to the 2.743. also what's 17.24 - 918.1241\"\n",
|
||||
" )\n",
|
||||
" ]\n",
|
||||
" },\n",
|
||||
" config={\"configurable\": {\"llm\": \"claude3\"}},\n",
|
||||
")"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.4"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
@ -0,0 +1,185 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import re
|
||||
from typing import Any, Callable, List
|
||||
|
||||
from langchain_core.exceptions import OutputParserException
|
||||
|
||||
|
||||
def _replace_new_line(match: re.Match[str]) -> str:
|
||||
value = match.group(2)
|
||||
value = re.sub(r"\n", r"\\n", value)
|
||||
value = re.sub(r"\r", r"\\r", value)
|
||||
value = re.sub(r"\t", r"\\t", value)
|
||||
value = re.sub(r'(?<!\\)"', r"\"", value)
|
||||
|
||||
return match.group(1) + value + match.group(3)
|
||||
|
||||
|
||||
def _custom_parser(multiline_string: str) -> str:
|
||||
"""
|
||||
The LLM response for `action_input` may be a multiline
|
||||
string containing unescaped newlines, tabs or quotes. This function
|
||||
replaces those characters with their escaped counterparts.
|
||||
(newlines in JSON must be double-escaped: `\\n`)
|
||||
"""
|
||||
if isinstance(multiline_string, (bytes, bytearray)):
|
||||
multiline_string = multiline_string.decode()
|
||||
|
||||
multiline_string = re.sub(
|
||||
r'("action_input"\:\s*")(.*?)(")',
|
||||
_replace_new_line,
|
||||
multiline_string,
|
||||
flags=re.DOTALL,
|
||||
)
|
||||
|
||||
return multiline_string
|
||||
|
||||
|
||||
# Adapted from https://github.com/KillianLucas/open-interpreter/blob/5b6080fae1f8c68938a1e4fa8667e3744084ee21/interpreter/utils/parse_partial_json.py
|
||||
# MIT License
|
||||
|
||||
|
||||
def parse_partial_json(s: str, *, strict: bool = False) -> Any:
|
||||
"""Parse a JSON string that may be missing closing braces.
|
||||
|
||||
Args:
|
||||
s: The JSON string to parse.
|
||||
strict: Whether to use strict parsing. Defaults to False.
|
||||
|
||||
Returns:
|
||||
The parsed JSON object as a Python dictionary.
|
||||
"""
|
||||
# Attempt to parse the string as-is.
|
||||
try:
|
||||
return json.loads(s, strict=strict)
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
# Initialize variables.
|
||||
new_s = ""
|
||||
stack = []
|
||||
is_inside_string = False
|
||||
escaped = False
|
||||
|
||||
# Process each character in the string one at a time.
|
||||
for char in s:
|
||||
if is_inside_string:
|
||||
if char == '"' and not escaped:
|
||||
is_inside_string = False
|
||||
elif char == "\n" and not escaped:
|
||||
char = "\\n" # Replace the newline character with the escape sequence.
|
||||
elif char == "\\":
|
||||
escaped = not escaped
|
||||
else:
|
||||
escaped = False
|
||||
else:
|
||||
if char == '"':
|
||||
is_inside_string = True
|
||||
escaped = False
|
||||
elif char == "{":
|
||||
stack.append("}")
|
||||
elif char == "[":
|
||||
stack.append("]")
|
||||
elif char == "}" or char == "]":
|
||||
if stack and stack[-1] == char:
|
||||
stack.pop()
|
||||
else:
|
||||
# Mismatched closing character; the input is malformed.
|
||||
return None
|
||||
|
||||
# Append the processed character to the new string.
|
||||
new_s += char
|
||||
|
||||
# If we're still inside a string at the end of processing,
|
||||
# we need to close the string.
|
||||
if is_inside_string:
|
||||
new_s += '"'
|
||||
|
||||
# Try to parse mods of string until we succeed or run out of characters.
|
||||
while new_s:
|
||||
final_s = new_s
|
||||
|
||||
# Close any remaining open structures in the reverse
|
||||
# order that they were opened.
|
||||
for closing_char in reversed(stack):
|
||||
final_s += closing_char
|
||||
|
||||
# Attempt to parse the modified string as JSON.
|
||||
try:
|
||||
return json.loads(final_s, strict=strict)
|
||||
except json.JSONDecodeError:
|
||||
# If we still can't parse the string as JSON,
|
||||
# try removing the last character
|
||||
new_s = new_s[:-1]
|
||||
|
||||
# If we got here, we ran out of characters to remove
|
||||
# and still couldn't parse the string as JSON, so return the parse error
|
||||
# for the original string.
|
||||
return json.loads(s, strict=strict)
|
||||
|
||||
|
||||
def parse_json_markdown(
|
||||
json_string: str, *, parser: Callable[[str], Any] = parse_partial_json
|
||||
) -> dict:
|
||||
"""
|
||||
Parse a JSON string from a Markdown string.
|
||||
|
||||
Args:
|
||||
json_string: The Markdown string.
|
||||
|
||||
Returns:
|
||||
The parsed JSON object as a Python dictionary.
|
||||
"""
|
||||
try:
|
||||
return _parse_json(json_string, parser=parser)
|
||||
except json.JSONDecodeError:
|
||||
# Try to find JSON string within triple backticks
|
||||
match = re.search(r"```(json)?(.*)", json_string, re.DOTALL)
|
||||
|
||||
# If no match found, assume the entire string is a JSON string
|
||||
if match is None:
|
||||
json_str = json_string
|
||||
else:
|
||||
# If match found, use the content within the backticks
|
||||
json_str = match.group(2)
|
||||
return _parse_json(json_str, parser=parser)
|
||||
|
||||
|
||||
def _parse_json(
|
||||
json_str: str, *, parser: Callable[[str], Any] = parse_partial_json
|
||||
) -> dict:
|
||||
# Strip whitespace and newlines from the start and end
|
||||
json_str = json_str.strip().strip("`")
|
||||
|
||||
# handle newlines and other special characters inside the returned value
|
||||
json_str = _custom_parser(json_str)
|
||||
|
||||
# Parse the JSON string into a Python dictionary
|
||||
return parser(json_str)
|
||||
|
||||
|
||||
def parse_and_check_json_markdown(text: str, expected_keys: List[str]) -> dict:
|
||||
"""
|
||||
Parse a JSON string from a Markdown string and check that it
|
||||
contains the expected keys.
|
||||
|
||||
Args:
|
||||
text: The Markdown string.
|
||||
expected_keys: The expected keys in the JSON string.
|
||||
|
||||
Returns:
|
||||
The parsed JSON object as a Python dictionary.
|
||||
"""
|
||||
try:
|
||||
json_obj = parse_json_markdown(text)
|
||||
except json.JSONDecodeError as e:
|
||||
raise OutputParserException(f"Got invalid JSON object. Error: {e}")
|
||||
for key in expected_keys:
|
||||
if key not in json_obj:
|
||||
raise OutputParserException(
|
||||
f"Got invalid return object. Expected key `{key}` "
|
||||
f"to be present, but got {json_obj}"
|
||||
)
|
||||
return json_obj
|
Loading…
Reference in New Issue