|
|
@ -1,4 +1,5 @@
|
|
|
|
import copy
|
|
|
|
import copy
|
|
|
|
|
|
|
|
import json
|
|
|
|
from typing import Any, Dict, List, Optional, Type, Union
|
|
|
|
from typing import Any, Dict, List, Optional, Type, Union
|
|
|
|
|
|
|
|
|
|
|
|
import jsonpatch
|
|
|
|
import jsonpatch
|
|
|
@ -22,7 +23,7 @@ class OutputFunctionsParser(BaseGenerationOutputParser[Any]):
|
|
|
|
args_only: bool = True
|
|
|
|
args_only: bool = True
|
|
|
|
"""Whether to only return the arguments to the function call."""
|
|
|
|
"""Whether to only return the arguments to the function call."""
|
|
|
|
|
|
|
|
|
|
|
|
def parse_result(self, result: List[Generation]) -> Any:
|
|
|
|
def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any:
|
|
|
|
generation = result[0]
|
|
|
|
generation = result[0]
|
|
|
|
if not isinstance(generation, ChatGeneration):
|
|
|
|
if not isinstance(generation, ChatGeneration):
|
|
|
|
raise OutputParserException(
|
|
|
|
raise OutputParserException(
|
|
|
@ -56,7 +57,7 @@ class JsonOutputFunctionsParser(BaseCumulativeTransformOutputParser[Any]):
|
|
|
|
def _diff(self, prev: Optional[Any], next: Any) -> Any:
|
|
|
|
def _diff(self, prev: Optional[Any], next: Any) -> Any:
|
|
|
|
return jsonpatch.make_patch(prev, next).patch
|
|
|
|
return jsonpatch.make_patch(prev, next).patch
|
|
|
|
|
|
|
|
|
|
|
|
def parse_result(self, result: List[Generation]) -> Any:
|
|
|
|
def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any:
|
|
|
|
if len(result) != 1:
|
|
|
|
if len(result) != 1:
|
|
|
|
raise OutputParserException(
|
|
|
|
raise OutputParserException(
|
|
|
|
f"Expected exactly one result, but got {len(result)}"
|
|
|
|
f"Expected exactly one result, but got {len(result)}"
|
|
|
@ -69,16 +70,46 @@ class JsonOutputFunctionsParser(BaseCumulativeTransformOutputParser[Any]):
|
|
|
|
message = generation.message
|
|
|
|
message = generation.message
|
|
|
|
try:
|
|
|
|
try:
|
|
|
|
function_call = message.additional_kwargs["function_call"]
|
|
|
|
function_call = message.additional_kwargs["function_call"]
|
|
|
|
except KeyError:
|
|
|
|
except KeyError as exc:
|
|
|
|
return None
|
|
|
|
if partial:
|
|
|
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
raise OutputParserException(f"Could not parse function call: {exc}")
|
|
|
|
try:
|
|
|
|
try:
|
|
|
|
if self.args_only:
|
|
|
|
if partial:
|
|
|
|
return parse_partial_json(function_call["arguments"])
|
|
|
|
if self.args_only:
|
|
|
|
|
|
|
|
return parse_partial_json(
|
|
|
|
|
|
|
|
function_call["arguments"], strict=self.strict
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
return {
|
|
|
|
|
|
|
|
**function_call,
|
|
|
|
|
|
|
|
"arguments": parse_partial_json(
|
|
|
|
|
|
|
|
function_call["arguments"], strict=self.strict
|
|
|
|
|
|
|
|
),
|
|
|
|
|
|
|
|
}
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
return {
|
|
|
|
if self.args_only:
|
|
|
|
**function_call,
|
|
|
|
try:
|
|
|
|
"arguments": parse_partial_json(function_call["arguments"]),
|
|
|
|
return json.loads(
|
|
|
|
}
|
|
|
|
function_call["arguments"], strict=self.strict
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
except (json.JSONDecodeError, TypeError) as exc:
|
|
|
|
|
|
|
|
raise OutputParserException(
|
|
|
|
|
|
|
|
f"Could not parse function call data: {exc}"
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
|
|
|
return {
|
|
|
|
|
|
|
|
**function_call,
|
|
|
|
|
|
|
|
"arguments": json.loads(
|
|
|
|
|
|
|
|
function_call["arguments"], strict=self.strict
|
|
|
|
|
|
|
|
),
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
except (json.JSONDecodeError, TypeError) as exc:
|
|
|
|
|
|
|
|
raise OutputParserException(
|
|
|
|
|
|
|
|
f"Could not parse function call data: {exc}"
|
|
|
|
|
|
|
|
)
|
|
|
|
except KeyError:
|
|
|
|
except KeyError:
|
|
|
|
return None
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
|
@ -94,7 +125,7 @@ class JsonKeyOutputFunctionsParser(JsonOutputFunctionsParser):
|
|
|
|
key_name: str
|
|
|
|
key_name: str
|
|
|
|
"""The name of the key to return."""
|
|
|
|
"""The name of the key to return."""
|
|
|
|
|
|
|
|
|
|
|
|
def parse_result(self, result: List[Generation]) -> Any:
|
|
|
|
def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any:
|
|
|
|
res = super().parse_result(result)
|
|
|
|
res = super().parse_result(result)
|
|
|
|
return res[self.key_name]
|
|
|
|
return res[self.key_name]
|
|
|
|
|
|
|
|
|
|
|
@ -119,7 +150,7 @@ class PydanticOutputFunctionsParser(OutputFunctionsParser):
|
|
|
|
)
|
|
|
|
)
|
|
|
|
return values
|
|
|
|
return values
|
|
|
|
|
|
|
|
|
|
|
|
def parse_result(self, result: List[Generation]) -> Any:
|
|
|
|
def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any:
|
|
|
|
_result = super().parse_result(result)
|
|
|
|
_result = super().parse_result(result)
|
|
|
|
if self.args_only:
|
|
|
|
if self.args_only:
|
|
|
|
pydantic_args = self.pydantic_schema.parse_raw(_result) # type: ignore
|
|
|
|
pydantic_args = self.pydantic_schema.parse_raw(_result) # type: ignore
|
|
|
@ -136,6 +167,6 @@ class PydanticAttrOutputFunctionsParser(PydanticOutputFunctionsParser):
|
|
|
|
attr_name: str
|
|
|
|
attr_name: str
|
|
|
|
"""The name of the attribute to return."""
|
|
|
|
"""The name of the attribute to return."""
|
|
|
|
|
|
|
|
|
|
|
|
def parse_result(self, result: List[Generation]) -> Any:
|
|
|
|
def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any:
|
|
|
|
result = super().parse_result(result)
|
|
|
|
result = super().parse_result(result)
|
|
|
|
return getattr(result, self.attr_name)
|
|
|
|
return getattr(result, self.attr_name)
|
|
|
|