mirror of
https://github.com/hwchase17/langchain
synced 2024-11-10 01:10:59 +00:00
core[minor], integrations...[patch]: Support ToolCall as Tool input and ToolMessage as Tool output (#24038)
Changes: - ToolCall, InvalidToolCall and ToolCallChunk can all accept a "type" parameter now - LLM integration packages add "type" to all the above - Tool supports ToolCall inputs that have "type" specified - Tool outputs ToolMessage when a ToolCall is passed as input - Tools can separately specify ToolMessage.content and ToolMessage.raw_output - Tools emit events for validation errors (using on_tool_error and on_tool_end) Example: ```python @tool("structured_api", response_format="content_and_raw_output") def _mock_structured_tool_with_raw_output( arg1: int, arg2: bool, arg3: Optional[dict] = None ) -> Tuple[str, dict]: """A Structured Tool""" return f"{arg1} {arg2}", {"arg1": arg1, "arg2": arg2, "arg3": arg3} def test_tool_call_input_tool_message_with_raw_output() -> None: tool_call: Dict = { "name": "structured_api", "args": {"arg1": 1, "arg2": True, "arg3": {"img": "base64string..."}}, "id": "123", "type": "tool_call", } expected = ToolMessage("1 True", raw_output=tool_call["args"], tool_call_id="123") tool = _mock_structured_tool_with_raw_output actual = tool.invoke(tool_call) assert actual == expected tool_call.pop("type") with pytest.raises(ValidationError): tool.invoke(tool_call) actual_content = tool.invoke(tool_call["args"]) assert actual_content == expected.content ``` --------- Co-authored-by: Erick Friis <erick@langchain.dev>
This commit is contained in:
parent
eeb996034b
commit
5fd1e67808
@ -39,7 +39,7 @@ def get_non_abstract_subclasses(cls: Type[BaseTool]) -> List[Type[BaseTool]]:
|
||||
def test_all_subclasses_accept_run_manager(cls: Type[BaseTool]) -> None:
|
||||
"""Test that tools defined in this repo accept a run manager argument."""
|
||||
# This wouldn't be necessary if the BaseTool had a strict API.
|
||||
if cls._run is not BaseTool._arun:
|
||||
if cls._run is not BaseTool._run:
|
||||
run_func = cls._run
|
||||
params = inspect.signature(run_func).parameters
|
||||
assert "run_manager" in params
|
||||
|
@ -1,7 +1,7 @@
|
||||
import json
|
||||
from typing import Any, Dict, List, Literal, Optional, Tuple, Union
|
||||
|
||||
from typing_extensions import TypedDict
|
||||
from typing_extensions import NotRequired, TypedDict
|
||||
|
||||
from langchain_core.messages.base import BaseMessage, BaseMessageChunk, merge_content
|
||||
from langchain_core.utils._merge import merge_dicts, merge_obj
|
||||
@ -146,6 +146,11 @@ class ToolCall(TypedDict):
|
||||
An identifier is needed to associate a tool call request with a tool
|
||||
call result in events when multiple concurrent tool calls are made.
|
||||
"""
|
||||
type: NotRequired[Literal["tool_call"]]
|
||||
|
||||
|
||||
def tool_call(*, name: str, args: Dict[str, Any], id: Optional[str]) -> ToolCall:
|
||||
return ToolCall(name=name, args=args, id=id, type="tool_call")
|
||||
|
||||
|
||||
class ToolCallChunk(TypedDict):
|
||||
@ -176,6 +181,19 @@ class ToolCallChunk(TypedDict):
|
||||
"""An identifier associated with the tool call."""
|
||||
index: Optional[int]
|
||||
"""The index of the tool call in a sequence."""
|
||||
type: NotRequired[Literal["tool_call_chunk"]]
|
||||
|
||||
|
||||
def tool_call_chunk(
|
||||
*,
|
||||
name: Optional[str] = None,
|
||||
args: Optional[str] = None,
|
||||
id: Optional[str] = None,
|
||||
index: Optional[int] = None,
|
||||
) -> ToolCallChunk:
|
||||
return ToolCallChunk(
|
||||
name=name, args=args, id=id, index=index, type="tool_call_chunk"
|
||||
)
|
||||
|
||||
|
||||
class InvalidToolCall(TypedDict):
|
||||
@ -193,6 +211,19 @@ class InvalidToolCall(TypedDict):
|
||||
"""An identifier associated with the tool call."""
|
||||
error: Optional[str]
|
||||
"""An error message associated with the tool call."""
|
||||
type: NotRequired[Literal["invalid_tool_call"]]
|
||||
|
||||
|
||||
def invalid_tool_call(
|
||||
*,
|
||||
name: Optional[str] = None,
|
||||
args: Optional[str] = None,
|
||||
id: Optional[str] = None,
|
||||
error: Optional[str] = None,
|
||||
) -> InvalidToolCall:
|
||||
return InvalidToolCall(
|
||||
name=name, args=args, id=id, error=error, type="invalid_tool_call"
|
||||
)
|
||||
|
||||
|
||||
def default_tool_parser(
|
||||
|
@ -5,6 +5,12 @@ from typing import Any, Dict, List, Optional, Type
|
||||
|
||||
from langchain_core.exceptions import OutputParserException
|
||||
from langchain_core.messages import AIMessage, InvalidToolCall
|
||||
from langchain_core.messages.tool import (
|
||||
invalid_tool_call,
|
||||
)
|
||||
from langchain_core.messages.tool import (
|
||||
tool_call as create_tool_call,
|
||||
)
|
||||
from langchain_core.output_parsers.transform import BaseCumulativeTransformOutputParser
|
||||
from langchain_core.outputs import ChatGeneration, Generation
|
||||
from langchain_core.pydantic_v1 import BaseModel, ValidationError
|
||||
@ -59,6 +65,7 @@ def parse_tool_call(
|
||||
}
|
||||
if return_id:
|
||||
parsed["id"] = raw_tool_call.get("id")
|
||||
parsed = create_tool_call(**parsed) # type: ignore
|
||||
return parsed
|
||||
|
||||
|
||||
@ -75,7 +82,7 @@ def make_invalid_tool_call(
|
||||
Returns:
|
||||
An InvalidToolCall instance with the error message.
|
||||
"""
|
||||
return InvalidToolCall(
|
||||
return invalid_tool_call(
|
||||
name=raw_tool_call["function"]["name"],
|
||||
args=raw_tool_call["function"]["arguments"],
|
||||
id=raw_tool_call.get("id"),
|
||||
|
@ -21,6 +21,7 @@ from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import inspect
|
||||
import json
|
||||
import textwrap
|
||||
import uuid
|
||||
import warnings
|
||||
@ -34,6 +35,7 @@ from typing import (
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
Literal,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
@ -42,7 +44,7 @@ from typing import (
|
||||
get_type_hints,
|
||||
)
|
||||
|
||||
from typing_extensions import Annotated, get_args, get_origin
|
||||
from typing_extensions import Annotated, cast, get_args, get_origin
|
||||
|
||||
from langchain_core._api import deprecated
|
||||
from langchain_core.callbacks import (
|
||||
@ -56,6 +58,7 @@ from langchain_core.callbacks.manager import (
|
||||
Callbacks,
|
||||
)
|
||||
from langchain_core.load.serializable import Serializable
|
||||
from langchain_core.messages.tool import ToolCall, ToolMessage
|
||||
from langchain_core.prompts import (
|
||||
BasePromptTemplate,
|
||||
PromptTemplate,
|
||||
@ -306,7 +309,7 @@ class ToolException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class BaseTool(RunnableSerializable[Union[str, Dict], Any]):
|
||||
class BaseTool(RunnableSerializable[Union[str, Dict, ToolCall], Any]):
|
||||
"""Interface LangChain tools must implement."""
|
||||
|
||||
def __init_subclass__(cls, **kwargs: Any) -> None:
|
||||
@ -378,6 +381,14 @@ class ChildTool(BaseTool):
|
||||
] = False
|
||||
"""Handle the content of the ValidationError thrown."""
|
||||
|
||||
response_format: Literal["content", "content_and_raw_output"] = "content"
|
||||
"""The tool response format.
|
||||
|
||||
If "content" then the output of the tool is interpreted as the contents of a
|
||||
ToolMessage. If "content_and_raw_output" then the output is expected to be a
|
||||
two-tuple corresponding to the (content, raw_output) of a ToolMessage.
|
||||
"""
|
||||
|
||||
class Config(Serializable.Config):
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
@ -410,46 +421,25 @@ class ChildTool(BaseTool):
|
||||
|
||||
def invoke(
|
||||
self,
|
||||
input: Union[str, Dict],
|
||||
input: Union[str, Dict, ToolCall],
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
config = ensure_config(config)
|
||||
return self.run(
|
||||
input,
|
||||
callbacks=config.get("callbacks"),
|
||||
tags=config.get("tags"),
|
||||
metadata=config.get("metadata"),
|
||||
run_name=config.get("run_name"),
|
||||
run_id=config.pop("run_id", None),
|
||||
config=config,
|
||||
**kwargs,
|
||||
)
|
||||
tool_input, kwargs = _prep_run_args(input, config, **kwargs)
|
||||
return self.run(tool_input, **kwargs)
|
||||
|
||||
async def ainvoke(
|
||||
self,
|
||||
input: Union[str, Dict],
|
||||
input: Union[str, Dict, ToolCall],
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
config = ensure_config(config)
|
||||
return await self.arun(
|
||||
input,
|
||||
callbacks=config.get("callbacks"),
|
||||
tags=config.get("tags"),
|
||||
metadata=config.get("metadata"),
|
||||
run_name=config.get("run_name"),
|
||||
run_id=config.pop("run_id", None),
|
||||
config=config,
|
||||
**kwargs,
|
||||
)
|
||||
tool_input, kwargs = _prep_run_args(input, config, **kwargs)
|
||||
return await self.arun(tool_input, **kwargs)
|
||||
|
||||
# --- Tool ---
|
||||
|
||||
def _parse_input(
|
||||
self,
|
||||
tool_input: Union[str, Dict],
|
||||
) -> Union[str, Dict[str, Any]]:
|
||||
def _parse_input(self, tool_input: Union[str, Dict]) -> Union[str, Dict[str, Any]]:
|
||||
"""Convert tool input to pydantic model."""
|
||||
input_args = self.args_schema
|
||||
if isinstance(tool_input, str):
|
||||
@ -465,7 +455,7 @@ class ChildTool(BaseTool):
|
||||
for k, v in result.dict().items()
|
||||
if k in tool_input
|
||||
}
|
||||
return tool_input
|
||||
return tool_input
|
||||
|
||||
@root_validator(pre=True)
|
||||
def raise_deprecation(cls, values: Dict) -> Dict:
|
||||
@ -479,30 +469,27 @@ class ChildTool(BaseTool):
|
||||
return values
|
||||
|
||||
@abstractmethod
|
||||
def _run(
|
||||
self,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
def _run(self, *args: Any, **kwargs: Any) -> Any:
|
||||
"""Use the tool.
|
||||
|
||||
Add run_manager: Optional[CallbackManagerForToolRun] = None
|
||||
to child implementations to enable tracing,
|
||||
to child implementations to enable tracing.
|
||||
"""
|
||||
|
||||
async def _arun(
|
||||
self,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
async def _arun(self, *args: Any, **kwargs: Any) -> Any:
|
||||
"""Use the tool asynchronously.
|
||||
|
||||
Add run_manager: Optional[AsyncCallbackManagerForToolRun] = None
|
||||
to child implementations to enable tracing,
|
||||
to child implementations to enable tracing.
|
||||
"""
|
||||
if kwargs.get("run_manager") and signature(self._run).parameters.get(
|
||||
"run_manager"
|
||||
):
|
||||
kwargs["run_manager"] = kwargs["run_manager"].get_sync()
|
||||
return await run_in_executor(None, self._run, *args, **kwargs)
|
||||
|
||||
def _to_args_and_kwargs(self, tool_input: Union[str, Dict]) -> Tuple[Tuple, Dict]:
|
||||
tool_input = self._parse_input(tool_input)
|
||||
# For backwards compatibility, if run_input is a string,
|
||||
# pass as a positional argument.
|
||||
if isinstance(tool_input, str):
|
||||
@ -523,24 +510,20 @@ class ChildTool(BaseTool):
|
||||
run_name: Optional[str] = None,
|
||||
run_id: Optional[uuid.UUID] = None,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
tool_call_id: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Run the tool."""
|
||||
if not self.verbose and verbose is not None:
|
||||
verbose_ = verbose
|
||||
else:
|
||||
verbose_ = self.verbose
|
||||
callback_manager = CallbackManager.configure(
|
||||
callbacks,
|
||||
self.callbacks,
|
||||
verbose_,
|
||||
self.verbose or bool(verbose),
|
||||
tags,
|
||||
self.tags,
|
||||
metadata,
|
||||
self.metadata,
|
||||
)
|
||||
# TODO: maybe also pass through run_manager is _run supports kwargs
|
||||
new_arg_supported = signature(self._run).parameters.get("run_manager")
|
||||
|
||||
run_manager = callback_manager.on_tool_start(
|
||||
{"name": self.name, "description": self.description},
|
||||
tool_input if isinstance(tool_input, str) else str(tool_input),
|
||||
@ -550,67 +533,52 @@ class ChildTool(BaseTool):
|
||||
# Inputs by definition should always be dicts.
|
||||
# For now, it's unclear whether this assumption is ever violated,
|
||||
# but if it is we will send a `None` value to the callback instead
|
||||
# And will need to address issue via a patch.
|
||||
inputs=None if isinstance(tool_input, str) else tool_input,
|
||||
# TODO: will need to address issue via a patch.
|
||||
inputs=tool_input if isinstance(tool_input, dict) else None,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
content = None
|
||||
raw_output = None
|
||||
error_to_raise: Union[Exception, KeyboardInterrupt, None] = None
|
||||
try:
|
||||
child_config = patch_config(
|
||||
config,
|
||||
callbacks=run_manager.get_child(),
|
||||
)
|
||||
child_config = patch_config(config, callbacks=run_manager.get_child())
|
||||
context = copy_context()
|
||||
context.run(_set_config_context, child_config)
|
||||
parsed_input = self._parse_input(tool_input)
|
||||
tool_args, tool_kwargs = self._to_args_and_kwargs(parsed_input)
|
||||
observation = (
|
||||
context.run(
|
||||
self._run, *tool_args, run_manager=run_manager, **tool_kwargs
|
||||
)
|
||||
if new_arg_supported
|
||||
else context.run(self._run, *tool_args, **tool_kwargs)
|
||||
)
|
||||
tool_args, tool_kwargs = self._to_args_and_kwargs(tool_input)
|
||||
if signature(self._run).parameters.get("run_manager"):
|
||||
tool_kwargs["run_manager"] = run_manager
|
||||
response = context.run(self._run, *tool_args, **tool_kwargs)
|
||||
if self.response_format == "content_and_raw_output":
|
||||
if not isinstance(response, tuple) or len(response) != 2:
|
||||
raise ValueError(
|
||||
"Since response_format='content_and_raw_output' "
|
||||
"a two-tuple of the message content and raw tool output is "
|
||||
f"expected. Instead generated response of type: "
|
||||
f"{type(response)}."
|
||||
)
|
||||
content, raw_output = response
|
||||
else:
|
||||
content = response
|
||||
except ValidationError as e:
|
||||
if not self.handle_validation_error:
|
||||
raise e
|
||||
elif isinstance(self.handle_validation_error, bool):
|
||||
observation = "Tool input validation error"
|
||||
elif isinstance(self.handle_validation_error, str):
|
||||
observation = self.handle_validation_error
|
||||
elif callable(self.handle_validation_error):
|
||||
observation = self.handle_validation_error(e)
|
||||
error_to_raise = e
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Got unexpected type of `handle_validation_error`. Expected bool, "
|
||||
f"str or callable. Received: {self.handle_validation_error}"
|
||||
)
|
||||
return observation
|
||||
content = _handle_validation_error(e, flag=self.handle_validation_error)
|
||||
except ToolException as e:
|
||||
if not self.handle_tool_error:
|
||||
run_manager.on_tool_error(e)
|
||||
raise e
|
||||
elif isinstance(self.handle_tool_error, bool):
|
||||
if e.args:
|
||||
observation = e.args[0]
|
||||
else:
|
||||
observation = "Tool execution error"
|
||||
elif isinstance(self.handle_tool_error, str):
|
||||
observation = self.handle_tool_error
|
||||
elif callable(self.handle_tool_error):
|
||||
observation = self.handle_tool_error(e)
|
||||
error_to_raise = e
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Got unexpected type of `handle_tool_error`. Expected bool, str "
|
||||
f"or callable. Received: {self.handle_tool_error}"
|
||||
)
|
||||
run_manager.on_tool_end(observation, color="red", name=self.name, **kwargs)
|
||||
return observation
|
||||
content = _handle_tool_error(e, flag=self.handle_tool_error)
|
||||
except (Exception, KeyboardInterrupt) as e:
|
||||
run_manager.on_tool_error(e)
|
||||
raise e
|
||||
else:
|
||||
run_manager.on_tool_end(observation, color=color, name=self.name, **kwargs)
|
||||
return observation
|
||||
error_to_raise = e
|
||||
|
||||
if error_to_raise:
|
||||
run_manager.on_tool_error(error_to_raise)
|
||||
raise error_to_raise
|
||||
output = _format_output(content, raw_output, tool_call_id)
|
||||
run_manager.on_tool_end(output, color=color, name=self.name, **kwargs)
|
||||
return output
|
||||
|
||||
async def arun(
|
||||
self,
|
||||
@ -625,99 +593,80 @@ class ChildTool(BaseTool):
|
||||
run_name: Optional[str] = None,
|
||||
run_id: Optional[uuid.UUID] = None,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
tool_call_id: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Run the tool asynchronously."""
|
||||
if not self.verbose and verbose is not None:
|
||||
verbose_ = verbose
|
||||
else:
|
||||
verbose_ = self.verbose
|
||||
callback_manager = AsyncCallbackManager.configure(
|
||||
callbacks,
|
||||
self.callbacks,
|
||||
verbose_,
|
||||
self.verbose or bool(verbose),
|
||||
tags,
|
||||
self.tags,
|
||||
metadata,
|
||||
self.metadata,
|
||||
)
|
||||
new_arg_supported = signature(self._arun).parameters.get("run_manager")
|
||||
run_manager = await callback_manager.on_tool_start(
|
||||
{"name": self.name, "description": self.description},
|
||||
tool_input if isinstance(tool_input, str) else str(tool_input),
|
||||
color=start_color,
|
||||
name=run_name,
|
||||
inputs=tool_input,
|
||||
run_id=run_id,
|
||||
# Inputs by definition should always be dicts.
|
||||
# For now, it's unclear whether this assumption is ever violated,
|
||||
# but if it is we will send a `None` value to the callback instead
|
||||
# TODO: will need to address issue via a patch.
|
||||
inputs=tool_input if isinstance(tool_input, dict) else None,
|
||||
**kwargs,
|
||||
)
|
||||
content = None
|
||||
raw_output = None
|
||||
error_to_raise: Optional[Union[Exception, KeyboardInterrupt]] = None
|
||||
try:
|
||||
parsed_input = self._parse_input(tool_input)
|
||||
# We then call the tool on the tool input to get an observation
|
||||
tool_args, tool_kwargs = self._to_args_and_kwargs(parsed_input)
|
||||
child_config = patch_config(
|
||||
config,
|
||||
callbacks=run_manager.get_child(),
|
||||
)
|
||||
tool_args, tool_kwargs = self._to_args_and_kwargs(tool_input)
|
||||
child_config = patch_config(config, callbacks=run_manager.get_child())
|
||||
context = copy_context()
|
||||
context.run(_set_config_context, child_config)
|
||||
coro = (
|
||||
context.run(
|
||||
self._arun, *tool_args, run_manager=run_manager, **tool_kwargs
|
||||
)
|
||||
if new_arg_supported
|
||||
else context.run(self._arun, *tool_args, **tool_kwargs)
|
||||
)
|
||||
if self.__class__._arun is BaseTool._arun or signature(
|
||||
self._arun
|
||||
).parameters.get("run_manager"):
|
||||
tool_kwargs["run_manager"] = run_manager
|
||||
coro = context.run(self._arun, *tool_args, **tool_kwargs)
|
||||
if accepts_context(asyncio.create_task):
|
||||
observation = await asyncio.create_task(coro, context=context) # type: ignore
|
||||
response = await asyncio.create_task(coro, context=context) # type: ignore
|
||||
else:
|
||||
observation = await coro
|
||||
|
||||
response = await coro
|
||||
if self.response_format == "content_and_raw_output":
|
||||
if not isinstance(response, tuple) or len(response) != 2:
|
||||
raise ValueError(
|
||||
"Since response_format='content_and_raw_output' "
|
||||
"a two-tuple of the message content and raw tool output is "
|
||||
f"expected. Instead generated response of type: "
|
||||
f"{type(response)}."
|
||||
)
|
||||
content, raw_output = response
|
||||
else:
|
||||
content = response
|
||||
except ValidationError as e:
|
||||
if not self.handle_validation_error:
|
||||
raise e
|
||||
elif isinstance(self.handle_validation_error, bool):
|
||||
observation = "Tool input validation error"
|
||||
elif isinstance(self.handle_validation_error, str):
|
||||
observation = self.handle_validation_error
|
||||
elif callable(self.handle_validation_error):
|
||||
observation = self.handle_validation_error(e)
|
||||
error_to_raise = e
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Got unexpected type of `handle_validation_error`. Expected bool, "
|
||||
f"str or callable. Received: {self.handle_validation_error}"
|
||||
)
|
||||
return observation
|
||||
content = _handle_validation_error(e, flag=self.handle_validation_error)
|
||||
except ToolException as e:
|
||||
if not self.handle_tool_error:
|
||||
await run_manager.on_tool_error(e)
|
||||
raise e
|
||||
elif isinstance(self.handle_tool_error, bool):
|
||||
if e.args:
|
||||
observation = e.args[0]
|
||||
else:
|
||||
observation = "Tool execution error"
|
||||
elif isinstance(self.handle_tool_error, str):
|
||||
observation = self.handle_tool_error
|
||||
elif callable(self.handle_tool_error):
|
||||
observation = self.handle_tool_error(e)
|
||||
error_to_raise = e
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Got unexpected type of `handle_tool_error`. Expected bool, str "
|
||||
f"or callable. Received: {self.handle_tool_error}"
|
||||
)
|
||||
await run_manager.on_tool_end(
|
||||
observation, color="red", name=self.name, **kwargs
|
||||
)
|
||||
return observation
|
||||
content = _handle_tool_error(e, flag=self.handle_tool_error)
|
||||
except (Exception, KeyboardInterrupt) as e:
|
||||
await run_manager.on_tool_error(e)
|
||||
raise e
|
||||
else:
|
||||
await run_manager.on_tool_end(
|
||||
observation, color=color, name=self.name, **kwargs
|
||||
)
|
||||
return observation
|
||||
error_to_raise = e
|
||||
|
||||
if error_to_raise:
|
||||
await run_manager.on_tool_error(error_to_raise)
|
||||
raise error_to_raise
|
||||
|
||||
output = _format_output(content, raw_output, tool_call_id)
|
||||
await run_manager.on_tool_end(output, color=color, name=self.name, **kwargs)
|
||||
return output
|
||||
|
||||
@deprecated("0.1.47", alternative="invoke", removal="0.3.0")
|
||||
def __call__(self, tool_input: str, callbacks: Callbacks = None) -> str:
|
||||
@ -738,7 +687,7 @@ class Tool(BaseTool):
|
||||
|
||||
async def ainvoke(
|
||||
self,
|
||||
input: Union[str, Dict],
|
||||
input: Union[str, Dict, ToolCall],
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
@ -780,17 +729,10 @@ class Tool(BaseTool):
|
||||
) -> Any:
|
||||
"""Use the tool."""
|
||||
if self.func:
|
||||
new_argument_supported = signature(self.func).parameters.get("callbacks")
|
||||
return (
|
||||
self.func(
|
||||
*args,
|
||||
callbacks=run_manager.get_child() if run_manager else None,
|
||||
**kwargs,
|
||||
)
|
||||
if new_argument_supported
|
||||
else self.func(*args, **kwargs)
|
||||
)
|
||||
raise NotImplementedError("Tool does not support sync")
|
||||
if run_manager and signature(self.func).parameters.get("callbacks"):
|
||||
kwargs["callbacks"] = run_manager.get_child()
|
||||
return self.func(*args, **kwargs)
|
||||
raise NotImplementedError("Tool does not support sync invocation.")
|
||||
|
||||
async def _arun(
|
||||
self,
|
||||
@ -800,26 +742,13 @@ class Tool(BaseTool):
|
||||
) -> Any:
|
||||
"""Use the tool asynchronously."""
|
||||
if self.coroutine:
|
||||
new_argument_supported = signature(self.coroutine).parameters.get(
|
||||
"callbacks"
|
||||
)
|
||||
return (
|
||||
await self.coroutine(
|
||||
*args,
|
||||
callbacks=run_manager.get_child() if run_manager else None,
|
||||
**kwargs,
|
||||
)
|
||||
if new_argument_supported
|
||||
else await self.coroutine(*args, **kwargs)
|
||||
)
|
||||
else:
|
||||
return await run_in_executor(
|
||||
None,
|
||||
self._run,
|
||||
run_manager=run_manager.get_sync() if run_manager else None,
|
||||
*args,
|
||||
**kwargs,
|
||||
)
|
||||
if run_manager and signature(self.coroutine).parameters.get("callbacks"):
|
||||
kwargs["callbacks"] = run_manager.get_child()
|
||||
return await self.coroutine(*args, **kwargs)
|
||||
|
||||
# NOTE: this code is unreachable since _arun is only called if coroutine is not
|
||||
# None.
|
||||
return await super()._arun(*args, run_manager=run_manager, **kwargs)
|
||||
|
||||
# TODO: this is for backwards compatibility, remove in future
|
||||
def __init__(
|
||||
@ -870,9 +799,10 @@ class StructuredTool(BaseTool):
|
||||
|
||||
# --- Runnable ---
|
||||
|
||||
# TODO: Is this needed?
|
||||
async def ainvoke(
|
||||
self,
|
||||
input: Union[str, Dict],
|
||||
input: Union[str, Dict, ToolCall],
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
@ -897,45 +827,26 @@ class StructuredTool(BaseTool):
|
||||
) -> Any:
|
||||
"""Use the tool."""
|
||||
if self.func:
|
||||
new_argument_supported = signature(self.func).parameters.get("callbacks")
|
||||
return (
|
||||
self.func(
|
||||
*args,
|
||||
callbacks=run_manager.get_child() if run_manager else None,
|
||||
**kwargs,
|
||||
)
|
||||
if new_argument_supported
|
||||
else self.func(*args, **kwargs)
|
||||
)
|
||||
raise NotImplementedError("Tool does not support sync")
|
||||
if run_manager and signature(self.func).parameters.get("callbacks"):
|
||||
kwargs["callbacks"] = run_manager.get_child()
|
||||
return self.func(*args, **kwargs)
|
||||
raise NotImplementedError("StructuredTool does not support sync invocation.")
|
||||
|
||||
async def _arun(
|
||||
self,
|
||||
*args: Any,
|
||||
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
) -> Any:
|
||||
"""Use the tool asynchronously."""
|
||||
if self.coroutine:
|
||||
new_argument_supported = signature(self.coroutine).parameters.get(
|
||||
"callbacks"
|
||||
)
|
||||
return (
|
||||
await self.coroutine(
|
||||
*args,
|
||||
callbacks=run_manager.get_child() if run_manager else None,
|
||||
**kwargs,
|
||||
)
|
||||
if new_argument_supported
|
||||
else await self.coroutine(*args, **kwargs)
|
||||
)
|
||||
return await run_in_executor(
|
||||
None,
|
||||
self._run,
|
||||
run_manager=run_manager.get_sync() if run_manager else None,
|
||||
*args,
|
||||
**kwargs,
|
||||
)
|
||||
if run_manager and signature(self.coroutine).parameters.get("callbacks"):
|
||||
kwargs["callbacks"] = run_manager.get_child()
|
||||
return await self.coroutine(*args, **kwargs)
|
||||
|
||||
# NOTE: this code is unreachable since _arun is only called if coroutine is not
|
||||
# None.
|
||||
return await super()._arun(*args, run_manager=run_manager, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def from_function(
|
||||
@ -947,6 +858,8 @@ class StructuredTool(BaseTool):
|
||||
return_direct: bool = False,
|
||||
args_schema: Optional[Type[BaseModel]] = None,
|
||||
infer_schema: bool = True,
|
||||
*,
|
||||
response_format: Literal["content", "content_and_raw_output"] = "content",
|
||||
parse_docstring: bool = False,
|
||||
error_on_invalid_docstring: bool = False,
|
||||
**kwargs: Any,
|
||||
@ -963,6 +876,10 @@ class StructuredTool(BaseTool):
|
||||
return_direct: Whether to return the result directly or as a callback
|
||||
args_schema: The schema of the tool's input arguments
|
||||
infer_schema: Whether to infer the schema from the function's signature
|
||||
response_format: The tool response format. If "content" then the output of
|
||||
the tool is interpreted as the contents of a ToolMessage. If
|
||||
"content_and_raw_output" then the output is expected to be a two-tuple
|
||||
corresponding to the (content, raw_output) of a ToolMessage.
|
||||
parse_docstring: if ``infer_schema`` and ``parse_docstring``, will attempt
|
||||
to parse parameter descriptions from Google Style function docstrings.
|
||||
error_on_invalid_docstring: if ``parse_docstring`` is provided, configures
|
||||
@ -1020,6 +937,7 @@ class StructuredTool(BaseTool):
|
||||
args_schema=_args_schema, # type: ignore[arg-type]
|
||||
description=description_,
|
||||
return_direct=return_direct,
|
||||
response_format=response_format,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@ -1029,6 +947,7 @@ def tool(
|
||||
return_direct: bool = False,
|
||||
args_schema: Optional[Type[BaseModel]] = None,
|
||||
infer_schema: bool = True,
|
||||
response_format: Literal["content", "content_and_raw_output"] = "content",
|
||||
parse_docstring: bool = False,
|
||||
error_on_invalid_docstring: bool = True,
|
||||
) -> Callable:
|
||||
@ -1042,6 +961,10 @@ def tool(
|
||||
infer_schema: Whether to infer the schema of the arguments from
|
||||
the function's signature. This also makes the resultant tool
|
||||
accept a dictionary input to its `run()` function.
|
||||
response_format: The tool response format. If "content" then the output of
|
||||
the tool is interpreted as the contents of a ToolMessage. If
|
||||
"content_and_raw_output" then the output is expected to be a two-tuple
|
||||
corresponding to the (content, raw_output) of a ToolMessage.
|
||||
parse_docstring: if ``infer_schema`` and ``parse_docstring``, will attempt to
|
||||
parse parameter descriptions from Google Style function docstrings.
|
||||
error_on_invalid_docstring: if ``parse_docstring`` is provided, configures
|
||||
@ -1064,8 +987,12 @@ def tool(
|
||||
# Searches the API for the query.
|
||||
return
|
||||
|
||||
.. versionadded:: 0.2.14
|
||||
Parse Google-style docstrings:
|
||||
@tool(response_format="content_and_raw_output")
|
||||
def search_api(query: str) -> Tuple[str, dict]:
|
||||
return "partial json of results", {"full": "object of results"}
|
||||
|
||||
.. versionadded:: 0.2.14
|
||||
Parse Google-style docstrings:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
@ -1179,6 +1106,7 @@ def tool(
|
||||
return_direct=return_direct,
|
||||
args_schema=schema,
|
||||
infer_schema=infer_schema,
|
||||
response_format=response_format,
|
||||
parse_docstring=parse_docstring,
|
||||
error_on_invalid_docstring=error_on_invalid_docstring,
|
||||
)
|
||||
@ -1195,6 +1123,7 @@ def tool(
|
||||
description=f"{tool_name} tool",
|
||||
return_direct=return_direct,
|
||||
coroutine=coroutine,
|
||||
response_format=response_format,
|
||||
)
|
||||
|
||||
return _make_tool
|
||||
@ -1350,6 +1279,103 @@ class BaseToolkit(BaseModel, ABC):
|
||||
"""Get the tools in the toolkit."""
|
||||
|
||||
|
||||
def _is_tool_call(x: Any) -> bool:
|
||||
return isinstance(x, dict) and x.get("type") == "tool_call"
|
||||
|
||||
|
||||
def _handle_validation_error(
|
||||
e: ValidationError,
|
||||
*,
|
||||
flag: Union[Literal[True], str, Callable[[ValidationError], str]],
|
||||
) -> str:
|
||||
if isinstance(flag, bool):
|
||||
content = "Tool input validation error"
|
||||
elif isinstance(flag, str):
|
||||
content = flag
|
||||
elif callable(flag):
|
||||
content = flag(e)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Got unexpected type of `handle_validation_error`. Expected bool, "
|
||||
f"str or callable. Received: {flag}"
|
||||
)
|
||||
return content
|
||||
|
||||
|
||||
def _handle_tool_error(
|
||||
e: ToolException,
|
||||
*,
|
||||
flag: Optional[Union[Literal[True], str, Callable[[ToolException], str]]],
|
||||
) -> str:
|
||||
if isinstance(flag, bool):
|
||||
if e.args:
|
||||
content = e.args[0]
|
||||
else:
|
||||
content = "Tool execution error"
|
||||
elif isinstance(flag, str):
|
||||
content = flag
|
||||
elif callable(flag):
|
||||
content = flag(e)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Got unexpected type of `handle_tool_error`. Expected bool, str "
|
||||
f"or callable. Received: {flag}"
|
||||
)
|
||||
return content
|
||||
|
||||
|
||||
def _prep_run_args(
|
||||
input: Union[str, dict, ToolCall],
|
||||
config: Optional[RunnableConfig],
|
||||
**kwargs: Any,
|
||||
) -> Tuple[Union[str, Dict], Dict]:
|
||||
config = ensure_config(config)
|
||||
if _is_tool_call(input):
|
||||
tool_call_id: Optional[str] = cast(ToolCall, input)["id"]
|
||||
tool_input: Union[str, dict] = cast(ToolCall, input)["args"]
|
||||
else:
|
||||
tool_call_id = None
|
||||
tool_input = cast(Union[str, dict], input)
|
||||
return (
|
||||
tool_input,
|
||||
dict(
|
||||
callbacks=config.get("callbacks"),
|
||||
tags=config.get("tags"),
|
||||
metadata=config.get("metadata"),
|
||||
run_name=config.get("run_name"),
|
||||
run_id=config.pop("run_id", None),
|
||||
config=config,
|
||||
tool_call_id=tool_call_id,
|
||||
**kwargs,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _format_output(
|
||||
content: Any, raw_output: Any, tool_call_id: Optional[str]
|
||||
) -> Union[ToolMessage, Any]:
|
||||
if tool_call_id:
|
||||
# NOTE: This will fail to stringify lists which aren't actually content blocks
|
||||
# but whose first element happens to be a string or dict. Tools should avoid
|
||||
# returning such contents.
|
||||
if not isinstance(content, str) and not (
|
||||
isinstance(content, list)
|
||||
and content
|
||||
and isinstance(content[0], (str, dict))
|
||||
):
|
||||
content = _stringify(content)
|
||||
return ToolMessage(content, raw_output=raw_output, tool_call_id=tool_call_id)
|
||||
else:
|
||||
return content
|
||||
|
||||
|
||||
def _stringify(content: Any) -> str:
|
||||
try:
|
||||
return json.dumps(content)
|
||||
except Exception:
|
||||
return str(content)
|
||||
|
||||
|
||||
def _get_description_from_runnable(runnable: Runnable) -> str:
|
||||
"""Generate a placeholder description of a runnable."""
|
||||
input_schema = runnable.input_schema.schema()
|
||||
|
@ -317,6 +317,13 @@
|
||||
'title': 'Name',
|
||||
'type': 'string',
|
||||
}),
|
||||
'type': dict({
|
||||
'enum': list([
|
||||
'invalid_tool_call',
|
||||
]),
|
||||
'title': 'Type',
|
||||
'type': 'string',
|
||||
}),
|
||||
}),
|
||||
'required': list([
|
||||
'name',
|
||||
@ -419,6 +426,13 @@
|
||||
'title': 'Name',
|
||||
'type': 'string',
|
||||
}),
|
||||
'type': dict({
|
||||
'enum': list([
|
||||
'tool_call',
|
||||
]),
|
||||
'title': 'Type',
|
||||
'type': 'string',
|
||||
}),
|
||||
}),
|
||||
'required': list([
|
||||
'name',
|
||||
@ -908,6 +922,13 @@
|
||||
'title': 'Name',
|
||||
'type': 'string',
|
||||
}),
|
||||
'type': dict({
|
||||
'enum': list([
|
||||
'invalid_tool_call',
|
||||
]),
|
||||
'title': 'Type',
|
||||
'type': 'string',
|
||||
}),
|
||||
}),
|
||||
'required': list([
|
||||
'name',
|
||||
@ -1010,6 +1031,13 @@
|
||||
'title': 'Name',
|
||||
'type': 'string',
|
||||
}),
|
||||
'type': dict({
|
||||
'enum': list([
|
||||
'tool_call',
|
||||
]),
|
||||
'title': 'Type',
|
||||
'type': 'string',
|
||||
}),
|
||||
}),
|
||||
'required': list([
|
||||
'name',
|
||||
|
@ -674,6 +674,13 @@
|
||||
'title': 'Name',
|
||||
'type': 'string',
|
||||
}),
|
||||
'type': dict({
|
||||
'enum': list([
|
||||
'invalid_tool_call',
|
||||
]),
|
||||
'title': 'Type',
|
||||
'type': 'string',
|
||||
}),
|
||||
}),
|
||||
'required': list([
|
||||
'name',
|
||||
@ -776,6 +783,13 @@
|
||||
'title': 'Name',
|
||||
'type': 'string',
|
||||
}),
|
||||
'type': dict({
|
||||
'enum': list([
|
||||
'tool_call',
|
||||
]),
|
||||
'title': 'Type',
|
||||
'type': 'string',
|
||||
}),
|
||||
}),
|
||||
'required': list([
|
||||
'name',
|
||||
|
@ -5577,6 +5577,13 @@
|
||||
'title': 'Name',
|
||||
'type': 'string',
|
||||
}),
|
||||
'type': dict({
|
||||
'enum': list([
|
||||
'invalid_tool_call',
|
||||
]),
|
||||
'title': 'Type',
|
||||
'type': 'string',
|
||||
}),
|
||||
}),
|
||||
'required': list([
|
||||
'name',
|
||||
@ -5701,6 +5708,13 @@
|
||||
'title': 'Name',
|
||||
'type': 'string',
|
||||
}),
|
||||
'type': dict({
|
||||
'enum': list([
|
||||
'tool_call',
|
||||
]),
|
||||
'title': 'Type',
|
||||
'type': 'string',
|
||||
}),
|
||||
}),
|
||||
'required': list([
|
||||
'name',
|
||||
@ -6237,6 +6251,13 @@
|
||||
'title': 'Name',
|
||||
'type': 'string',
|
||||
}),
|
||||
'type': dict({
|
||||
'enum': list([
|
||||
'invalid_tool_call',
|
||||
]),
|
||||
'title': 'Type',
|
||||
'type': 'string',
|
||||
}),
|
||||
}),
|
||||
'required': list([
|
||||
'name',
|
||||
@ -6361,6 +6382,13 @@
|
||||
'title': 'Name',
|
||||
'type': 'string',
|
||||
}),
|
||||
'type': dict({
|
||||
'enum': list([
|
||||
'tool_call',
|
||||
]),
|
||||
'title': 'Type',
|
||||
'type': 'string',
|
||||
}),
|
||||
}),
|
||||
'required': list([
|
||||
'name',
|
||||
@ -6834,6 +6862,13 @@
|
||||
'title': 'Name',
|
||||
'type': 'string',
|
||||
}),
|
||||
'type': dict({
|
||||
'enum': list([
|
||||
'invalid_tool_call',
|
||||
]),
|
||||
'title': 'Type',
|
||||
'type': 'string',
|
||||
}),
|
||||
}),
|
||||
'required': list([
|
||||
'name',
|
||||
@ -6936,6 +6971,13 @@
|
||||
'title': 'Name',
|
||||
'type': 'string',
|
||||
}),
|
||||
'type': dict({
|
||||
'enum': list([
|
||||
'tool_call',
|
||||
]),
|
||||
'title': 'Type',
|
||||
'type': 'string',
|
||||
}),
|
||||
}),
|
||||
'required': list([
|
||||
'name',
|
||||
@ -7444,6 +7486,13 @@
|
||||
'title': 'Name',
|
||||
'type': 'string',
|
||||
}),
|
||||
'type': dict({
|
||||
'enum': list([
|
||||
'invalid_tool_call',
|
||||
]),
|
||||
'title': 'Type',
|
||||
'type': 'string',
|
||||
}),
|
||||
}),
|
||||
'required': list([
|
||||
'name',
|
||||
@ -7568,6 +7617,13 @@
|
||||
'title': 'Name',
|
||||
'type': 'string',
|
||||
}),
|
||||
'type': dict({
|
||||
'enum': list([
|
||||
'tool_call',
|
||||
]),
|
||||
'title': 'Type',
|
||||
'type': 'string',
|
||||
}),
|
||||
}),
|
||||
'required': list([
|
||||
'name',
|
||||
@ -8068,6 +8124,13 @@
|
||||
'title': 'Name',
|
||||
'type': 'string',
|
||||
}),
|
||||
'type': dict({
|
||||
'enum': list([
|
||||
'invalid_tool_call',
|
||||
]),
|
||||
'title': 'Type',
|
||||
'type': 'string',
|
||||
}),
|
||||
}),
|
||||
'required': list([
|
||||
'name',
|
||||
@ -8203,6 +8266,13 @@
|
||||
'title': 'Name',
|
||||
'type': 'string',
|
||||
}),
|
||||
'type': dict({
|
||||
'enum': list([
|
||||
'tool_call',
|
||||
]),
|
||||
'title': 'Type',
|
||||
'type': 'string',
|
||||
}),
|
||||
}),
|
||||
'required': list([
|
||||
'name',
|
||||
@ -8683,6 +8753,13 @@
|
||||
'title': 'Name',
|
||||
'type': 'string',
|
||||
}),
|
||||
'type': dict({
|
||||
'enum': list([
|
||||
'invalid_tool_call',
|
||||
]),
|
||||
'title': 'Type',
|
||||
'type': 'string',
|
||||
}),
|
||||
}),
|
||||
'required': list([
|
||||
'name',
|
||||
@ -8785,6 +8862,13 @@
|
||||
'title': 'Name',
|
||||
'type': 'string',
|
||||
}),
|
||||
'type': dict({
|
||||
'enum': list([
|
||||
'tool_call',
|
||||
]),
|
||||
'title': 'Type',
|
||||
'type': 'string',
|
||||
}),
|
||||
}),
|
||||
'required': list([
|
||||
'name',
|
||||
@ -9238,6 +9322,13 @@
|
||||
'title': 'Name',
|
||||
'type': 'string',
|
||||
}),
|
||||
'type': dict({
|
||||
'enum': list([
|
||||
'invalid_tool_call',
|
||||
]),
|
||||
'title': 'Type',
|
||||
'type': 'string',
|
||||
}),
|
||||
}),
|
||||
'required': list([
|
||||
'name',
|
||||
@ -9340,6 +9431,13 @@
|
||||
'title': 'Name',
|
||||
'type': 'string',
|
||||
}),
|
||||
'type': dict({
|
||||
'enum': list([
|
||||
'tool_call',
|
||||
]),
|
||||
'title': 'Type',
|
||||
'type': 'string',
|
||||
}),
|
||||
}),
|
||||
'required': list([
|
||||
'name',
|
||||
@ -9880,6 +9978,13 @@
|
||||
'title': 'Name',
|
||||
'type': 'string',
|
||||
}),
|
||||
'type': dict({
|
||||
'enum': list([
|
||||
'invalid_tool_call',
|
||||
]),
|
||||
'title': 'Type',
|
||||
'type': 'string',
|
||||
}),
|
||||
}),
|
||||
'required': list([
|
||||
'name',
|
||||
@ -10004,6 +10109,13 @@
|
||||
'title': 'Name',
|
||||
'type': 'string',
|
||||
}),
|
||||
'type': dict({
|
||||
'enum': list([
|
||||
'tool_call',
|
||||
]),
|
||||
'title': 'Type',
|
||||
'type': 'string',
|
||||
}),
|
||||
}),
|
||||
'required': list([
|
||||
'name',
|
||||
|
@ -8,7 +8,7 @@ import textwrap
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from functools import partial
|
||||
from typing import Any, Callable, Dict, List, Optional, Type, Union
|
||||
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Type, Union
|
||||
|
||||
import pytest
|
||||
from typing_extensions import Annotated, TypedDict
|
||||
@ -17,6 +17,7 @@ from langchain_core.callbacks import (
|
||||
AsyncCallbackManagerForToolRun,
|
||||
CallbackManagerForToolRun,
|
||||
)
|
||||
from langchain_core.messages import ToolMessage
|
||||
from langchain_core.pydantic_v1 import BaseModel, ValidationError
|
||||
from langchain_core.runnables import Runnable, RunnableLambda, ensure_config
|
||||
from langchain_core.tools import (
|
||||
@ -1067,6 +1068,65 @@ def test_tool_annotated_descriptions() -> None:
|
||||
}
|
||||
|
||||
|
||||
def test_tool_call_input_tool_message_output() -> None:
|
||||
tool_call = {
|
||||
"name": "structured_api",
|
||||
"args": {"arg1": 1, "arg2": True, "arg3": {"img": "base64string..."}},
|
||||
"id": "123",
|
||||
"type": "tool_call",
|
||||
}
|
||||
tool = _MockStructuredTool()
|
||||
expected = ToolMessage("1 True {'img': 'base64string...'}", tool_call_id="123")
|
||||
actual = tool.invoke(tool_call)
|
||||
assert actual == expected
|
||||
|
||||
tool_call.pop("type")
|
||||
with pytest.raises(ValidationError):
|
||||
tool.invoke(tool_call)
|
||||
|
||||
|
||||
class _MockStructuredToolWithRawOutput(BaseTool):
|
||||
name: str = "structured_api"
|
||||
args_schema: Type[BaseModel] = _MockSchema
|
||||
description: str = "A Structured Tool"
|
||||
response_format: Literal["content_and_raw_output"] = "content_and_raw_output"
|
||||
|
||||
def _run(
|
||||
self, arg1: int, arg2: bool, arg3: Optional[dict] = None
|
||||
) -> Tuple[str, dict]:
|
||||
return f"{arg1} {arg2}", {"arg1": arg1, "arg2": arg2, "arg3": arg3}
|
||||
|
||||
|
||||
@tool("structured_api", response_format="content_and_raw_output")
|
||||
def _mock_structured_tool_with_raw_output(
|
||||
arg1: int, arg2: bool, arg3: Optional[dict] = None
|
||||
) -> Tuple[str, dict]:
|
||||
"""A Structured Tool"""
|
||||
return f"{arg1} {arg2}", {"arg1": arg1, "arg2": arg2, "arg3": arg3}
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"tool", [_MockStructuredToolWithRawOutput(), _mock_structured_tool_with_raw_output]
|
||||
)
|
||||
def test_tool_call_input_tool_message_with_raw_output(tool: BaseTool) -> None:
|
||||
tool_call: Dict = {
|
||||
"name": "structured_api",
|
||||
"args": {"arg1": 1, "arg2": True, "arg3": {"img": "base64string..."}},
|
||||
"id": "123",
|
||||
"type": "tool_call",
|
||||
}
|
||||
expected = ToolMessage("1 True", raw_output=tool_call["args"], tool_call_id="123")
|
||||
actual = tool.invoke(tool_call)
|
||||
assert actual == expected
|
||||
|
||||
tool_call.pop("type")
|
||||
with pytest.raises(ValidationError):
|
||||
tool.invoke(tool_call)
|
||||
|
||||
actual_content = tool.invoke(tool_call["args"])
|
||||
assert actual_content == expected.content
|
||||
|
||||
|
||||
def test_convert_from_runnable_dict() -> None:
|
||||
# Test with typed dict input
|
||||
class Args(TypedDict):
|
||||
|
26
libs/langchain/poetry.lock
generated
26
libs/langchain/poetry.lock
generated
@ -1,4 +1,4 @@
|
||||
# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand.
|
||||
# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand.
|
||||
|
||||
[[package]]
|
||||
name = "aiohttp"
|
||||
@ -1760,7 +1760,7 @@ files = [
|
||||
|
||||
[[package]]
|
||||
name = "langchain-core"
|
||||
version = "0.2.12"
|
||||
version = "0.2.13"
|
||||
description = "Building applications with LLMs through composability"
|
||||
optional = false
|
||||
python-versions = ">=3.8.1,<4.0"
|
||||
@ -1784,7 +1784,7 @@ url = "../core"
|
||||
|
||||
[[package]]
|
||||
name = "langchain-openai"
|
||||
version = "0.1.14"
|
||||
version = "0.1.15"
|
||||
description = "An integration package connecting OpenAI and LangChain"
|
||||
optional = true
|
||||
python-versions = ">=3.8.1,<4.0"
|
||||
@ -1792,7 +1792,7 @@ files = []
|
||||
develop = true
|
||||
|
||||
[package.dependencies]
|
||||
langchain-core = ">=0.2.2,<0.3"
|
||||
langchain-core = "^0.2.13"
|
||||
openai = "^1.32.0"
|
||||
tiktoken = ">=0.7,<1"
|
||||
|
||||
@ -1834,13 +1834,13 @@ types-requests = ">=2.31.0.2,<3.0.0.0"
|
||||
|
||||
[[package]]
|
||||
name = "langsmith"
|
||||
version = "0.1.84"
|
||||
version = "0.1.85"
|
||||
description = "Client library to connect to the LangSmith LLM Tracing and Evaluation Platform."
|
||||
optional = false
|
||||
python-versions = "<4.0,>=3.8.1"
|
||||
files = [
|
||||
{file = "langsmith-0.1.84-py3-none-any.whl", hash = "sha256:01f3c6390dba26c583bac8dd0e551ce3d0509c7f55cad714db0b5c8d36e4c7ff"},
|
||||
{file = "langsmith-0.1.84.tar.gz", hash = "sha256:5220c0439838b9a5bd320fd3686be505c5083dcee22d2452006c23891153bea1"},
|
||||
{file = "langsmith-0.1.85-py3-none-any.whl", hash = "sha256:c1f94384f10cea96f7b4d33fd3db7ec180c03c7468877d50846f881d2017ff94"},
|
||||
{file = "langsmith-0.1.85.tar.gz", hash = "sha256:acff31f9e53efa48586cf8e32f65625a335c74d7c4fa306d1655ac18452296f6"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@ -2350,13 +2350,13 @@ files = [
|
||||
|
||||
[[package]]
|
||||
name = "openai"
|
||||
version = "1.35.10"
|
||||
version = "1.35.13"
|
||||
description = "The official Python library for the openai API"
|
||||
optional = true
|
||||
python-versions = ">=3.7.1"
|
||||
files = [
|
||||
{file = "openai-1.35.10-py3-none-any.whl", hash = "sha256:962cb5c23224b5cbd16078308dabab97a08b0a5ad736a4fdb3dc2ffc44ac974f"},
|
||||
{file = "openai-1.35.10.tar.gz", hash = "sha256:85966949f4f960f3e4b239a659f9fd64d3a97ecc43c44dc0a044b5c7f11cccc6"},
|
||||
{file = "openai-1.35.13-py3-none-any.whl", hash = "sha256:36ec3e93e0d1f243f69be85c89b9221a471c3e450dfd9df16c9829e3cdf63e60"},
|
||||
{file = "openai-1.35.13.tar.gz", hash = "sha256:c684f3945608baf7d2dcc0ef3ee6f3e27e4c66f21076df0b47be45d57e6ae6e4"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@ -4141,13 +4141,13 @@ urllib3 = ">=2"
|
||||
|
||||
[[package]]
|
||||
name = "types-setuptools"
|
||||
version = "70.2.0.20240704"
|
||||
version = "70.3.0.20240710"
|
||||
description = "Typing stubs for setuptools"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "types-setuptools-70.2.0.20240704.tar.gz", hash = "sha256:2f8d28d16ca1607080f9fdf19595bd49c942884b2bbd6529c9b8a9a8fc8db911"},
|
||||
{file = "types_setuptools-70.2.0.20240704-py3-none-any.whl", hash = "sha256:6b892d5441c2ed58dd255724516e3df1db54892fb20597599aea66d04c3e4d7f"},
|
||||
{file = "types-setuptools-70.3.0.20240710.tar.gz", hash = "sha256:842cbf399812d2b65042c9d6ff35113bbf282dee38794779aa1f94e597bafc35"},
|
||||
{file = "types_setuptools-70.3.0.20240710-py3-none-any.whl", hash = "sha256:bd0db2a4b9f2c49ac5564be4e0fb3125c4c46b1f73eafdcbceffa5b005cceca4"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
@ -43,6 +43,7 @@ from langchain_core.messages import (
|
||||
ToolMessage,
|
||||
)
|
||||
from langchain_core.messages.ai import UsageMetadata
|
||||
from langchain_core.messages.tool import tool_call_chunk as create_tool_call_chunk
|
||||
from langchain_core.output_parsers import (
|
||||
JsonOutputKeyToolsParser,
|
||||
PydanticToolsParser,
|
||||
@ -1102,12 +1103,12 @@ def _make_message_chunk_from_anthropic_event(
|
||||
warnings.warn("Received unexpected tool content block.")
|
||||
content_block = event.content_block.model_dump()
|
||||
content_block["index"] = event.index
|
||||
tool_call_chunk = {
|
||||
"index": event.index,
|
||||
"id": event.content_block.id,
|
||||
"name": event.content_block.name,
|
||||
"args": "",
|
||||
}
|
||||
tool_call_chunk = create_tool_call_chunk(
|
||||
index=event.index,
|
||||
id=event.content_block.id,
|
||||
name=event.content_block.name,
|
||||
args="",
|
||||
)
|
||||
message_chunk = AIMessageChunk(
|
||||
content=[content_block],
|
||||
tool_call_chunks=[tool_call_chunk], # type: ignore
|
||||
|
@ -1,6 +1,7 @@
|
||||
from typing import Any, List, Optional, Type, Union, cast
|
||||
|
||||
from langchain_core.messages import AIMessage, ToolCall
|
||||
from langchain_core.messages.tool import tool_call
|
||||
from langchain_core.output_parsers import BaseGenerationOutputParser
|
||||
from langchain_core.outputs import ChatGeneration, Generation
|
||||
from langchain_core.pydantic_v1 import BaseModel
|
||||
@ -79,7 +80,7 @@ def extract_tool_calls(content: Union[str, List[Union[str, dict]]]) -> List[Tool
|
||||
if block["type"] != "tool_use":
|
||||
continue
|
||||
tool_calls.append(
|
||||
ToolCall(name=block["name"], args=block["input"], id=block["id"])
|
||||
tool_call(name=block["name"], args=block["input"], id=block["id"])
|
||||
)
|
||||
return tool_calls
|
||||
else:
|
||||
|
@ -365,10 +365,7 @@ async def test_astreaming() -> None:
|
||||
|
||||
|
||||
def test_tool_use() -> None:
|
||||
llm = ChatAnthropic( # type: ignore[call-arg]
|
||||
model=MODEL_NAME,
|
||||
)
|
||||
|
||||
llm = ChatAnthropic(model=MODEL_NAME) # type: ignore[call-arg]
|
||||
llm_with_tools = llm.bind_tools(
|
||||
[
|
||||
{
|
||||
@ -478,6 +475,7 @@ def test_anthropic_with_empty_text_block() -> None:
|
||||
"name": "type_letter",
|
||||
"args": {"letter": "d"},
|
||||
"id": "toolu_01V6d6W32QGGSmQm4BT98EKk",
|
||||
"type": "tool_call",
|
||||
},
|
||||
],
|
||||
),
|
||||
|
@ -33,8 +33,20 @@ class _Foo2(BaseModel):
|
||||
def test_tools_output_parser() -> None:
|
||||
output_parser = ToolsOutputParser()
|
||||
expected = [
|
||||
{"name": "_Foo1", "args": {"bar": 0}, "id": "1", "index": 1},
|
||||
{"name": "_Foo2", "args": {"baz": "a"}, "id": "2", "index": 3},
|
||||
{
|
||||
"name": "_Foo1",
|
||||
"args": {"bar": 0},
|
||||
"id": "1",
|
||||
"index": 1,
|
||||
"type": "tool_call",
|
||||
},
|
||||
{
|
||||
"name": "_Foo2",
|
||||
"args": {"baz": "a"},
|
||||
"id": "2",
|
||||
"index": 3,
|
||||
"type": "tool_call",
|
||||
},
|
||||
]
|
||||
actual = output_parser.parse_result(_RESULT)
|
||||
assert expected == actual
|
||||
@ -56,7 +68,13 @@ def test_tools_output_parser_args_only() -> None:
|
||||
|
||||
def test_tools_output_parser_first_tool_only() -> None:
|
||||
output_parser = ToolsOutputParser(first_tool_only=True)
|
||||
expected: Any = {"name": "_Foo1", "args": {"bar": 0}, "id": "1", "index": 1}
|
||||
expected: Any = {
|
||||
"name": "_Foo1",
|
||||
"args": {"bar": 0},
|
||||
"id": "1",
|
||||
"index": 1,
|
||||
"type": "tool_call",
|
||||
}
|
||||
actual = output_parser.parse_result(_RESULT)
|
||||
assert expected == actual
|
||||
|
||||
@ -81,7 +99,14 @@ def test_tools_output_parser_empty_content() -> None:
|
||||
)
|
||||
message = AIMessage(
|
||||
"",
|
||||
tool_calls=[{"name": "ChartType", "args": {"chart_type": "pie"}, "id": "foo"}],
|
||||
tool_calls=[
|
||||
{
|
||||
"name": "ChartType",
|
||||
"args": {"chart_type": "pie"},
|
||||
"id": "foo",
|
||||
"type": "tool_call",
|
||||
}
|
||||
],
|
||||
)
|
||||
actual = output_parser.invoke(message)
|
||||
expected = ChartType(chart_type="pie")
|
||||
|
@ -9,10 +9,11 @@ import json
|
||||
import os
|
||||
import re
|
||||
import urllib
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from io import BytesIO
|
||||
from typing import Any, BinaryIO, Callable, List, Optional
|
||||
from typing import Any, BinaryIO, Callable, List, Literal, Optional, Tuple
|
||||
from uuid import uuid4
|
||||
|
||||
import requests
|
||||
@ -126,6 +127,8 @@ class SessionsPythonREPLTool(BaseTool):
|
||||
session_id: str = str(uuid4())
|
||||
"""The session ID to use for the code interpreter. Defaults to a random UUID."""
|
||||
|
||||
response_format: Literal["content_and_raw_output"] = "content_and_raw_output"
|
||||
|
||||
def _build_url(self, path: str) -> str:
|
||||
pool_management_endpoint = self.pool_management_endpoint
|
||||
if not pool_management_endpoint:
|
||||
@ -164,16 +167,16 @@ class SessionsPythonREPLTool(BaseTool):
|
||||
properties = response_json.get("properties", {})
|
||||
return properties
|
||||
|
||||
def _run(self, python_code: str) -> Any:
|
||||
def _run(self, python_code: str, **kwargs: Any) -> Tuple[str, dict]:
|
||||
response = self.execute(python_code)
|
||||
|
||||
# if the result is an image, remove the base64 data
|
||||
result = response.get("result")
|
||||
result = deepcopy(response.get("result"))
|
||||
if isinstance(result, dict):
|
||||
if result.get("type") == "image" and "base64_data" in result:
|
||||
result.pop("base64_data")
|
||||
|
||||
return json.dumps(
|
||||
content = json.dumps(
|
||||
{
|
||||
"result": result,
|
||||
"stdout": response.get("stdout"),
|
||||
@ -181,6 +184,7 @@ class SessionsPythonREPLTool(BaseTool):
|
||||
},
|
||||
indent=2,
|
||||
)
|
||||
return content, response
|
||||
|
||||
def upload_file(
|
||||
self,
|
||||
|
@ -54,6 +54,12 @@ from langchain_core.messages import (
|
||||
ToolMessage,
|
||||
ToolMessageChunk,
|
||||
)
|
||||
from langchain_core.messages.tool import (
|
||||
ToolCallChunk,
|
||||
)
|
||||
from langchain_core.messages.tool import (
|
||||
tool_call_chunk as create_tool_call_chunk,
|
||||
)
|
||||
from langchain_core.output_parsers import JsonOutputParser, PydanticOutputParser
|
||||
from langchain_core.output_parsers.base import OutputParserLike
|
||||
from langchain_core.output_parsers.openai_tools import (
|
||||
@ -199,6 +205,7 @@ def _convert_chunk_to_message_chunk(
|
||||
role = cast(str, _dict.get("role"))
|
||||
content = cast(str, _dict.get("content") or "")
|
||||
additional_kwargs: Dict = {}
|
||||
tool_call_chunks: List[ToolCallChunk] = []
|
||||
if _dict.get("function_call"):
|
||||
function_call = dict(_dict["function_call"])
|
||||
if "name" in function_call and function_call["name"] is None:
|
||||
@ -206,21 +213,18 @@ def _convert_chunk_to_message_chunk(
|
||||
additional_kwargs["function_call"] = function_call
|
||||
if raw_tool_calls := _dict.get("tool_calls"):
|
||||
additional_kwargs["tool_calls"] = raw_tool_calls
|
||||
try:
|
||||
tool_call_chunks = [
|
||||
{
|
||||
"name": rtc["function"].get("name"),
|
||||
"args": rtc["function"].get("arguments"),
|
||||
"id": rtc.get("id"),
|
||||
"index": rtc["index"],
|
||||
}
|
||||
for rtc in raw_tool_calls
|
||||
]
|
||||
except KeyError:
|
||||
pass
|
||||
else:
|
||||
tool_call_chunks = []
|
||||
|
||||
for rtc in raw_tool_calls:
|
||||
try:
|
||||
tool_call_chunks.append(
|
||||
create_tool_call_chunk(
|
||||
name=rtc["function"].get("name"),
|
||||
args=rtc["function"].get("arguments"),
|
||||
id=rtc.get("id"),
|
||||
index=rtc.get("index"),
|
||||
)
|
||||
)
|
||||
except KeyError:
|
||||
pass
|
||||
if role == "user" or default_class == HumanMessageChunk:
|
||||
return HumanMessageChunk(content=content)
|
||||
elif role == "assistant" or default_class == AIMessageChunk:
|
||||
@ -237,7 +241,7 @@ def _convert_chunk_to_message_chunk(
|
||||
return AIMessageChunk(
|
||||
content=content,
|
||||
additional_kwargs=additional_kwargs,
|
||||
tool_call_chunks=tool_call_chunks, # type: ignore[arg-type]
|
||||
tool_call_chunks=tool_call_chunks,
|
||||
usage_metadata=usage_metadata, # type: ignore[arg-type]
|
||||
)
|
||||
elif role == "system" or default_class == SystemMessageChunk:
|
||||
|
@ -53,6 +53,7 @@ from langchain_core.messages import (
|
||||
ToolMessage,
|
||||
ToolMessageChunk,
|
||||
)
|
||||
from langchain_core.messages.tool import tool_call_chunk as create_tool_call_chunk
|
||||
from langchain_core.output_parsers import (
|
||||
JsonOutputParser,
|
||||
PydanticOutputParser,
|
||||
@ -511,19 +512,19 @@ class ChatGroq(BaseChatModel):
|
||||
generation = chat_result.generations[0]
|
||||
message = cast(AIMessage, generation.message)
|
||||
tool_call_chunks = [
|
||||
{
|
||||
"name": rtc["function"].get("name"),
|
||||
"args": rtc["function"].get("arguments"),
|
||||
"id": rtc.get("id"),
|
||||
"index": rtc.get("index"),
|
||||
}
|
||||
create_tool_call_chunk(
|
||||
name=rtc["function"].get("name"),
|
||||
args=rtc["function"].get("arguments"),
|
||||
id=rtc.get("id"),
|
||||
index=rtc.get("index"),
|
||||
)
|
||||
for rtc in message.additional_kwargs.get("tool_calls", [])
|
||||
]
|
||||
chunk_ = ChatGenerationChunk(
|
||||
message=AIMessageChunk(
|
||||
content=message.content,
|
||||
additional_kwargs=message.additional_kwargs,
|
||||
tool_call_chunks=tool_call_chunks, # type: ignore[arg-type]
|
||||
tool_call_chunks=tool_call_chunks,
|
||||
usage_metadata=message.usage_metadata,
|
||||
),
|
||||
generation_info=generation.generation_info,
|
||||
|
@ -77,6 +77,7 @@ def test__convert_dict_to_message_tool_call() -> None:
|
||||
name="GenerateUsername",
|
||||
args={"name": "Sally", "hair_color": "green"},
|
||||
id="call_wm0JY6CdwOMZ4eTxHWUThDNz",
|
||||
type="tool_call",
|
||||
)
|
||||
],
|
||||
)
|
||||
@ -112,6 +113,7 @@ def test__convert_dict_to_message_tool_call() -> None:
|
||||
args="oops",
|
||||
id="call_wm0JY6CdwOMZ4eTxHWUThDNz",
|
||||
error="Function GenerateUsername arguments:\n\noops\n\nare not valid JSON. Received JSONDecodeError Expecting value: line 1 column 1 (char 0)", # noqa: E501
|
||||
type="invalid_tool_call",
|
||||
),
|
||||
],
|
||||
tool_calls=[
|
||||
@ -119,6 +121,7 @@ def test__convert_dict_to_message_tool_call() -> None:
|
||||
name="GenerateUsername",
|
||||
args={"name": "Sally", "hair_color": "green"},
|
||||
id="call_abc123",
|
||||
type="tool_call",
|
||||
),
|
||||
],
|
||||
)
|
||||
|
@ -42,10 +42,12 @@ from langchain_core.messages import (
|
||||
HumanMessageChunk,
|
||||
SystemMessage,
|
||||
SystemMessageChunk,
|
||||
ToolCallChunk,
|
||||
ToolMessage,
|
||||
ToolMessageChunk,
|
||||
convert_to_messages,
|
||||
)
|
||||
from langchain_core.messages.tool import tool_call_chunk as create_tool_call_chunk
|
||||
from langchain_core.output_parsers import JsonOutputParser, PydanticOutputParser
|
||||
from langchain_core.output_parsers.base import OutputParserLike
|
||||
from langchain_core.output_parsers.openai_tools import (
|
||||
@ -174,6 +176,7 @@ def _convert_delta_to_message_chunk(
|
||||
role = cast(str, _dict.get("role"))
|
||||
content = cast(str, _dict.get("content") or "")
|
||||
additional_kwargs: Dict = {}
|
||||
tool_call_chunks: List[ToolCallChunk] = []
|
||||
if _dict.get("function_call"):
|
||||
function_call = dict(_dict["function_call"])
|
||||
if "name" in function_call and function_call["name"] is None:
|
||||
@ -181,21 +184,18 @@ def _convert_delta_to_message_chunk(
|
||||
additional_kwargs["function_call"] = function_call
|
||||
if raw_tool_calls := _dict.get("tool_calls"):
|
||||
additional_kwargs["tool_calls"] = raw_tool_calls
|
||||
try:
|
||||
tool_call_chunks = [
|
||||
{
|
||||
"name": rtc["function"].get("name"),
|
||||
"args": rtc["function"].get("arguments"),
|
||||
"id": rtc.get("id"),
|
||||
"index": rtc["index"],
|
||||
}
|
||||
for rtc in raw_tool_calls
|
||||
]
|
||||
except KeyError:
|
||||
pass
|
||||
else:
|
||||
tool_call_chunks = []
|
||||
|
||||
for rtc in raw_tool_calls:
|
||||
try:
|
||||
tool_call_chunks.append(
|
||||
create_tool_call_chunk(
|
||||
name=rtc["function"].get("name"),
|
||||
args=rtc["function"].get("arguments"),
|
||||
id=rtc.get("id"),
|
||||
index=rtc.get("index"),
|
||||
)
|
||||
)
|
||||
except KeyError:
|
||||
pass
|
||||
if role == "user" or default_class == HumanMessageChunk:
|
||||
return HumanMessageChunk(content=content)
|
||||
elif role == "assistant" or default_class == AIMessageChunk:
|
||||
|
@ -50,6 +50,7 @@ from langchain_core.messages import (
|
||||
ToolCall,
|
||||
ToolMessage,
|
||||
)
|
||||
from langchain_core.messages.tool import tool_call_chunk
|
||||
from langchain_core.output_parsers import (
|
||||
JsonOutputParser,
|
||||
PydanticOutputParser,
|
||||
@ -103,19 +104,10 @@ def _convert_mistral_chat_message_to_message(
|
||||
dict, parse_tool_call(raw_tool_call, return_id=True)
|
||||
)
|
||||
if not parsed["id"]:
|
||||
tool_call_id = uuid.uuid4().hex[:]
|
||||
tool_calls.append(
|
||||
{
|
||||
**parsed,
|
||||
**{"id": tool_call_id},
|
||||
},
|
||||
)
|
||||
else:
|
||||
tool_calls.append(parsed)
|
||||
parsed["id"] = uuid.uuid4().hex[:]
|
||||
tool_calls.append(parsed)
|
||||
except Exception as e:
|
||||
invalid_tool_calls.append(
|
||||
dict(make_invalid_tool_call(raw_tool_call, str(e)))
|
||||
)
|
||||
invalid_tool_calls.append(make_invalid_tool_call(raw_tool_call, str(e)))
|
||||
return AIMessage(
|
||||
content=content,
|
||||
additional_kwargs=additional_kwargs,
|
||||
@ -206,12 +198,12 @@ def _convert_chunk_to_message_chunk(
|
||||
else:
|
||||
tool_call_id = raw_tool_call.get("id")
|
||||
tool_call_chunks.append(
|
||||
{
|
||||
"name": raw_tool_call["function"].get("name"),
|
||||
"args": raw_tool_call["function"].get("arguments"),
|
||||
"id": tool_call_id,
|
||||
"index": raw_tool_call.get("index"),
|
||||
}
|
||||
tool_call_chunk(
|
||||
name=raw_tool_call["function"].get("name"),
|
||||
args=raw_tool_call["function"].get("arguments"),
|
||||
id=tool_call_id,
|
||||
index=raw_tool_call.get("index"),
|
||||
)
|
||||
)
|
||||
except KeyError:
|
||||
pass
|
||||
|
@ -144,6 +144,7 @@ def test__convert_dict_to_message_tool_call() -> None:
|
||||
name="GenerateUsername",
|
||||
args={"name": "Sally", "hair_color": "green"},
|
||||
id="abc123",
|
||||
type="tool_call",
|
||||
)
|
||||
],
|
||||
)
|
||||
@ -178,6 +179,7 @@ def test__convert_dict_to_message_tool_call() -> None:
|
||||
args="oops",
|
||||
error="Function GenerateUsername arguments:\n\noops\n\nare not valid JSON. Received JSONDecodeError Expecting value: line 1 column 1 (char 0)", # noqa: E501
|
||||
id="abc123",
|
||||
type="invalid_tool_call",
|
||||
),
|
||||
],
|
||||
tool_calls=[
|
||||
@ -185,6 +187,7 @@ def test__convert_dict_to_message_tool_call() -> None:
|
||||
name="GenerateUsername",
|
||||
args={"name": "Sally", "hair_color": "green"},
|
||||
id="def456",
|
||||
type="tool_call",
|
||||
),
|
||||
],
|
||||
)
|
||||
|
@ -63,6 +63,7 @@ from langchain_core.messages import (
|
||||
ToolMessageChunk,
|
||||
)
|
||||
from langchain_core.messages.ai import UsageMetadata
|
||||
from langchain_core.messages.tool import tool_call_chunk
|
||||
from langchain_core.output_parsers import JsonOutputParser, PydanticOutputParser
|
||||
from langchain_core.output_parsers.base import OutputParserLike
|
||||
from langchain_core.output_parsers.openai_tools import (
|
||||
@ -244,12 +245,12 @@ def _convert_delta_to_message_chunk(
|
||||
additional_kwargs["tool_calls"] = raw_tool_calls
|
||||
try:
|
||||
tool_call_chunks = [
|
||||
{
|
||||
"name": rtc["function"].get("name"),
|
||||
"args": rtc["function"].get("arguments"),
|
||||
"id": rtc.get("id"),
|
||||
"index": rtc["index"],
|
||||
}
|
||||
tool_call_chunk(
|
||||
name=rtc["function"].get("name"),
|
||||
args=rtc["function"].get("arguments"),
|
||||
id=rtc.get("id"),
|
||||
index=rtc["index"],
|
||||
)
|
||||
for rtc in raw_tool_calls
|
||||
]
|
||||
except KeyError:
|
||||
|
@ -117,6 +117,7 @@ def test__convert_dict_to_message_tool_call() -> None:
|
||||
name="GenerateUsername",
|
||||
args={"name": "Sally", "hair_color": "green"},
|
||||
id="call_wm0JY6CdwOMZ4eTxHWUThDNz",
|
||||
type="tool_call",
|
||||
)
|
||||
],
|
||||
)
|
||||
@ -151,6 +152,7 @@ def test__convert_dict_to_message_tool_call() -> None:
|
||||
args="oops",
|
||||
id="call_wm0JY6CdwOMZ4eTxHWUThDNz",
|
||||
error="Function GenerateUsername arguments:\n\noops\n\nare not valid JSON. Received JSONDecodeError Expecting value: line 1 column 1 (char 0)", # noqa: E501
|
||||
type="invalid_tool_call",
|
||||
)
|
||||
],
|
||||
tool_calls=[
|
||||
@ -158,6 +160,7 @@ def test__convert_dict_to_message_tool_call() -> None:
|
||||
name="GenerateUsername",
|
||||
args={"name": "Sally", "hair_color": "green"},
|
||||
id="call_abc123",
|
||||
type="tool_call",
|
||||
)
|
||||
],
|
||||
)
|
||||
@ -353,7 +356,10 @@ def test_get_num_tokens_from_messages() -> None:
|
||||
),
|
||||
AIMessage("a nice bird"),
|
||||
AIMessage(
|
||||
"", tool_calls=[ToolCall(id="foo", name="bar", args={"arg1": "arg1"})]
|
||||
"",
|
||||
tool_calls=[
|
||||
ToolCall(id="foo", name="bar", args={"arg1": "arg1"}, type="tool_call")
|
||||
],
|
||||
),
|
||||
AIMessage(
|
||||
"",
|
||||
@ -362,7 +368,10 @@ def test_get_num_tokens_from_messages() -> None:
|
||||
},
|
||||
),
|
||||
AIMessage(
|
||||
"text", tool_calls=[ToolCall(id="foo", name="bar", args={"arg1": "arg1"})]
|
||||
"text",
|
||||
tool_calls=[
|
||||
ToolCall(id="foo", name="bar", args={"arg1": "arg1"}, type="tool_call")
|
||||
],
|
||||
),
|
||||
ToolMessage("foobar", tool_call_id="foo"),
|
||||
]
|
||||
|
Loading…
Reference in New Issue
Block a user