langchain/libs/partners/anthropic/langchain_anthropic/output_parsers.py

87 lines
3.1 KiB
Python

from typing import Any, List, Optional, Type, Union, cast
from langchain_core.messages import AIMessage, ToolCall
from langchain_core.output_parsers import BaseGenerationOutputParser
from langchain_core.outputs import ChatGeneration, Generation
from langchain_core.pydantic_v1 import BaseModel
class ToolsOutputParser(BaseGenerationOutputParser):
"""Output parser for tool calls."""
first_tool_only: bool = False
"""Whether to return only the first tool call."""
args_only: bool = False
"""Whether to return only the arguments of the tool calls."""
pydantic_schemas: Optional[List[Type[BaseModel]]] = None
"""Pydantic schemas to parse tool calls into."""
class Config:
extra = "forbid"
def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any:
"""Parse a list of candidate model Generations into a specific format.
Args:
result: A list of Generations to be parsed. The Generations are assumed
to be different candidate outputs for a single model input.
Returns:
Structured output.
"""
if not result or not isinstance(result[0], ChatGeneration):
return None if self.first_tool_only else []
message = cast(AIMessage, result[0].message)
tool_calls: List = [
dict(tc) for tc in _extract_tool_calls_from_message(message)
]
if isinstance(message.content, list):
# Map tool call id to index
id_to_index = {
block["id"]: i
for i, block in enumerate(message.content)
if isinstance(block, dict) and block["type"] == "tool_use"
}
tool_calls = [{**tc, "index": id_to_index[tc["id"]]} for tc in tool_calls]
if self.pydantic_schemas:
tool_calls = [self._pydantic_parse(tc) for tc in tool_calls]
elif self.args_only:
tool_calls = [tc["args"] for tc in tool_calls]
else:
pass
if self.first_tool_only:
return tool_calls[0] if tool_calls else None
else:
return [tool_call for tool_call in tool_calls]
def _pydantic_parse(self, tool_call: dict) -> BaseModel:
cls_ = {schema.__name__: schema for schema in self.pydantic_schemas or []}[
tool_call["name"]
]
return cls_(**tool_call["args"])
def _extract_tool_calls_from_message(message: AIMessage) -> List[ToolCall]:
"""Extract tool calls from a list of content blocks."""
if message.tool_calls:
return message.tool_calls
return extract_tool_calls(message.content)
def extract_tool_calls(content: Union[str, List[Union[str, dict]]]) -> List[ToolCall]:
"""Extract tool calls from a list of content blocks."""
if isinstance(content, list):
tool_calls = []
for block in content:
if isinstance(block, str):
continue
if block["type"] != "tool_use":
continue
tool_calls.append(
ToolCall(name=block["name"], args=block["input"], id=block["id"])
)
return tool_calls
else:
return []