diff --git a/libs/langchain/langchain/agents/output_parsers/openai_functions.py b/libs/langchain/langchain/agents/output_parsers/openai_functions.py index 7a2a2702d6..9ba41963a8 100644 --- a/libs/langchain/langchain/agents/output_parsers/openai_functions.py +++ b/libs/langchain/langchain/agents/output_parsers/openai_functions.py @@ -75,7 +75,9 @@ class OpenAIFunctionsAgentOutputParser(AgentOutputParser): return_values={"output": message.content}, log=message.content ) - def parse_result(self, result: List[Generation]) -> Union[AgentAction, AgentFinish]: + def parse_result( + self, result: List[Generation], *, partial: bool = False + ) -> Union[AgentAction, AgentFinish]: if not isinstance(result[0], ChatGeneration): raise ValueError("This output parser only works on ChatGeneration output") message = result[0].message diff --git a/libs/langchain/langchain/output_parsers/openai_functions.py b/libs/langchain/langchain/output_parsers/openai_functions.py index 098f4ab372..8724035e4b 100644 --- a/libs/langchain/langchain/output_parsers/openai_functions.py +++ b/libs/langchain/langchain/output_parsers/openai_functions.py @@ -1,4 +1,5 @@ import copy +import json from typing import Any, Dict, List, Optional, Type, Union import jsonpatch @@ -22,7 +23,7 @@ class OutputFunctionsParser(BaseGenerationOutputParser[Any]): args_only: bool = True """Whether to only return the arguments to the function call.""" - def parse_result(self, result: List[Generation]) -> Any: + def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any: generation = result[0] if not isinstance(generation, ChatGeneration): raise OutputParserException( @@ -56,7 +57,7 @@ class JsonOutputFunctionsParser(BaseCumulativeTransformOutputParser[Any]): def _diff(self, prev: Optional[Any], next: Any) -> Any: return jsonpatch.make_patch(prev, next).patch - def parse_result(self, result: List[Generation]) -> Any: + def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any: if len(result) != 1: raise OutputParserException( f"Expected exactly one result, but got {len(result)}" @@ -69,16 +70,46 @@ class JsonOutputFunctionsParser(BaseCumulativeTransformOutputParser[Any]): message = generation.message try: function_call = message.additional_kwargs["function_call"] - except KeyError: - return None + except KeyError as exc: + if partial: + return None + else: + raise OutputParserException(f"Could not parse function call: {exc}") try: - if self.args_only: - return parse_partial_json(function_call["arguments"]) + if partial: + if self.args_only: + return parse_partial_json( + function_call["arguments"], strict=self.strict + ) + else: + return { + **function_call, + "arguments": parse_partial_json( + function_call["arguments"], strict=self.strict + ), + } else: - return { - **function_call, - "arguments": parse_partial_json(function_call["arguments"]), - } + if self.args_only: + try: + return json.loads( + function_call["arguments"], strict=self.strict + ) + except (json.JSONDecodeError, TypeError) as exc: + raise OutputParserException( + f"Could not parse function call data: {exc}" + ) + else: + try: + return { + **function_call, + "arguments": json.loads( + function_call["arguments"], strict=self.strict + ), + } + except (json.JSONDecodeError, TypeError) as exc: + raise OutputParserException( + f"Could not parse function call data: {exc}" + ) except KeyError: return None @@ -94,7 +125,7 @@ class JsonKeyOutputFunctionsParser(JsonOutputFunctionsParser): key_name: str """The name of the key to return.""" - def parse_result(self, result: List[Generation]) -> Any: + def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any: res = super().parse_result(result) return res[self.key_name] @@ -119,7 +150,7 @@ class PydanticOutputFunctionsParser(OutputFunctionsParser): ) return values - def parse_result(self, result: List[Generation]) -> Any: + def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any: _result = super().parse_result(result) if self.args_only: pydantic_args = self.pydantic_schema.parse_raw(_result) # type: ignore @@ -136,6 +167,6 @@ class PydanticAttrOutputFunctionsParser(PydanticOutputFunctionsParser): attr_name: str """The name of the attribute to return.""" - def parse_result(self, result: List[Generation]) -> Any: + def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any: result = super().parse_result(result) return getattr(result, self.attr_name) diff --git a/libs/langchain/langchain/schema/output_parser.py b/libs/langchain/langchain/schema/output_parser.py index 5858730256..157c1cd5f0 100644 --- a/libs/langchain/langchain/schema/output_parser.py +++ b/libs/langchain/langchain/schema/output_parser.py @@ -34,7 +34,7 @@ class BaseLLMOutputParser(Serializable, Generic[T], ABC): """Abstract base class for parsing the outputs of a model.""" @abstractmethod - def parse_result(self, result: List[Generation]) -> T: + def parse_result(self, result: List[Generation], *, partial: bool = False) -> T: """Parse a list of candidate model Generations into a specific format. Args: @@ -45,7 +45,9 @@ class BaseLLMOutputParser(Serializable, Generic[T], ABC): Structured output. """ - async def aparse_result(self, result: List[Generation]) -> T: + async def aparse_result( + self, result: List[Generation], *, partial: bool = False + ) -> T: """Parse a list of candidate model Generations into a specific format. Args: @@ -205,7 +207,7 @@ class BaseOutputParser(BaseLLMOutputParser, Runnable[Union[str, BaseMessage], T] run_type="parser", ) - def parse_result(self, result: List[Generation]) -> T: + def parse_result(self, result: List[Generation], *, partial: bool = False) -> T: """Parse a list of candidate model Generations into a specific format. The return value is parsed from only the first Generation in the result, which @@ -231,7 +233,9 @@ class BaseOutputParser(BaseLLMOutputParser, Runnable[Union[str, BaseMessage], T] Structured output. """ - async def aparse_result(self, result: List[Generation]) -> T: + async def aparse_result( + self, result: List[Generation], *, partial: bool = False + ) -> T: """Parse a list of candidate model Generations into a specific format. The return value is parsed from only the first Generation in the result, which @@ -365,7 +369,7 @@ class BaseCumulativeTransformOutputParser(BaseTransformOutputParser[T]): else: acc_gen += chunk_gen - parsed = self.parse_result([acc_gen]) + parsed = self.parse_result([acc_gen], partial=True) if parsed is not None and parsed != prev_parsed: if self.diff: yield self._diff(prev_parsed, parsed) @@ -393,7 +397,7 @@ class BaseCumulativeTransformOutputParser(BaseTransformOutputParser[T]): else: acc_gen += chunk_gen - parsed = self.parse_result([acc_gen]) + parsed = self.parse_result([acc_gen], partial=True) if parsed is not None and parsed != prev_parsed: if self.diff: yield self._diff(prev_parsed, parsed)