Keep exceptions when not in streaming mode

pull/11193/head
Nuno Campos 10 months ago
parent 1f30e25681
commit aa8b4120a8

@ -75,7 +75,9 @@ class OpenAIFunctionsAgentOutputParser(AgentOutputParser):
return_values={"output": message.content}, log=message.content
)
def parse_result(self, result: List[Generation]) -> Union[AgentAction, AgentFinish]:
def parse_result(
self, result: List[Generation], *, partial: bool = False
) -> Union[AgentAction, AgentFinish]:
if not isinstance(result[0], ChatGeneration):
raise ValueError("This output parser only works on ChatGeneration output")
message = result[0].message

@ -1,4 +1,5 @@
import copy
import json
from typing import Any, Dict, List, Optional, Type, Union
import jsonpatch
@ -22,7 +23,7 @@ class OutputFunctionsParser(BaseGenerationOutputParser[Any]):
args_only: bool = True
"""Whether to only return the arguments to the function call."""
def parse_result(self, result: List[Generation]) -> Any:
def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any:
generation = result[0]
if not isinstance(generation, ChatGeneration):
raise OutputParserException(
@ -56,7 +57,7 @@ class JsonOutputFunctionsParser(BaseCumulativeTransformOutputParser[Any]):
def _diff(self, prev: Optional[Any], next: Any) -> Any:
return jsonpatch.make_patch(prev, next).patch
def parse_result(self, result: List[Generation]) -> Any:
def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any:
if len(result) != 1:
raise OutputParserException(
f"Expected exactly one result, but got {len(result)}"
@ -69,16 +70,46 @@ class JsonOutputFunctionsParser(BaseCumulativeTransformOutputParser[Any]):
message = generation.message
try:
function_call = message.additional_kwargs["function_call"]
except KeyError:
return None
except KeyError as exc:
if partial:
return None
else:
raise OutputParserException(f"Could not parse function call: {exc}")
try:
if self.args_only:
return parse_partial_json(function_call["arguments"])
if partial:
if self.args_only:
return parse_partial_json(
function_call["arguments"], strict=self.strict
)
else:
return {
**function_call,
"arguments": parse_partial_json(
function_call["arguments"], strict=self.strict
),
}
else:
return {
**function_call,
"arguments": parse_partial_json(function_call["arguments"]),
}
if self.args_only:
try:
return json.loads(
function_call["arguments"], strict=self.strict
)
except (json.JSONDecodeError, TypeError) as exc:
raise OutputParserException(
f"Could not parse function call data: {exc}"
)
else:
try:
return {
**function_call,
"arguments": json.loads(
function_call["arguments"], strict=self.strict
),
}
except (json.JSONDecodeError, TypeError) as exc:
raise OutputParserException(
f"Could not parse function call data: {exc}"
)
except KeyError:
return None
@ -94,7 +125,7 @@ class JsonKeyOutputFunctionsParser(JsonOutputFunctionsParser):
key_name: str
"""The name of the key to return."""
def parse_result(self, result: List[Generation]) -> Any:
def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any:
res = super().parse_result(result)
return res[self.key_name]
@ -119,7 +150,7 @@ class PydanticOutputFunctionsParser(OutputFunctionsParser):
)
return values
def parse_result(self, result: List[Generation]) -> Any:
def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any:
_result = super().parse_result(result)
if self.args_only:
pydantic_args = self.pydantic_schema.parse_raw(_result) # type: ignore
@ -136,6 +167,6 @@ class PydanticAttrOutputFunctionsParser(PydanticOutputFunctionsParser):
attr_name: str
"""The name of the attribute to return."""
def parse_result(self, result: List[Generation]) -> Any:
def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any:
result = super().parse_result(result)
return getattr(result, self.attr_name)

@ -34,7 +34,7 @@ class BaseLLMOutputParser(Serializable, Generic[T], ABC):
"""Abstract base class for parsing the outputs of a model."""
@abstractmethod
def parse_result(self, result: List[Generation]) -> T:
def parse_result(self, result: List[Generation], *, partial: bool = False) -> T:
"""Parse a list of candidate model Generations into a specific format.
Args:
@ -45,7 +45,9 @@ class BaseLLMOutputParser(Serializable, Generic[T], ABC):
Structured output.
"""
async def aparse_result(self, result: List[Generation]) -> T:
async def aparse_result(
self, result: List[Generation], *, partial: bool = False
) -> T:
"""Parse a list of candidate model Generations into a specific format.
Args:
@ -205,7 +207,7 @@ class BaseOutputParser(BaseLLMOutputParser, Runnable[Union[str, BaseMessage], T]
run_type="parser",
)
def parse_result(self, result: List[Generation]) -> T:
def parse_result(self, result: List[Generation], *, partial: bool = False) -> T:
"""Parse a list of candidate model Generations into a specific format.
The return value is parsed from only the first Generation in the result, which
@ -231,7 +233,9 @@ class BaseOutputParser(BaseLLMOutputParser, Runnable[Union[str, BaseMessage], T]
Structured output.
"""
async def aparse_result(self, result: List[Generation]) -> T:
async def aparse_result(
self, result: List[Generation], *, partial: bool = False
) -> T:
"""Parse a list of candidate model Generations into a specific format.
The return value is parsed from only the first Generation in the result, which
@ -365,7 +369,7 @@ class BaseCumulativeTransformOutputParser(BaseTransformOutputParser[T]):
else:
acc_gen += chunk_gen
parsed = self.parse_result([acc_gen])
parsed = self.parse_result([acc_gen], partial=True)
if parsed is not None and parsed != prev_parsed:
if self.diff:
yield self._diff(prev_parsed, parsed)
@ -393,7 +397,7 @@ class BaseCumulativeTransformOutputParser(BaseTransformOutputParser[T]):
else:
acc_gen += chunk_gen
parsed = self.parse_result([acc_gen])
parsed = self.parse_result([acc_gen], partial=True)
if parsed is not None and parsed != prev_parsed:
if self.diff:
yield self._diff(prev_parsed, parsed)

Loading…
Cancel
Save