From eb9b334a6b78af0d56b0e01fe6a0e36a7dba91b8 Mon Sep 17 00:00:00 2001 From: Tal <73430602+treisfeld@users.noreply.github.com> Date: Fri, 12 Jan 2024 07:52:36 +0200 Subject: [PATCH] Enable customizing the output parser of `OpenAIFunctionsAgent` (#15827) - **Description:** This PR defines the output parser of OpenAIFunctionsAgent as an attribute, enabling customization and subclassing of the parser logic. - **Issue:** Subclassing is currently impossible as the `OpenAIFunctionsAgentOutputParser` class is hard coded into the `plan` and `aplan` methods - **Dependencies:** None --------- Co-authored-by: Harrison Chase --- .../langchain/agents/openai_functions_agent/base.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/libs/langchain/langchain/agents/openai_functions_agent/base.py b/libs/langchain/langchain/agents/openai_functions_agent/base.py index 9e6c6e0cfd..f388f5cae2 100644 --- a/libs/langchain/langchain/agents/openai_functions_agent/base.py +++ b/libs/langchain/langchain/agents/openai_functions_agent/base.py @@ -1,5 +1,5 @@ """Module implements an agent that uses OpenAI's APIs function enabled API.""" -from typing import Any, List, Optional, Sequence, Tuple, Union +from typing import Any, List, Optional, Sequence, Tuple, Type, Union from langchain_community.tools.convert_to_openai import format_tool_to_openai_function from langchain_core._api import deprecated @@ -47,6 +47,9 @@ class OpenAIFunctionsAgent(BaseSingleActionAgent): llm: BaseLanguageModel tools: Sequence[BaseTool] prompt: BasePromptTemplate + output_parser: Type[ + OpenAIFunctionsAgentOutputParser + ] = OpenAIFunctionsAgentOutputParser def get_allowed_tools(self) -> List[str]: """Get allowed tools.""" @@ -105,9 +108,7 @@ class OpenAIFunctionsAgent(BaseSingleActionAgent): messages, callbacks=callbacks, ) - agent_decision = OpenAIFunctionsAgentOutputParser._parse_ai_message( - predicted_message - ) + agent_decision = self.output_parser._parse_ai_message(predicted_message) return agent_decision async def aplan( @@ -136,9 +137,7 @@ class OpenAIFunctionsAgent(BaseSingleActionAgent): predicted_message = await self.llm.apredict_messages( messages, functions=self.functions, callbacks=callbacks ) - agent_decision = OpenAIFunctionsAgentOutputParser._parse_ai_message( - predicted_message - ) + agent_decision = self.output_parser._parse_ai_message(predicted_message) return agent_decision def return_stopped_response(