core[patch]: docstrings output_parsers (#23825)

Added missed docstrings. Formatted docstrings to the consistent form.
This commit is contained in:
Leonid Ganeline 2024-07-03 11:27:40 -07:00 committed by GitHub
parent 26cee2e878
commit 55f6f91f17
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 387 additions and 20 deletions

View File

@ -38,6 +38,8 @@ class BaseLLMOutputParser(Generic[T], ABC):
Args:
result: A list of Generations to be parsed. The Generations are assumed
to be different candidate outputs for a single model input.
partial: Whether to parse the output as a partial result. This is useful
for parsers that can parse partial results. Default is False.
Returns:
Structured output.
@ -46,11 +48,13 @@ class BaseLLMOutputParser(Generic[T], ABC):
async def aparse_result(
self, result: List[Generation], *, partial: bool = False
) -> T:
"""Parse a list of candidate model Generations into a specific format.
"""Async parse a list of candidate model Generations into a specific format.
Args:
result: A list of Generations to be parsed. The Generations are assumed
to be different candidate outputs for a single model input.
partial: Whether to parse the output as a partial result. This is useful
for parsers that can parse partial results. Default is False.
Returns:
Structured output.
@ -65,10 +69,12 @@ class BaseGenerationOutputParser(
@property
def InputType(self) -> Any:
"""Return the input type for the parser."""
return Union[str, AnyMessage]
@property
def OutputType(self) -> Type[T]:
"""Return the output type for the parser."""
# even though mypy complains this isn't valid,
# it is good enough for pydantic to build the schema from
return T # type: ignore[misc]
@ -148,10 +154,18 @@ class BaseOutputParser(
@property
def InputType(self) -> Any:
"""Return the input type for the parser."""
return Union[str, AnyMessage]
@property
def OutputType(self) -> Type[T]:
"""Return the output type for the parser.
This property is inferred from the first type argument of the class.
Raises:
TypeError: If the class doesn't have an inferable OutputType.
"""
for cls in self.__class__.__orig_bases__: # type: ignore[attr-defined]
type_args = get_args(cls)
if type_args and len(type_args) == 1:
@ -214,6 +228,8 @@ class BaseOutputParser(
Args:
result: A list of Generations to be parsed. The Generations are assumed
to be different candidate outputs for a single model input.
partial: Whether to parse the output as a partial result. This is useful
for parsers that can parse partial results. Default is False.
Returns:
Structured output.
@ -234,7 +250,7 @@ class BaseOutputParser(
async def aparse_result(
self, result: List[Generation], *, partial: bool = False
) -> T:
"""Parse a list of candidate model Generations into a specific format.
"""Async parse a list of candidate model Generations into a specific format.
The return value is parsed from only the first Generation in the result, which
is assumed to be the highest-likelihood Generation.
@ -242,6 +258,8 @@ class BaseOutputParser(
Args:
result: A list of Generations to be parsed. The Generations are assumed
to be different candidate outputs for a single model input.
partial: Whether to parse the output as a partial result. This is useful
for parsers that can parse partial results. Default is False.
Returns:
Structured output.
@ -249,7 +267,7 @@ class BaseOutputParser(
return await run_in_executor(None, self.parse_result, result, partial=partial)
async def aparse(self, text: str) -> T:
"""Parse a single string model output into some structure.
"""Async parse a single string model output into some structure.
Args:
text: String output of a language model.
@ -272,7 +290,7 @@ class BaseOutputParser(
prompt: Input PromptValue.
Returns:
Structured output
Structured output.
"""
return self.parse(completion)

View File

@ -41,6 +41,8 @@ class JsonOutputParser(BaseCumulativeTransformOutputParser[Any]):
"""
pydantic_object: Optional[Type[TBaseModel]] = None # type: ignore
"""The Pydantic object to use for validation.
If None, no validation is performed."""
def _diff(self, prev: Optional[Any], next: Any) -> Any:
return jsonpatch.make_patch(prev, next).patch
@ -54,6 +56,22 @@ class JsonOutputParser(BaseCumulativeTransformOutputParser[Any]):
return pydantic_object.schema()
def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any:
"""Parse the result of an LLM call to a JSON object.
Args:
result: The result of the LLM call.
partial: Whether to parse partial JSON objects.
If True, the output will be a JSON object containing
all the keys that have been returned so far.
If False, the output will be the full JSON object.
Default is False.
Returns:
The parsed JSON object.
Raises:
OutputParserException: If the output is not valid JSON.
"""
text = result[0].text
text = text.strip()
if partial:
@ -69,9 +87,22 @@ class JsonOutputParser(BaseCumulativeTransformOutputParser[Any]):
raise OutputParserException(msg, llm_output=text) from e
def parse(self, text: str) -> Any:
"""Parse the output of an LLM call to a JSON object.
Args:
text: The output of the LLM call.
Returns:
The parsed JSON object.
"""
return self.parse_result([Generation(text=text)])
def get_format_instructions(self) -> str:
"""Return the format instructions for the JSON output.
Returns:
The format instructions for the JSON output.
"""
if self.pydantic_object is None:
return "Return a JSON object."
else:

View File

@ -12,7 +12,15 @@ T = TypeVar("T")
def droplastn(iter: Iterator[T], n: int) -> Iterator[T]:
"""Drop the last n elements of an iterator."""
"""Drop the last n elements of an iterator.
Args:
iter: The iterator to drop elements from.
n: The number of elements to drop.
Yields:
The elements of the iterator, except the last n elements.
"""
buffer: Deque[T] = deque()
for item in iter:
buffer.append(item)
@ -29,10 +37,24 @@ class ListOutputParser(BaseTransformOutputParser[List[str]]):
@abstractmethod
def parse(self, text: str) -> List[str]:
"""Parse the output of an LLM call."""
"""Parse the output of an LLM call.
Args:
text: The output of an LLM call.
Returns:
A list of strings.
"""
def parse_iter(self, text: str) -> Iterator[re.Match]:
"""Parse the output of an LLM call."""
"""Parse the output of an LLM call.
Args:
text: The output of an LLM call.
Yields:
A match object for each part of the output.
"""
raise NotImplementedError
def _transform(
@ -105,21 +127,36 @@ class CommaSeparatedListOutputParser(ListOutputParser):
@classmethod
def is_lc_serializable(cls) -> bool:
"""Check if the langchain object is serializable.
Returns True."""
return True
@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
"""Get the namespace of the langchain object.
Returns:
A list of strings.
Default is ["langchain", "output_parsers", "list"].
"""
return ["langchain", "output_parsers", "list"]
def get_format_instructions(self) -> str:
"""Return the format instructions for the comma-separated list output."""
return (
"Your response should be a list of comma separated values, "
"eg: `foo, bar, baz` or `foo,bar,baz`"
)
def parse(self, text: str) -> List[str]:
"""Parse the output of an LLM call."""
"""Parse the output of an LLM call.
Args:
text: The output of an LLM call.
Returns:
A list of strings.
"""
return [part.strip() for part in text.split(",")]
@property
@ -131,6 +168,7 @@ class NumberedListOutputParser(ListOutputParser):
"""Parse a numbered list."""
pattern: str = r"\d+\.\s([^\n]+)"
"""The pattern to match a numbered list item."""
def get_format_instructions(self) -> str:
return (
@ -139,11 +177,25 @@ class NumberedListOutputParser(ListOutputParser):
)
def parse(self, text: str) -> List[str]:
"""Parse the output of an LLM call."""
"""Parse the output of an LLM call.
Args:
text: The output of an LLM call.
Returns:
A list of strings.
"""
return re.findall(self.pattern, text)
def parse_iter(self, text: str) -> Iterator[re.Match]:
"""Parse the output of an LLM call."""
"""Parse the output of an LLM call.
Args:
text: The output of an LLM call.
Yields:
A match object for each part of the output.
"""
return re.finditer(self.pattern, text)
@property
@ -152,19 +204,35 @@ class NumberedListOutputParser(ListOutputParser):
class MarkdownListOutputParser(ListOutputParser):
"""Parse a markdown list."""
"""Parse a Markdown list."""
pattern: str = r"^\s*[-*]\s([^\n]+)$"
"""The pattern to match a Markdown list item."""
def get_format_instructions(self) -> str:
"""Return the format instructions for the Markdown list output."""
return "Your response should be a markdown list, " "eg: `- foo\n- bar\n- baz`"
def parse(self, text: str) -> List[str]:
"""Parse the output of an LLM call."""
"""Parse the output of an LLM call.
Args:
text: The output of an LLM call.
Returns:
A list of strings.
"""
return re.findall(self.pattern, text, re.MULTILINE)
def parse_iter(self, text: str) -> Iterator[re.Match]:
"""Parse the output of an LLM call."""
"""Parse the output of an LLM call.
Args:
text: The output of an LLM call.
Yields:
A match object for each part of the output.
"""
return re.finditer(self.pattern, text, re.MULTILINE)
@property

View File

@ -21,6 +21,18 @@ class OutputFunctionsParser(BaseGenerationOutputParser[Any]):
"""Whether to only return the arguments to the function call."""
def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any:
"""Parse the result of an LLM call to a JSON object.
Args:
result: The result of the LLM call.
partial: Whether to parse partial JSON objects. Default is False.
Returns:
The parsed JSON object.
Raises:
OutputParserException: If the output is not valid JSON.
"""
generation = result[0]
if not isinstance(generation, ChatGeneration):
raise OutputParserException(
@ -59,6 +71,19 @@ class JsonOutputFunctionsParser(BaseCumulativeTransformOutputParser[Any]):
return jsonpatch.make_patch(prev, next).patch
def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any:
"""Parse the result of an LLM call to a JSON object.
Args:
result: The result of the LLM call.
partial: Whether to parse partial JSON objects. Default is False.
Returns:
The parsed JSON object.
Raises:
OutputParserException: If the output is not valid JSON.
"""
if len(result) != 1:
raise OutputParserException(
f"Expected exactly one result, but got {len(result)}"
@ -120,6 +145,14 @@ class JsonOutputFunctionsParser(BaseCumulativeTransformOutputParser[Any]):
# This method would be called by the default implementation of `parse_result`
# but we're overriding that method so it's not needed.
def parse(self, text: str) -> Any:
"""Parse the output of an LLM call to a JSON object.
Args:
text: The output of the LLM call.
Returns:
The parsed JSON object.
"""
raise NotImplementedError()
@ -130,6 +163,15 @@ class JsonKeyOutputFunctionsParser(JsonOutputFunctionsParser):
"""The name of the key to return."""
def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any:
"""Parse the result of an LLM call to a JSON object.
Args:
result: The result of the LLM call.
partial: Whether to parse partial JSON objects. Default is False.
Returns:
The parsed JSON object.
"""
res = super().parse_result(result, partial=partial)
if partial and res is None:
return None
@ -186,6 +228,17 @@ class PydanticOutputFunctionsParser(OutputFunctionsParser):
@root_validator(pre=True)
def validate_schema(cls, values: Dict) -> Dict:
"""Validate the pydantic schema.
Args:
values: The values to validate.
Returns:
The validated values.
Raises:
ValueError: If the schema is not a pydantic schema.
"""
schema = values["pydantic_schema"]
if "args_only" not in values:
values["args_only"] = isinstance(schema, type) and issubclass(
@ -199,6 +252,15 @@ class PydanticOutputFunctionsParser(OutputFunctionsParser):
return values
def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any:
"""Parse the result of an LLM call to a JSON object.
Args:
result: The result of the LLM call.
partial: Whether to parse partial JSON objects. Default is False.
Returns:
The parsed JSON object.
"""
_result = super().parse_result(result)
if self.args_only:
pydantic_args = self.pydantic_schema.parse_raw(_result) # type: ignore
@ -216,5 +278,14 @@ class PydanticAttrOutputFunctionsParser(PydanticOutputFunctionsParser):
"""The name of the attribute to return."""
def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any:
"""Parse the result of an LLM call to a JSON object.
Args:
result: The result of the LLM call.
partial: Whether to parse partial JSON objects. Default is False.
Returns:
The parsed JSON object.
"""
result = super().parse_result(result)
return getattr(result, self.attr_name)

View File

@ -18,7 +18,21 @@ def parse_tool_call(
strict: bool = False,
return_id: bool = True,
) -> Optional[Dict[str, Any]]:
"""Parse a single tool call."""
"""Parse a single tool call.
Args:
raw_tool_call: The raw tool call to parse.
partial: Whether to parse partial JSON. Default is False.
strict: Whether to allow non-JSON-compliant strings.
Default is False.
return_id: Whether to return the tool call id. Default is True.
Returns:
The parsed tool call.
Raises:
OutputParserException: If the tool call is not valid JSON.
"""
if "function" not in raw_tool_call:
return None
if partial:
@ -52,7 +66,15 @@ def make_invalid_tool_call(
raw_tool_call: Dict[str, Any],
error_msg: Optional[str],
) -> InvalidToolCall:
"""Create an InvalidToolCall from a raw tool call."""
"""Create an InvalidToolCall from a raw tool call.
Args:
raw_tool_call: The raw tool call.
error_msg: The error message.
Returns:
An InvalidToolCall instance with the error message.
"""
return InvalidToolCall(
name=raw_tool_call["function"]["name"],
args=raw_tool_call["function"]["arguments"],
@ -68,7 +90,21 @@ def parse_tool_calls(
strict: bool = False,
return_id: bool = True,
) -> List[Dict[str, Any]]:
"""Parse a list of tool calls."""
"""Parse a list of tool calls.
Args:
raw_tool_calls: The raw tool calls to parse.
partial: Whether to parse partial JSON. Default is False.
strict: Whether to allow non-JSON-compliant strings.
Default is False.
return_id: Whether to return the tool call id. Default is True.
Returns:
The parsed tool calls.
Raises:
OutputParserException: If any of the tool calls are not valid JSON.
"""
final_tools: List[Dict[str, Any]] = []
exceptions = []
for tool_call in raw_tool_calls:
@ -110,6 +146,23 @@ class JsonOutputToolsParser(BaseCumulativeTransformOutputParser[Any]):
"""
def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any:
"""Parse the result of an LLM call to a list of tool calls.
Args:
result: The result of the LLM call.
partial: Whether to parse partial JSON.
If True, the output will be a JSON object containing
all the keys that have been returned so far.
If False, the output will be the full JSON object.
Default is False.
Returns:
The parsed tool calls.
Raises:
OutputParserException: If the output is not valid JSON.
"""
generation = result[0]
if not isinstance(generation, ChatGeneration):
raise OutputParserException(
@ -141,6 +194,14 @@ class JsonOutputToolsParser(BaseCumulativeTransformOutputParser[Any]):
return tool_calls
def parse(self, text: str) -> Any:
"""Parse the output of an LLM call to a list of tool calls.
Args:
text: The output of the LLM call.
Returns:
The parsed tool calls.
"""
raise NotImplementedError()
@ -151,6 +212,19 @@ class JsonOutputKeyToolsParser(JsonOutputToolsParser):
"""The type of tools to return."""
def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any:
"""Parse the result of an LLM call to a list of tool calls.
Args:
result: The result of the LLM call.
partial: Whether to parse partial JSON.
If True, the output will be a JSON object containing
all the keys that have been returned so far.
If False, the output will be the full JSON object.
Default is False.
Returns:
The parsed tool calls.
"""
parsed_result = super().parse_result(result, partial=partial)
if self.first_tool_only:
@ -175,10 +249,27 @@ class PydanticToolsParser(JsonOutputToolsParser):
"""Parse tools from OpenAI response."""
tools: List[Type[BaseModel]]
"""The tools to parse."""
# TODO: Support more granular streaming of objects. Currently only streams once all
# Pydantic object fields are present.
def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any:
"""Parse the result of an LLM call to a list of Pydantic objects.
Args:
result: The result of the LLM call.
partial: Whether to parse partial JSON.
If True, the output will be a JSON object containing
all the keys that have been returned so far.
If False, the output will be the full JSON object.
Default is False.
Returns:
The parsed Pydantic objects.
Raises:
OutputParserException: If the output is not valid JSON.
"""
json_results = super().parse_result(result, partial=partial)
if not json_results:
return None if self.first_tool_only else []

View File

@ -57,13 +57,38 @@ class PydanticOutputParser(JsonOutputParser, Generic[TBaseModel]):
def parse_result(
self, result: List[Generation], *, partial: bool = False
) -> TBaseModel:
"""Parse the result of an LLM call to a pydantic object.
Args:
result: The result of the LLM call.
partial: Whether to parse partial JSON objects.
If True, the output will be a JSON object containing
all the keys that have been returned so far.
Defaults to False.
Returns:
The parsed pydantic object.
"""
json_object = super().parse_result(result)
return self._parse_obj(json_object)
def parse(self, text: str) -> TBaseModel:
"""Parse the output of an LLM call to a pydantic object.
Args:
text: The output of the LLM call.
Returns:
The parsed pydantic object.
"""
return super().parse(text)
def get_format_instructions(self) -> str:
"""Return the format instructions for the JSON output.
Returns:
The format instructions for the JSON output.
"""
# Copy schema to avoid altering original Pydantic schema.
schema = {k: v for k, v in self.pydantic_object.schema().items()}

View File

@ -47,6 +47,16 @@ class BaseTransformOutputParser(BaseOutputParser[T]):
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> Iterator[T]:
"""Transform the input into the output format.
Args:
input: The input to transform.
config: The configuration to use for the transformation.
kwargs: Additional keyword arguments.
Yields:
The transformed output.
"""
yield from self._transform_stream_with_config(
input, self._transform, config, run_type="parser"
)
@ -57,6 +67,16 @@ class BaseTransformOutputParser(BaseOutputParser[T]):
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> AsyncIterator[T]:
"""Async transform the input into the output format.
Args:
input: The input to transform.
config: The configuration to use for the transformation.
kwargs: Additional keyword arguments.
Yields:
The transformed output.
"""
async for chunk in self._atransform_stream_with_config(
input, self._atransform, config, run_type="parser"
):
@ -73,7 +93,15 @@ class BaseCumulativeTransformOutputParser(BaseTransformOutputParser[T]):
def _diff(self, prev: Optional[T], next: T) -> T:
"""Convert parsed outputs into a diff format. The semantics of this are
up to the output parser."""
up to the output parser.
Args:
prev: The previous parsed output.
next: The current parsed output.
Returns:
The diff between the previous and current parsed output.
"""
raise NotImplementedError()
def _transform(self, input: Iterator[Union[str, BaseMessage]]) -> Iterator[Any]:

View File

@ -38,6 +38,10 @@ class _StreamingParser:
Args:
parser: Parser to use for XML parsing. Can be either 'defusedxml' or 'xml'.
See documentation in XMLOutputParser for more information.
Raises:
ImportError: If defusedxml is not installed and the defusedxml
parser is requested.
"""
if parser == "defusedxml":
try:
@ -66,6 +70,9 @@ class _StreamingParser:
Yields:
AddableDict: A dictionary representing the parsed XML element.
Raises:
xml.etree.ElementTree.ParseError: If the XML is not well-formed.
"""
if isinstance(chunk, BaseMessage):
# extract text
@ -116,7 +123,13 @@ class _StreamingParser:
raise
def close(self) -> None:
"""Close the parser."""
"""Close the parser.
This should be called after all chunks have been parsed.
Raises:
xml.etree.ElementTree.ParseError: If the XML is not well-formed.
"""
try:
self.pull_parser.close()
except xml.etree.ElementTree.ParseError:
@ -153,9 +166,23 @@ class XMLOutputParser(BaseTransformOutputParser):
"""
def get_format_instructions(self) -> str:
"""Return the format instructions for the XML output."""
return XML_FORMAT_INSTRUCTIONS.format(tags=self.tags)
def parse(self, text: str) -> Dict[str, Union[str, List[Any]]]:
"""Parse the output of an LLM call.
Args:
text: The output of an LLM call.
Returns:
A dictionary representing the parsed XML.
Raises:
OutputParserException: If the XML is not well-formed.
ImportError: If defusedxml is not installed and the defusedxml
parser is requested.
"""
# Try to find XML string within triple backticks
# Imports are temporarily placed here to avoid issue with caching on CI
# likely if you're reading this you can move them to the top of the file
@ -227,7 +254,15 @@ class XMLOutputParser(BaseTransformOutputParser):
def nested_element(path: List[str], elem: ET.Element) -> Any:
"""Get nested element from path."""
"""Get nested element from path.
Args:
path: The path to the element.
elem: The element to extract.
Returns:
The nested element.
"""
if len(path) == 0:
return AddableDict({elem.tag: elem.text})
else: