diff --git a/libs/langchain/langchain/agents/openai_functions_agent/base.py b/libs/langchain/langchain/agents/openai_functions_agent/base.py index 52aa91b7c8..c75f14a341 100644 --- a/libs/langchain/langchain/agents/openai_functions_agent/base.py +++ b/libs/langchain/langchain/agents/openai_functions_agent/base.py @@ -1,6 +1,5 @@ """Module implements an agent that uses OpenAI's APIs function enabled API.""" import json -from dataclasses import dataclass from json import JSONDecodeError from typing import Any, List, Optional, Sequence, Tuple, Union @@ -21,6 +20,7 @@ from langchain.schema import ( BasePromptTemplate, OutputParserException, ) +from langchain.schema.agent import AgentActionMessageLog from langchain.schema.language_model import BaseLanguageModel from langchain.schema.messages import ( AIMessage, @@ -31,10 +31,8 @@ from langchain.schema.messages import ( from langchain.tools import BaseTool from langchain.tools.convert_to_openai import format_tool_to_openai_function - -@dataclass -class _FunctionsAgentAction(AgentAction): - message_log: List[BaseMessage] +# For backwards compatibility +_FunctionsAgentAction = AgentActionMessageLog def _convert_agent_action_to_messages( @@ -51,7 +49,7 @@ def _convert_agent_action_to_messages( AIMessage that corresponds to the original tool invocation. """ if isinstance(agent_action, _FunctionsAgentAction): - return agent_action.message_log + [ + return list(agent_action.message_log) + [ _create_function_message(agent_action, observation) ] else: diff --git a/libs/langchain/langchain/agents/openai_functions_multi_agent/base.py b/libs/langchain/langchain/agents/openai_functions_multi_agent/base.py index 7469f30389..5849cf9718 100644 --- a/libs/langchain/langchain/agents/openai_functions_multi_agent/base.py +++ b/libs/langchain/langchain/agents/openai_functions_multi_agent/base.py @@ -1,6 +1,5 @@ """Module implements an agent that uses OpenAI's APIs function enabled API.""" import json -from dataclasses import dataclass from json import JSONDecodeError from typing import Any, List, Optional, Sequence, Tuple, Union @@ -21,6 +20,7 @@ from langchain.schema import ( BasePromptTemplate, OutputParserException, ) +from langchain.schema.agent import AgentActionMessageLog from langchain.schema.language_model import BaseLanguageModel from langchain.schema.messages import ( AIMessage, @@ -30,10 +30,8 @@ from langchain.schema.messages import ( ) from langchain.tools import BaseTool - -@dataclass -class _FunctionsAgentAction(AgentAction): - message_log: List[BaseMessage] +# For backwards compatibility +_FunctionsAgentAction = AgentActionMessageLog def _convert_agent_action_to_messages( @@ -50,7 +48,7 @@ def _convert_agent_action_to_messages( AIMessage that corresponds to the original tool invocation. """ if isinstance(agent_action, _FunctionsAgentAction): - return agent_action.message_log + [ + return list(agent_action.message_log) + [ _create_function_message(agent_action, observation) ] else: diff --git a/libs/langchain/langchain/schema/agent.py b/libs/langchain/langchain/schema/agent.py index 5d58993cb5..deb292ba3c 100644 --- a/libs/langchain/langchain/schema/agent.py +++ b/libs/langchain/langchain/schema/agent.py @@ -1,11 +1,12 @@ from __future__ import annotations -from dataclasses import dataclass -from typing import NamedTuple, Union +from typing import Any, Sequence, Union +from langchain.load.serializable import Serializable +from langchain.schema.messages import BaseMessage -@dataclass -class AgentAction: + +class AgentAction(Serializable): """A full description of an action for an ActionAgent to execute.""" tool: str @@ -13,13 +14,57 @@ class AgentAction: tool_input: Union[str, dict] """The input to pass in to the Tool.""" log: str - """Additional information to log about the action.""" + """Additional information to log about the action. + This log can be used in a few ways. First, it can be used to audit + what exactly the LLM predicted to lead to this (tool, tool_input). + Second, it can be used in future iterations to show the LLMs prior + thoughts. This is useful when (tool, tool_input) does not contain + full information about the LLM prediction (for example, any `thought` + before the tool/tool_input).""" + + def __init__( + self, tool: str, tool_input: Union[str, dict], log: str, **kwargs: Any + ): + super().__init__(tool=tool, tool_input=tool_input, log=log, **kwargs) + + @property + def lc_serializable(self) -> bool: + """ + Return whether or not the class is serializable. + """ + return True + +class AgentActionMessageLog(AgentAction): + message_log: Sequence[BaseMessage] + """Similar to log, this can be used to pass along extra + information about what exact messages were predicted by the LLM + before parsing out the (tool, tool_input). This is again useful + if (tool, tool_input) cannot be used to fully recreate the LLM + prediction, and you need that LLM prediction (for future agent iteration). + Compared to `log`, this is useful when the underlying LLM is a + ChatModel (and therefore returns messages rather than a string).""" -class AgentFinish(NamedTuple): + +class AgentFinish(Serializable): """The final return value of an ActionAgent.""" return_values: dict """Dictionary of return values.""" log: str - """Additional information to log about the return value""" + """Additional information to log about the return value. + This is used to pass along the full LLM prediction, not just the parsed out + return value. For example, if the full LLM prediction was + `Final Answer: 2` you may want to just return `2` as a return value, but pass + along the full string as a `log` (for debugging or observability purposes). + """ + + def __init__(self, return_values: dict, log: str, **kwargs: Any): + super().__init__(return_values=return_values, log=log, **kwargs) + + @property + def lc_serializable(self) -> bool: + """ + Return whether or not the class is serializable. + """ + return True