Fix OpenAIFunctionsAgent function call message content retrieving (#10488)

`langchain.agents.openai_functions[_multi]_agent._parse_ai_message()`
incorrectly extracts AI message content, thus LLM response ("thoughts")
is lost and can't be logged or processed by callbacks.

This PR fixes function call message content retrieving.
pull/10431/head^2
Sergey Kozlov 1 year ago committed by GitHub
parent 2dc3c64386
commit 0a0276bcdb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -127,7 +127,7 @@ def _parse_ai_message(message: BaseMessage) -> Union[AgentAction, AgentFinish]:
else:
tool_input = _tool_input
content_msg = "responded: {content}\n" if message.content else "\n"
content_msg = f"responded: {message.content}\n" if message.content else "\n"
return _FunctionsAgentAction(
tool=function_name,

@ -129,7 +129,7 @@ def _parse_ai_message(message: BaseMessage) -> Union[List[AgentAction], AgentFin
else:
tool_input = _tool_input
content_msg = "responded: {content}\n" if message.content else "\n"
content_msg = f"responded: {message.content}\n" if message.content else "\n"
log = f"\nInvoking: `{function_name}` with `{tool_input}`\n{content_msg}\n"
_tool = _FunctionsAgentAction(
tool=function_name,

@ -0,0 +1,76 @@
import pytest
from langchain.agents.openai_functions_agent.base import (
_FunctionsAgentAction,
_parse_ai_message,
)
from langchain.schema import AgentFinish, OutputParserException
from langchain.schema.messages import AIMessage, SystemMessage
# Test: _parse_ai_message() function.
class TestParseAIMessage:
# Test: Pass Non-AIMessage.
def test_not_an_ai(self) -> None:
err = f"Expected an AI message got {str(SystemMessage)}"
with pytest.raises(TypeError, match=err):
_parse_ai_message(SystemMessage(content="x"))
# Test: Model response (not a function call).
def test_model_response(self) -> None:
msg = AIMessage(content="Model response.")
result = _parse_ai_message(msg)
assert isinstance(result, AgentFinish)
assert result.return_values == {"output": "Model response."}
assert result.log == "Model response."
# Test: Model response with a function call.
def test_func_call(self) -> None:
msg = AIMessage(
content="LLM thoughts.",
additional_kwargs={
"function_call": {"name": "foo", "arguments": '{"param": 42}'}
},
)
result = _parse_ai_message(msg)
assert isinstance(result, _FunctionsAgentAction)
assert result.tool == "foo"
assert result.tool_input == {"param": 42}
assert result.log == (
"\nInvoking: `foo` with `{'param': 42}`\nresponded: LLM thoughts.\n\n"
)
assert result.message_log == [msg]
# Test: Model response with a function call (old style tools).
def test_func_call_oldstyle(self) -> None:
msg = AIMessage(
content="LLM thoughts.",
additional_kwargs={
"function_call": {"name": "foo", "arguments": '{"__arg1": "42"}'}
},
)
result = _parse_ai_message(msg)
assert isinstance(result, _FunctionsAgentAction)
assert result.tool == "foo"
assert result.tool_input == "42"
assert result.log == (
"\nInvoking: `foo` with `42`\nresponded: LLM thoughts.\n\n"
)
assert result.message_log == [msg]
# Test: Invalid function call args.
def test_func_call_invalid(self) -> None:
msg = AIMessage(
content="LLM thoughts.",
additional_kwargs={"function_call": {"name": "foo", "arguments": "{42]"}},
)
err = (
"Could not parse tool input: {'name': 'foo', 'arguments': '{42]'} "
"because the `arguments` is not valid JSON."
)
with pytest.raises(OutputParserException, match=err):
_parse_ai_message(msg)

@ -0,0 +1,90 @@
import json
import pytest
from langchain.agents.openai_functions_multi_agent.base import (
_FunctionsAgentAction,
_parse_ai_message,
)
from langchain.schema import AgentFinish, OutputParserException
from langchain.schema.messages import AIMessage, SystemMessage
# Test: _parse_ai_message() function.
class TestParseAIMessage:
# Test: Pass Non-AIMessage.
def test_not_an_ai(self) -> None:
err = f"Expected an AI message got {str(SystemMessage)}"
with pytest.raises(TypeError, match=err):
_parse_ai_message(SystemMessage(content="x"))
# Test: Model response (not a function call).
def test_model_response(self) -> None:
msg = AIMessage(content="Model response.")
result = _parse_ai_message(msg)
assert isinstance(result, AgentFinish)
assert result.return_values == {"output": "Model response."}
assert result.log == "Model response."
# Test: Model response with a function call.
def test_func_call(self) -> None:
act = json.dumps([{"action_name": "foo", "action": {"param": 42}}])
msg = AIMessage(
content="LLM thoughts.",
additional_kwargs={
"function_call": {"name": "foo", "arguments": f'{{"actions": {act}}}'}
},
)
result = _parse_ai_message(msg)
assert isinstance(result, list)
assert len(result) == 1
action = result[0]
assert isinstance(action, _FunctionsAgentAction)
assert action.tool == "foo"
assert action.tool_input == {"param": 42}
assert action.log == (
"\nInvoking: `foo` with `{'param': 42}`\nresponded: LLM thoughts.\n\n"
)
assert action.message_log == [msg]
# Test: Model response with a function call (old style tools).
def test_func_call_oldstyle(self) -> None:
act = json.dumps([{"action_name": "foo", "action": {"__arg1": "42"}}])
msg = AIMessage(
content="LLM thoughts.",
additional_kwargs={
"function_call": {"name": "foo", "arguments": f'{{"actions": {act}}}'}
},
)
result = _parse_ai_message(msg)
assert isinstance(result, list)
assert len(result) == 1
action = result[0]
assert isinstance(action, _FunctionsAgentAction)
assert action.tool == "foo"
assert action.tool_input == "42"
assert action.log == (
"\nInvoking: `foo` with `42`\nresponded: LLM thoughts.\n\n"
)
assert action.message_log == [msg]
# Test: Invalid function call args.
def test_func_call_invalid(self) -> None:
msg = AIMessage(
content="LLM thoughts.",
additional_kwargs={"function_call": {"name": "foo", "arguments": "{42]"}},
)
err = (
"Could not parse tool input: {'name': 'foo', 'arguments': '{42]'} "
"because the `arguments` is not valid JSON."
)
with pytest.raises(OutputParserException, match=err):
_parse_ai_message(msg)
Loading…
Cancel
Save