mirror of
https://github.com/hwchase17/langchain
synced 2024-11-10 01:10:59 +00:00
9514bc4d67
core[minor], langchain[patch], openai[minor], anthropic[minor], fireworks[minor], groq[minor], mistralai[minor] ```python class ToolCall(TypedDict): name: str args: Dict[str, Any] id: Optional[str] class InvalidToolCall(TypedDict): name: Optional[str] args: Optional[str] id: Optional[str] error: Optional[str] class ToolCallChunk(TypedDict): name: Optional[str] args: Optional[str] id: Optional[str] index: Optional[int] class AIMessage(BaseMessage): ... tool_calls: List[ToolCall] = [] invalid_tool_calls: List[InvalidToolCall] = [] ... class AIMessageChunk(AIMessage, BaseMessageChunk): ... tool_call_chunks: Optional[List[ToolCallChunk]] = None ... ``` Important considerations: - Parsing logic occurs within different providers; - ~Changing output type is a breaking change for anyone doing explicit type checking;~ - ~Langsmith rendering will need to be updated: https://github.com/langchain-ai/langchainplus/pull/3561~ - ~Langserve will need to be updated~ - Adding chunks: - ~AIMessage + ToolCallsMessage = ToolCallsMessage if either has non-null .tool_calls.~ - Tool call chunks are appended, merging when having equal values of `index`. - additional_kwargs accumulate the normal way. - During streaming: - ~Messages can change types (e.g., from AIMessageChunk to AIToolCallsMessageChunk)~ - Output parsers parse additional_kwargs (during .invoke they read off tool calls). Packages outside of `partners/`: - https://github.com/langchain-ai/langchain-cohere/pull/7 - https://github.com/langchain-ai/langchain-google/pull/123/files --------- Co-authored-by: Chester Curme <chester.curme@gmail.com>
70 lines
2.4 KiB
Python
70 lines
2.4 KiB
Python
from typing import Any, List, Optional, Type
|
|
|
|
from langchain_core.messages import ToolCall
|
|
from langchain_core.output_parsers import BaseGenerationOutputParser
|
|
from langchain_core.outputs import ChatGeneration, Generation
|
|
from langchain_core.pydantic_v1 import BaseModel
|
|
|
|
|
|
class ToolsOutputParser(BaseGenerationOutputParser):
|
|
first_tool_only: bool = False
|
|
args_only: bool = False
|
|
pydantic_schemas: Optional[List[Type[BaseModel]]] = None
|
|
|
|
class Config:
|
|
extra = "forbid"
|
|
|
|
def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any:
|
|
"""Parse a list of candidate model Generations into a specific format.
|
|
|
|
Args:
|
|
result: A list of Generations to be parsed. The Generations are assumed
|
|
to be different candidate outputs for a single model input.
|
|
|
|
Returns:
|
|
Structured output.
|
|
"""
|
|
if not result or not isinstance(result[0], ChatGeneration):
|
|
return None if self.first_tool_only else []
|
|
message = result[0].message
|
|
if isinstance(message.content, str):
|
|
tool_calls: List = []
|
|
else:
|
|
content: List = message.content
|
|
_tool_calls = [dict(tc) for tc in extract_tool_calls(content)]
|
|
# Map tool call id to index
|
|
id_to_index = {
|
|
block["id"]: i
|
|
for i, block in enumerate(content)
|
|
if block["type"] == "tool_use"
|
|
}
|
|
tool_calls = [{**tc, "index": id_to_index[tc["id"]]} for tc in _tool_calls]
|
|
if self.pydantic_schemas:
|
|
tool_calls = [self._pydantic_parse(tc) for tc in tool_calls]
|
|
elif self.args_only:
|
|
tool_calls = [tc["args"] for tc in tool_calls]
|
|
else:
|
|
pass
|
|
|
|
if self.first_tool_only:
|
|
return tool_calls[0] if tool_calls else None
|
|
else:
|
|
return [tool_call for tool_call in tool_calls]
|
|
|
|
def _pydantic_parse(self, tool_call: dict) -> BaseModel:
|
|
cls_ = {schema.__name__: schema for schema in self.pydantic_schemas or []}[
|
|
tool_call["name"]
|
|
]
|
|
return cls_(**tool_call["args"])
|
|
|
|
|
|
def extract_tool_calls(content: List[dict]) -> List[ToolCall]:
|
|
tool_calls = []
|
|
for block in content:
|
|
if block["type"] != "tool_use":
|
|
continue
|
|
tool_calls.append(
|
|
ToolCall(name=block["name"], args=block["input"], id=block["id"])
|
|
)
|
|
return tool_calls
|