Fix OpenAIFunctionsAgent function call message content retrieving (#10488)

`langchain.agents.openai_functions[_multi]_agent._parse_ai_message()`
incorrectly extracts AI message content, thus LLM response ("thoughts")
is lost and can't be logged or processed by callbacks.

This PR fixes function call message content retrieving.
This commit is contained in:
Sergey Kozlov 2023-09-14 05:19:25 +06:00 committed by GitHub
parent 2dc3c64386
commit 0a0276bcdb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 168 additions and 2 deletions

View File

@ -127,7 +127,7 @@ def _parse_ai_message(message: BaseMessage) -> Union[AgentAction, AgentFinish]:
else: else:
tool_input = _tool_input tool_input = _tool_input
content_msg = "responded: {content}\n" if message.content else "\n" content_msg = f"responded: {message.content}\n" if message.content else "\n"
return _FunctionsAgentAction( return _FunctionsAgentAction(
tool=function_name, tool=function_name,

View File

@ -129,7 +129,7 @@ def _parse_ai_message(message: BaseMessage) -> Union[List[AgentAction], AgentFin
else: else:
tool_input = _tool_input tool_input = _tool_input
content_msg = "responded: {content}\n" if message.content else "\n" content_msg = f"responded: {message.content}\n" if message.content else "\n"
log = f"\nInvoking: `{function_name}` with `{tool_input}`\n{content_msg}\n" log = f"\nInvoking: `{function_name}` with `{tool_input}`\n{content_msg}\n"
_tool = _FunctionsAgentAction( _tool = _FunctionsAgentAction(
tool=function_name, tool=function_name,

View File

@ -0,0 +1,76 @@
import pytest
from langchain.agents.openai_functions_agent.base import (
_FunctionsAgentAction,
_parse_ai_message,
)
from langchain.schema import AgentFinish, OutputParserException
from langchain.schema.messages import AIMessage, SystemMessage
# Test: _parse_ai_message() function.
class TestParseAIMessage:
# Test: Pass Non-AIMessage.
def test_not_an_ai(self) -> None:
err = f"Expected an AI message got {str(SystemMessage)}"
with pytest.raises(TypeError, match=err):
_parse_ai_message(SystemMessage(content="x"))
# Test: Model response (not a function call).
def test_model_response(self) -> None:
msg = AIMessage(content="Model response.")
result = _parse_ai_message(msg)
assert isinstance(result, AgentFinish)
assert result.return_values == {"output": "Model response."}
assert result.log == "Model response."
# Test: Model response with a function call.
def test_func_call(self) -> None:
msg = AIMessage(
content="LLM thoughts.",
additional_kwargs={
"function_call": {"name": "foo", "arguments": '{"param": 42}'}
},
)
result = _parse_ai_message(msg)
assert isinstance(result, _FunctionsAgentAction)
assert result.tool == "foo"
assert result.tool_input == {"param": 42}
assert result.log == (
"\nInvoking: `foo` with `{'param': 42}`\nresponded: LLM thoughts.\n\n"
)
assert result.message_log == [msg]
# Test: Model response with a function call (old style tools).
def test_func_call_oldstyle(self) -> None:
msg = AIMessage(
content="LLM thoughts.",
additional_kwargs={
"function_call": {"name": "foo", "arguments": '{"__arg1": "42"}'}
},
)
result = _parse_ai_message(msg)
assert isinstance(result, _FunctionsAgentAction)
assert result.tool == "foo"
assert result.tool_input == "42"
assert result.log == (
"\nInvoking: `foo` with `42`\nresponded: LLM thoughts.\n\n"
)
assert result.message_log == [msg]
# Test: Invalid function call args.
def test_func_call_invalid(self) -> None:
msg = AIMessage(
content="LLM thoughts.",
additional_kwargs={"function_call": {"name": "foo", "arguments": "{42]"}},
)
err = (
"Could not parse tool input: {'name': 'foo', 'arguments': '{42]'} "
"because the `arguments` is not valid JSON."
)
with pytest.raises(OutputParserException, match=err):
_parse_ai_message(msg)

View File

@ -0,0 +1,90 @@
import json
import pytest
from langchain.agents.openai_functions_multi_agent.base import (
_FunctionsAgentAction,
_parse_ai_message,
)
from langchain.schema import AgentFinish, OutputParserException
from langchain.schema.messages import AIMessage, SystemMessage
# Test: _parse_ai_message() function.
class TestParseAIMessage:
# Test: Pass Non-AIMessage.
def test_not_an_ai(self) -> None:
err = f"Expected an AI message got {str(SystemMessage)}"
with pytest.raises(TypeError, match=err):
_parse_ai_message(SystemMessage(content="x"))
# Test: Model response (not a function call).
def test_model_response(self) -> None:
msg = AIMessage(content="Model response.")
result = _parse_ai_message(msg)
assert isinstance(result, AgentFinish)
assert result.return_values == {"output": "Model response."}
assert result.log == "Model response."
# Test: Model response with a function call.
def test_func_call(self) -> None:
act = json.dumps([{"action_name": "foo", "action": {"param": 42}}])
msg = AIMessage(
content="LLM thoughts.",
additional_kwargs={
"function_call": {"name": "foo", "arguments": f'{{"actions": {act}}}'}
},
)
result = _parse_ai_message(msg)
assert isinstance(result, list)
assert len(result) == 1
action = result[0]
assert isinstance(action, _FunctionsAgentAction)
assert action.tool == "foo"
assert action.tool_input == {"param": 42}
assert action.log == (
"\nInvoking: `foo` with `{'param': 42}`\nresponded: LLM thoughts.\n\n"
)
assert action.message_log == [msg]
# Test: Model response with a function call (old style tools).
def test_func_call_oldstyle(self) -> None:
act = json.dumps([{"action_name": "foo", "action": {"__arg1": "42"}}])
msg = AIMessage(
content="LLM thoughts.",
additional_kwargs={
"function_call": {"name": "foo", "arguments": f'{{"actions": {act}}}'}
},
)
result = _parse_ai_message(msg)
assert isinstance(result, list)
assert len(result) == 1
action = result[0]
assert isinstance(action, _FunctionsAgentAction)
assert action.tool == "foo"
assert action.tool_input == "42"
assert action.log == (
"\nInvoking: `foo` with `42`\nresponded: LLM thoughts.\n\n"
)
assert action.message_log == [msg]
# Test: Invalid function call args.
def test_func_call_invalid(self) -> None:
msg = AIMessage(
content="LLM thoughts.",
additional_kwargs={"function_call": {"name": "foo", "arguments": "{42]"}},
)
err = (
"Could not parse tool input: {'name': 'foo', 'arguments': '{42]'} "
"because the `arguments` is not valid JSON."
)
with pytest.raises(OutputParserException, match=err):
_parse_ai_message(msg)