langchain/libs/partners/anthropic/langchain_anthropic/output_parsers.py
Bagatur 9514bc4d67
core[minor], ...: add tool calls message (#18947)
core[minor], langchain[patch], openai[minor], anthropic[minor], fireworks[minor], groq[minor], mistralai[minor]

```python
class ToolCall(TypedDict):
    name: str
    args: Dict[str, Any]
    id: Optional[str]

class InvalidToolCall(TypedDict):
    name: Optional[str]
    args: Optional[str]
    id: Optional[str]
    error: Optional[str]

class ToolCallChunk(TypedDict):
    name: Optional[str]
    args: Optional[str]
    id: Optional[str]
    index: Optional[int]


class AIMessage(BaseMessage):
    ...
    tool_calls: List[ToolCall] = []
    invalid_tool_calls: List[InvalidToolCall] = []
    ...


class AIMessageChunk(AIMessage, BaseMessageChunk):
    ...
    tool_call_chunks: Optional[List[ToolCallChunk]] = None
    ...
```
Important considerations:
- Parsing logic occurs within different providers;
- ~Changing output type is a breaking change for anyone doing explicit
type checking;~
- ~Langsmith rendering will need to be updated:
https://github.com/langchain-ai/langchainplus/pull/3561~
- ~Langserve will need to be updated~
- Adding chunks:
- ~AIMessage + ToolCallsMessage = ToolCallsMessage if either has
non-null .tool_calls.~
- Tool call chunks are appended, merging when having equal values of
`index`.
  - additional_kwargs accumulate the normal way.
- During streaming:
- ~Messages can change types (e.g., from AIMessageChunk to
AIToolCallsMessageChunk)~
- Output parsers parse additional_kwargs (during .invoke they read off
tool calls).

Packages outside of `partners/`:
- https://github.com/langchain-ai/langchain-cohere/pull/7
- https://github.com/langchain-ai/langchain-google/pull/123/files

---------

Co-authored-by: Chester Curme <chester.curme@gmail.com>
2024-04-09 18:41:42 -05:00

70 lines
2.4 KiB
Python

from typing import Any, List, Optional, Type
from langchain_core.messages import ToolCall
from langchain_core.output_parsers import BaseGenerationOutputParser
from langchain_core.outputs import ChatGeneration, Generation
from langchain_core.pydantic_v1 import BaseModel
class ToolsOutputParser(BaseGenerationOutputParser):
first_tool_only: bool = False
args_only: bool = False
pydantic_schemas: Optional[List[Type[BaseModel]]] = None
class Config:
extra = "forbid"
def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any:
"""Parse a list of candidate model Generations into a specific format.
Args:
result: A list of Generations to be parsed. The Generations are assumed
to be different candidate outputs for a single model input.
Returns:
Structured output.
"""
if not result or not isinstance(result[0], ChatGeneration):
return None if self.first_tool_only else []
message = result[0].message
if isinstance(message.content, str):
tool_calls: List = []
else:
content: List = message.content
_tool_calls = [dict(tc) for tc in extract_tool_calls(content)]
# Map tool call id to index
id_to_index = {
block["id"]: i
for i, block in enumerate(content)
if block["type"] == "tool_use"
}
tool_calls = [{**tc, "index": id_to_index[tc["id"]]} for tc in _tool_calls]
if self.pydantic_schemas:
tool_calls = [self._pydantic_parse(tc) for tc in tool_calls]
elif self.args_only:
tool_calls = [tc["args"] for tc in tool_calls]
else:
pass
if self.first_tool_only:
return tool_calls[0] if tool_calls else None
else:
return [tool_call for tool_call in tool_calls]
def _pydantic_parse(self, tool_call: dict) -> BaseModel:
cls_ = {schema.__name__: schema for schema in self.pydantic_schemas or []}[
tool_call["name"]
]
return cls_(**tool_call["args"])
def extract_tool_calls(content: List[dict]) -> List[ToolCall]:
tool_calls = []
for block in content:
if block["type"] != "tool_use":
continue
tool_calls.append(
ToolCall(name=block["name"], args=block["input"], id=block["id"])
)
return tool_calls