mirror of
https://github.com/hwchase17/langchain
synced 2024-11-04 06:00:26 +00:00
Fix OpenAIFunctionsAgent function call message content retrieving (#10488)
`langchain.agents.openai_functions[_multi]_agent._parse_ai_message()` incorrectly extracts AI message content, thus LLM response ("thoughts") is lost and can't be logged or processed by callbacks. This PR fixes function call message content retrieving.
This commit is contained in:
parent
2dc3c64386
commit
0a0276bcdb
@ -127,7 +127,7 @@ def _parse_ai_message(message: BaseMessage) -> Union[AgentAction, AgentFinish]:
|
|||||||
else:
|
else:
|
||||||
tool_input = _tool_input
|
tool_input = _tool_input
|
||||||
|
|
||||||
content_msg = "responded: {content}\n" if message.content else "\n"
|
content_msg = f"responded: {message.content}\n" if message.content else "\n"
|
||||||
|
|
||||||
return _FunctionsAgentAction(
|
return _FunctionsAgentAction(
|
||||||
tool=function_name,
|
tool=function_name,
|
||||||
|
@ -129,7 +129,7 @@ def _parse_ai_message(message: BaseMessage) -> Union[List[AgentAction], AgentFin
|
|||||||
else:
|
else:
|
||||||
tool_input = _tool_input
|
tool_input = _tool_input
|
||||||
|
|
||||||
content_msg = "responded: {content}\n" if message.content else "\n"
|
content_msg = f"responded: {message.content}\n" if message.content else "\n"
|
||||||
log = f"\nInvoking: `{function_name}` with `{tool_input}`\n{content_msg}\n"
|
log = f"\nInvoking: `{function_name}` with `{tool_input}`\n{content_msg}\n"
|
||||||
_tool = _FunctionsAgentAction(
|
_tool = _FunctionsAgentAction(
|
||||||
tool=function_name,
|
tool=function_name,
|
||||||
|
@ -0,0 +1,76 @@
|
|||||||
|
import pytest
|
||||||
|
|
||||||
|
from langchain.agents.openai_functions_agent.base import (
|
||||||
|
_FunctionsAgentAction,
|
||||||
|
_parse_ai_message,
|
||||||
|
)
|
||||||
|
from langchain.schema import AgentFinish, OutputParserException
|
||||||
|
from langchain.schema.messages import AIMessage, SystemMessage
|
||||||
|
|
||||||
|
|
||||||
|
# Test: _parse_ai_message() function.
|
||||||
|
class TestParseAIMessage:
|
||||||
|
# Test: Pass Non-AIMessage.
|
||||||
|
def test_not_an_ai(self) -> None:
|
||||||
|
err = f"Expected an AI message got {str(SystemMessage)}"
|
||||||
|
with pytest.raises(TypeError, match=err):
|
||||||
|
_parse_ai_message(SystemMessage(content="x"))
|
||||||
|
|
||||||
|
# Test: Model response (not a function call).
|
||||||
|
def test_model_response(self) -> None:
|
||||||
|
msg = AIMessage(content="Model response.")
|
||||||
|
result = _parse_ai_message(msg)
|
||||||
|
|
||||||
|
assert isinstance(result, AgentFinish)
|
||||||
|
assert result.return_values == {"output": "Model response."}
|
||||||
|
assert result.log == "Model response."
|
||||||
|
|
||||||
|
# Test: Model response with a function call.
|
||||||
|
def test_func_call(self) -> None:
|
||||||
|
msg = AIMessage(
|
||||||
|
content="LLM thoughts.",
|
||||||
|
additional_kwargs={
|
||||||
|
"function_call": {"name": "foo", "arguments": '{"param": 42}'}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
result = _parse_ai_message(msg)
|
||||||
|
|
||||||
|
assert isinstance(result, _FunctionsAgentAction)
|
||||||
|
assert result.tool == "foo"
|
||||||
|
assert result.tool_input == {"param": 42}
|
||||||
|
assert result.log == (
|
||||||
|
"\nInvoking: `foo` with `{'param': 42}`\nresponded: LLM thoughts.\n\n"
|
||||||
|
)
|
||||||
|
assert result.message_log == [msg]
|
||||||
|
|
||||||
|
# Test: Model response with a function call (old style tools).
|
||||||
|
def test_func_call_oldstyle(self) -> None:
|
||||||
|
msg = AIMessage(
|
||||||
|
content="LLM thoughts.",
|
||||||
|
additional_kwargs={
|
||||||
|
"function_call": {"name": "foo", "arguments": '{"__arg1": "42"}'}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
result = _parse_ai_message(msg)
|
||||||
|
|
||||||
|
assert isinstance(result, _FunctionsAgentAction)
|
||||||
|
assert result.tool == "foo"
|
||||||
|
assert result.tool_input == "42"
|
||||||
|
assert result.log == (
|
||||||
|
"\nInvoking: `foo` with `42`\nresponded: LLM thoughts.\n\n"
|
||||||
|
)
|
||||||
|
assert result.message_log == [msg]
|
||||||
|
|
||||||
|
# Test: Invalid function call args.
|
||||||
|
def test_func_call_invalid(self) -> None:
|
||||||
|
msg = AIMessage(
|
||||||
|
content="LLM thoughts.",
|
||||||
|
additional_kwargs={"function_call": {"name": "foo", "arguments": "{42]"}},
|
||||||
|
)
|
||||||
|
|
||||||
|
err = (
|
||||||
|
"Could not parse tool input: {'name': 'foo', 'arguments': '{42]'} "
|
||||||
|
"because the `arguments` is not valid JSON."
|
||||||
|
)
|
||||||
|
with pytest.raises(OutputParserException, match=err):
|
||||||
|
_parse_ai_message(msg)
|
@ -0,0 +1,90 @@
|
|||||||
|
import json
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from langchain.agents.openai_functions_multi_agent.base import (
|
||||||
|
_FunctionsAgentAction,
|
||||||
|
_parse_ai_message,
|
||||||
|
)
|
||||||
|
from langchain.schema import AgentFinish, OutputParserException
|
||||||
|
from langchain.schema.messages import AIMessage, SystemMessage
|
||||||
|
|
||||||
|
|
||||||
|
# Test: _parse_ai_message() function.
|
||||||
|
class TestParseAIMessage:
|
||||||
|
# Test: Pass Non-AIMessage.
|
||||||
|
def test_not_an_ai(self) -> None:
|
||||||
|
err = f"Expected an AI message got {str(SystemMessage)}"
|
||||||
|
with pytest.raises(TypeError, match=err):
|
||||||
|
_parse_ai_message(SystemMessage(content="x"))
|
||||||
|
|
||||||
|
# Test: Model response (not a function call).
|
||||||
|
def test_model_response(self) -> None:
|
||||||
|
msg = AIMessage(content="Model response.")
|
||||||
|
result = _parse_ai_message(msg)
|
||||||
|
|
||||||
|
assert isinstance(result, AgentFinish)
|
||||||
|
assert result.return_values == {"output": "Model response."}
|
||||||
|
assert result.log == "Model response."
|
||||||
|
|
||||||
|
# Test: Model response with a function call.
|
||||||
|
def test_func_call(self) -> None:
|
||||||
|
act = json.dumps([{"action_name": "foo", "action": {"param": 42}}])
|
||||||
|
|
||||||
|
msg = AIMessage(
|
||||||
|
content="LLM thoughts.",
|
||||||
|
additional_kwargs={
|
||||||
|
"function_call": {"name": "foo", "arguments": f'{{"actions": {act}}}'}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
result = _parse_ai_message(msg)
|
||||||
|
|
||||||
|
assert isinstance(result, list)
|
||||||
|
assert len(result) == 1
|
||||||
|
|
||||||
|
action = result[0]
|
||||||
|
assert isinstance(action, _FunctionsAgentAction)
|
||||||
|
assert action.tool == "foo"
|
||||||
|
assert action.tool_input == {"param": 42}
|
||||||
|
assert action.log == (
|
||||||
|
"\nInvoking: `foo` with `{'param': 42}`\nresponded: LLM thoughts.\n\n"
|
||||||
|
)
|
||||||
|
assert action.message_log == [msg]
|
||||||
|
|
||||||
|
# Test: Model response with a function call (old style tools).
|
||||||
|
def test_func_call_oldstyle(self) -> None:
|
||||||
|
act = json.dumps([{"action_name": "foo", "action": {"__arg1": "42"}}])
|
||||||
|
|
||||||
|
msg = AIMessage(
|
||||||
|
content="LLM thoughts.",
|
||||||
|
additional_kwargs={
|
||||||
|
"function_call": {"name": "foo", "arguments": f'{{"actions": {act}}}'}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
result = _parse_ai_message(msg)
|
||||||
|
|
||||||
|
assert isinstance(result, list)
|
||||||
|
assert len(result) == 1
|
||||||
|
|
||||||
|
action = result[0]
|
||||||
|
assert isinstance(action, _FunctionsAgentAction)
|
||||||
|
assert action.tool == "foo"
|
||||||
|
assert action.tool_input == "42"
|
||||||
|
assert action.log == (
|
||||||
|
"\nInvoking: `foo` with `42`\nresponded: LLM thoughts.\n\n"
|
||||||
|
)
|
||||||
|
assert action.message_log == [msg]
|
||||||
|
|
||||||
|
# Test: Invalid function call args.
|
||||||
|
def test_func_call_invalid(self) -> None:
|
||||||
|
msg = AIMessage(
|
||||||
|
content="LLM thoughts.",
|
||||||
|
additional_kwargs={"function_call": {"name": "foo", "arguments": "{42]"}},
|
||||||
|
)
|
||||||
|
|
||||||
|
err = (
|
||||||
|
"Could not parse tool input: {'name': 'foo', 'arguments': '{42]'} "
|
||||||
|
"because the `arguments` is not valid JSON."
|
||||||
|
)
|
||||||
|
with pytest.raises(OutputParserException, match=err):
|
||||||
|
_parse_ai_message(msg)
|
Loading…
Reference in New Issue
Block a user