mirror of https://github.com/hwchase17/langchain
anthropic[minor]: tool use (#20016)
parent
3aacd11846
commit
209de0a561
@ -0,0 +1,66 @@
|
||||
from typing import Any, List, Optional, Type, TypedDict, cast
|
||||
|
||||
from langchain_core.messages import BaseMessage
|
||||
from langchain_core.output_parsers import BaseGenerationOutputParser
|
||||
from langchain_core.outputs import ChatGeneration, Generation
|
||||
from langchain_core.pydantic_v1 import BaseModel
|
||||
|
||||
|
||||
class _ToolCall(TypedDict):
|
||||
name: str
|
||||
args: dict
|
||||
id: str
|
||||
index: int
|
||||
|
||||
|
||||
class ToolsOutputParser(BaseGenerationOutputParser):
|
||||
first_tool_only: bool = False
|
||||
args_only: bool = False
|
||||
pydantic_schemas: Optional[List[Type[BaseModel]]] = None
|
||||
|
||||
class Config:
|
||||
extra = "forbid"
|
||||
|
||||
def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any:
|
||||
"""Parse a list of candidate model Generations into a specific format.
|
||||
|
||||
Args:
|
||||
result: A list of Generations to be parsed. The Generations are assumed
|
||||
to be different candidate outputs for a single model input.
|
||||
|
||||
Returns:
|
||||
Structured output.
|
||||
"""
|
||||
if not result or not isinstance(result[0], ChatGeneration):
|
||||
return None if self.first_tool_only else []
|
||||
tool_calls: List = _extract_tool_calls(result[0].message)
|
||||
if self.pydantic_schemas:
|
||||
tool_calls = [self._pydantic_parse(tc) for tc in tool_calls]
|
||||
elif self.args_only:
|
||||
tool_calls = [tc["args"] for tc in tool_calls]
|
||||
else:
|
||||
pass
|
||||
|
||||
if self.first_tool_only:
|
||||
return tool_calls[0] if tool_calls else None
|
||||
else:
|
||||
return tool_calls
|
||||
|
||||
def _pydantic_parse(self, tool_call: _ToolCall) -> BaseModel:
|
||||
cls_ = {schema.__name__: schema for schema in self.pydantic_schemas or []}[
|
||||
tool_call["name"]
|
||||
]
|
||||
return cls_(**tool_call["args"])
|
||||
|
||||
|
||||
def _extract_tool_calls(msg: BaseMessage) -> List[_ToolCall]:
|
||||
if isinstance(msg.content, str):
|
||||
return []
|
||||
tool_calls = []
|
||||
for i, block in enumerate(cast(List[dict], msg.content)):
|
||||
if block["type"] != "tool_use":
|
||||
continue
|
||||
tool_calls.append(
|
||||
_ToolCall(name=block["name"], args=block["input"], id=block["id"], index=i)
|
||||
)
|
||||
return tool_calls
|
@ -0,0 +1,72 @@
|
||||
from typing import Any, List, Literal
|
||||
|
||||
from langchain_core.messages import AIMessage
|
||||
from langchain_core.outputs import ChatGeneration
|
||||
from langchain_core.pydantic_v1 import BaseModel
|
||||
|
||||
from langchain_anthropic.output_parsers import ToolsOutputParser
|
||||
|
||||
_CONTENT: List = [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "thought",
|
||||
},
|
||||
{"type": "tool_use", "input": {"bar": 0}, "id": "1", "name": "_Foo1"},
|
||||
{
|
||||
"type": "text",
|
||||
"text": "thought",
|
||||
},
|
||||
{"type": "tool_use", "input": {"baz": "a"}, "id": "2", "name": "_Foo2"},
|
||||
]
|
||||
|
||||
_RESULT: List = [ChatGeneration(message=AIMessage(_CONTENT))]
|
||||
|
||||
|
||||
class _Foo1(BaseModel):
|
||||
bar: int
|
||||
|
||||
|
||||
class _Foo2(BaseModel):
|
||||
baz: Literal["a", "b"]
|
||||
|
||||
|
||||
def test_tools_output_parser() -> None:
|
||||
output_parser = ToolsOutputParser()
|
||||
expected = [
|
||||
{"name": "_Foo1", "args": {"bar": 0}, "id": "1", "index": 1},
|
||||
{"name": "_Foo2", "args": {"baz": "a"}, "id": "2", "index": 3},
|
||||
]
|
||||
actual = output_parser.parse_result(_RESULT)
|
||||
assert expected == actual
|
||||
|
||||
|
||||
def test_tools_output_parser_args_only() -> None:
|
||||
output_parser = ToolsOutputParser(args_only=True)
|
||||
expected = [
|
||||
{"bar": 0},
|
||||
{"baz": "a"},
|
||||
]
|
||||
actual = output_parser.parse_result(_RESULT)
|
||||
assert expected == actual
|
||||
|
||||
expected = []
|
||||
actual = output_parser.parse_result([ChatGeneration(message=AIMessage(""))])
|
||||
assert expected == actual
|
||||
|
||||
|
||||
def test_tools_output_parser_first_tool_only() -> None:
|
||||
output_parser = ToolsOutputParser(first_tool_only=True)
|
||||
expected: Any = {"name": "_Foo1", "args": {"bar": 0}, "id": "1", "index": 1}
|
||||
actual = output_parser.parse_result(_RESULT)
|
||||
assert expected == actual
|
||||
|
||||
expected = None
|
||||
actual = output_parser.parse_result([ChatGeneration(message=AIMessage(""))])
|
||||
assert expected == actual
|
||||
|
||||
|
||||
def test_tools_output_parser_pydantic() -> None:
|
||||
output_parser = ToolsOutputParser(pydantic_schemas=[_Foo1, _Foo2])
|
||||
expected = [_Foo1(bar=0), _Foo2(baz="a")]
|
||||
actual = output_parser.parse_result(_RESULT)
|
||||
assert expected == actual
|
Loading…
Reference in New Issue