core[minor], integrations...[patch]: Support ToolCall as Tool input and ToolMessage as Tool output (#24038)

Changes:
- ToolCall, InvalidToolCall and ToolCallChunk can all accept a "type"
parameter now
- LLM integration packages add "type" to all the above
- Tool supports ToolCall inputs that have "type" specified
- Tool outputs ToolMessage when a ToolCall is passed as input
- Tools can separately specify ToolMessage.content and
ToolMessage.raw_output
- Tools emit events for validation errors (using on_tool_error and
on_tool_end)

Example:
```python
@tool("structured_api", response_format="content_and_raw_output")
def _mock_structured_tool_with_raw_output(
    arg1: int, arg2: bool, arg3: Optional[dict] = None
) -> Tuple[str, dict]:
    """A Structured Tool"""
    return f"{arg1} {arg2}", {"arg1": arg1, "arg2": arg2, "arg3": arg3}


def test_tool_call_input_tool_message_with_raw_output() -> None:
    tool_call: Dict = {
        "name": "structured_api",
        "args": {"arg1": 1, "arg2": True, "arg3": {"img": "base64string..."}},
        "id": "123",
        "type": "tool_call",
    }
    expected = ToolMessage("1 True", raw_output=tool_call["args"], tool_call_id="123")
    tool = _mock_structured_tool_with_raw_output
    actual = tool.invoke(tool_call)
    assert actual == expected

    tool_call.pop("type")
    with pytest.raises(ValidationError):
        tool.invoke(tool_call)

    actual_content = tool.invoke(tool_call["args"])
    assert actual_content == expected.content
```

---------

Co-authored-by: Erick Friis <erick@langchain.dev>
This commit is contained in:
Bagatur 2024-07-11 14:54:02 -07:00 committed by GitHub
parent eeb996034b
commit 5fd1e67808
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
22 changed files with 647 additions and 327 deletions

View File

@ -39,7 +39,7 @@ def get_non_abstract_subclasses(cls: Type[BaseTool]) -> List[Type[BaseTool]]:
def test_all_subclasses_accept_run_manager(cls: Type[BaseTool]) -> None:
"""Test that tools defined in this repo accept a run manager argument."""
# This wouldn't be necessary if the BaseTool had a strict API.
if cls._run is not BaseTool._arun:
if cls._run is not BaseTool._run:
run_func = cls._run
params = inspect.signature(run_func).parameters
assert "run_manager" in params

View File

@ -1,7 +1,7 @@
import json
from typing import Any, Dict, List, Literal, Optional, Tuple, Union
from typing_extensions import TypedDict
from typing_extensions import NotRequired, TypedDict
from langchain_core.messages.base import BaseMessage, BaseMessageChunk, merge_content
from langchain_core.utils._merge import merge_dicts, merge_obj
@ -146,6 +146,11 @@ class ToolCall(TypedDict):
An identifier is needed to associate a tool call request with a tool
call result in events when multiple concurrent tool calls are made.
"""
type: NotRequired[Literal["tool_call"]]
def tool_call(*, name: str, args: Dict[str, Any], id: Optional[str]) -> ToolCall:
return ToolCall(name=name, args=args, id=id, type="tool_call")
class ToolCallChunk(TypedDict):
@ -176,6 +181,19 @@ class ToolCallChunk(TypedDict):
"""An identifier associated with the tool call."""
index: Optional[int]
"""The index of the tool call in a sequence."""
type: NotRequired[Literal["tool_call_chunk"]]
def tool_call_chunk(
*,
name: Optional[str] = None,
args: Optional[str] = None,
id: Optional[str] = None,
index: Optional[int] = None,
) -> ToolCallChunk:
return ToolCallChunk(
name=name, args=args, id=id, index=index, type="tool_call_chunk"
)
class InvalidToolCall(TypedDict):
@ -193,6 +211,19 @@ class InvalidToolCall(TypedDict):
"""An identifier associated with the tool call."""
error: Optional[str]
"""An error message associated with the tool call."""
type: NotRequired[Literal["invalid_tool_call"]]
def invalid_tool_call(
*,
name: Optional[str] = None,
args: Optional[str] = None,
id: Optional[str] = None,
error: Optional[str] = None,
) -> InvalidToolCall:
return InvalidToolCall(
name=name, args=args, id=id, error=error, type="invalid_tool_call"
)
def default_tool_parser(

View File

@ -5,6 +5,12 @@ from typing import Any, Dict, List, Optional, Type
from langchain_core.exceptions import OutputParserException
from langchain_core.messages import AIMessage, InvalidToolCall
from langchain_core.messages.tool import (
invalid_tool_call,
)
from langchain_core.messages.tool import (
tool_call as create_tool_call,
)
from langchain_core.output_parsers.transform import BaseCumulativeTransformOutputParser
from langchain_core.outputs import ChatGeneration, Generation
from langchain_core.pydantic_v1 import BaseModel, ValidationError
@ -59,6 +65,7 @@ def parse_tool_call(
}
if return_id:
parsed["id"] = raw_tool_call.get("id")
parsed = create_tool_call(**parsed) # type: ignore
return parsed
@ -75,7 +82,7 @@ def make_invalid_tool_call(
Returns:
An InvalidToolCall instance with the error message.
"""
return InvalidToolCall(
return invalid_tool_call(
name=raw_tool_call["function"]["name"],
args=raw_tool_call["function"]["arguments"],
id=raw_tool_call.get("id"),

View File

@ -21,6 +21,7 @@ from __future__ import annotations
import asyncio
import inspect
import json
import textwrap
import uuid
import warnings
@ -34,6 +35,7 @@ from typing import (
Callable,
Dict,
List,
Literal,
Optional,
Sequence,
Tuple,
@ -42,7 +44,7 @@ from typing import (
get_type_hints,
)
from typing_extensions import Annotated, get_args, get_origin
from typing_extensions import Annotated, cast, get_args, get_origin
from langchain_core._api import deprecated
from langchain_core.callbacks import (
@ -56,6 +58,7 @@ from langchain_core.callbacks.manager import (
Callbacks,
)
from langchain_core.load.serializable import Serializable
from langchain_core.messages.tool import ToolCall, ToolMessage
from langchain_core.prompts import (
BasePromptTemplate,
PromptTemplate,
@ -306,7 +309,7 @@ class ToolException(Exception):
pass
class BaseTool(RunnableSerializable[Union[str, Dict], Any]):
class BaseTool(RunnableSerializable[Union[str, Dict, ToolCall], Any]):
"""Interface LangChain tools must implement."""
def __init_subclass__(cls, **kwargs: Any) -> None:
@ -378,6 +381,14 @@ class ChildTool(BaseTool):
] = False
"""Handle the content of the ValidationError thrown."""
response_format: Literal["content", "content_and_raw_output"] = "content"
"""The tool response format.
If "content" then the output of the tool is interpreted as the contents of a
ToolMessage. If "content_and_raw_output" then the output is expected to be a
two-tuple corresponding to the (content, raw_output) of a ToolMessage.
"""
class Config(Serializable.Config):
"""Configuration for this pydantic object."""
@ -410,46 +421,25 @@ class ChildTool(BaseTool):
def invoke(
self,
input: Union[str, Dict],
input: Union[str, Dict, ToolCall],
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> Any:
config = ensure_config(config)
return self.run(
input,
callbacks=config.get("callbacks"),
tags=config.get("tags"),
metadata=config.get("metadata"),
run_name=config.get("run_name"),
run_id=config.pop("run_id", None),
config=config,
**kwargs,
)
tool_input, kwargs = _prep_run_args(input, config, **kwargs)
return self.run(tool_input, **kwargs)
async def ainvoke(
self,
input: Union[str, Dict],
input: Union[str, Dict, ToolCall],
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> Any:
config = ensure_config(config)
return await self.arun(
input,
callbacks=config.get("callbacks"),
tags=config.get("tags"),
metadata=config.get("metadata"),
run_name=config.get("run_name"),
run_id=config.pop("run_id", None),
config=config,
**kwargs,
)
tool_input, kwargs = _prep_run_args(input, config, **kwargs)
return await self.arun(tool_input, **kwargs)
# --- Tool ---
def _parse_input(
self,
tool_input: Union[str, Dict],
) -> Union[str, Dict[str, Any]]:
def _parse_input(self, tool_input: Union[str, Dict]) -> Union[str, Dict[str, Any]]:
"""Convert tool input to pydantic model."""
input_args = self.args_schema
if isinstance(tool_input, str):
@ -465,7 +455,7 @@ class ChildTool(BaseTool):
for k, v in result.dict().items()
if k in tool_input
}
return tool_input
return tool_input
@root_validator(pre=True)
def raise_deprecation(cls, values: Dict) -> Dict:
@ -479,30 +469,27 @@ class ChildTool(BaseTool):
return values
@abstractmethod
def _run(
self,
*args: Any,
**kwargs: Any,
) -> Any:
def _run(self, *args: Any, **kwargs: Any) -> Any:
"""Use the tool.
Add run_manager: Optional[CallbackManagerForToolRun] = None
to child implementations to enable tracing,
to child implementations to enable tracing.
"""
async def _arun(
self,
*args: Any,
**kwargs: Any,
) -> Any:
async def _arun(self, *args: Any, **kwargs: Any) -> Any:
"""Use the tool asynchronously.
Add run_manager: Optional[AsyncCallbackManagerForToolRun] = None
to child implementations to enable tracing,
to child implementations to enable tracing.
"""
if kwargs.get("run_manager") and signature(self._run).parameters.get(
"run_manager"
):
kwargs["run_manager"] = kwargs["run_manager"].get_sync()
return await run_in_executor(None, self._run, *args, **kwargs)
def _to_args_and_kwargs(self, tool_input: Union[str, Dict]) -> Tuple[Tuple, Dict]:
tool_input = self._parse_input(tool_input)
# For backwards compatibility, if run_input is a string,
# pass as a positional argument.
if isinstance(tool_input, str):
@ -523,24 +510,20 @@ class ChildTool(BaseTool):
run_name: Optional[str] = None,
run_id: Optional[uuid.UUID] = None,
config: Optional[RunnableConfig] = None,
tool_call_id: Optional[str] = None,
**kwargs: Any,
) -> Any:
"""Run the tool."""
if not self.verbose and verbose is not None:
verbose_ = verbose
else:
verbose_ = self.verbose
callback_manager = CallbackManager.configure(
callbacks,
self.callbacks,
verbose_,
self.verbose or bool(verbose),
tags,
self.tags,
metadata,
self.metadata,
)
# TODO: maybe also pass through run_manager is _run supports kwargs
new_arg_supported = signature(self._run).parameters.get("run_manager")
run_manager = callback_manager.on_tool_start(
{"name": self.name, "description": self.description},
tool_input if isinstance(tool_input, str) else str(tool_input),
@ -550,67 +533,52 @@ class ChildTool(BaseTool):
# Inputs by definition should always be dicts.
# For now, it's unclear whether this assumption is ever violated,
# but if it is we will send a `None` value to the callback instead
# And will need to address issue via a patch.
inputs=None if isinstance(tool_input, str) else tool_input,
# TODO: will need to address issue via a patch.
inputs=tool_input if isinstance(tool_input, dict) else None,
**kwargs,
)
content = None
raw_output = None
error_to_raise: Union[Exception, KeyboardInterrupt, None] = None
try:
child_config = patch_config(
config,
callbacks=run_manager.get_child(),
)
child_config = patch_config(config, callbacks=run_manager.get_child())
context = copy_context()
context.run(_set_config_context, child_config)
parsed_input = self._parse_input(tool_input)
tool_args, tool_kwargs = self._to_args_and_kwargs(parsed_input)
observation = (
context.run(
self._run, *tool_args, run_manager=run_manager, **tool_kwargs
)
if new_arg_supported
else context.run(self._run, *tool_args, **tool_kwargs)
)
tool_args, tool_kwargs = self._to_args_and_kwargs(tool_input)
if signature(self._run).parameters.get("run_manager"):
tool_kwargs["run_manager"] = run_manager
response = context.run(self._run, *tool_args, **tool_kwargs)
if self.response_format == "content_and_raw_output":
if not isinstance(response, tuple) or len(response) != 2:
raise ValueError(
"Since response_format='content_and_raw_output' "
"a two-tuple of the message content and raw tool output is "
f"expected. Instead generated response of type: "
f"{type(response)}."
)
content, raw_output = response
else:
content = response
except ValidationError as e:
if not self.handle_validation_error:
raise e
elif isinstance(self.handle_validation_error, bool):
observation = "Tool input validation error"
elif isinstance(self.handle_validation_error, str):
observation = self.handle_validation_error
elif callable(self.handle_validation_error):
observation = self.handle_validation_error(e)
error_to_raise = e
else:
raise ValueError(
f"Got unexpected type of `handle_validation_error`. Expected bool, "
f"str or callable. Received: {self.handle_validation_error}"
)
return observation
content = _handle_validation_error(e, flag=self.handle_validation_error)
except ToolException as e:
if not self.handle_tool_error:
run_manager.on_tool_error(e)
raise e
elif isinstance(self.handle_tool_error, bool):
if e.args:
observation = e.args[0]
else:
observation = "Tool execution error"
elif isinstance(self.handle_tool_error, str):
observation = self.handle_tool_error
elif callable(self.handle_tool_error):
observation = self.handle_tool_error(e)
error_to_raise = e
else:
raise ValueError(
f"Got unexpected type of `handle_tool_error`. Expected bool, str "
f"or callable. Received: {self.handle_tool_error}"
)
run_manager.on_tool_end(observation, color="red", name=self.name, **kwargs)
return observation
content = _handle_tool_error(e, flag=self.handle_tool_error)
except (Exception, KeyboardInterrupt) as e:
run_manager.on_tool_error(e)
raise e
else:
run_manager.on_tool_end(observation, color=color, name=self.name, **kwargs)
return observation
error_to_raise = e
if error_to_raise:
run_manager.on_tool_error(error_to_raise)
raise error_to_raise
output = _format_output(content, raw_output, tool_call_id)
run_manager.on_tool_end(output, color=color, name=self.name, **kwargs)
return output
async def arun(
self,
@ -625,99 +593,80 @@ class ChildTool(BaseTool):
run_name: Optional[str] = None,
run_id: Optional[uuid.UUID] = None,
config: Optional[RunnableConfig] = None,
tool_call_id: Optional[str] = None,
**kwargs: Any,
) -> Any:
"""Run the tool asynchronously."""
if not self.verbose and verbose is not None:
verbose_ = verbose
else:
verbose_ = self.verbose
callback_manager = AsyncCallbackManager.configure(
callbacks,
self.callbacks,
verbose_,
self.verbose or bool(verbose),
tags,
self.tags,
metadata,
self.metadata,
)
new_arg_supported = signature(self._arun).parameters.get("run_manager")
run_manager = await callback_manager.on_tool_start(
{"name": self.name, "description": self.description},
tool_input if isinstance(tool_input, str) else str(tool_input),
color=start_color,
name=run_name,
inputs=tool_input,
run_id=run_id,
# Inputs by definition should always be dicts.
# For now, it's unclear whether this assumption is ever violated,
# but if it is we will send a `None` value to the callback instead
# TODO: will need to address issue via a patch.
inputs=tool_input if isinstance(tool_input, dict) else None,
**kwargs,
)
content = None
raw_output = None
error_to_raise: Optional[Union[Exception, KeyboardInterrupt]] = None
try:
parsed_input = self._parse_input(tool_input)
# We then call the tool on the tool input to get an observation
tool_args, tool_kwargs = self._to_args_and_kwargs(parsed_input)
child_config = patch_config(
config,
callbacks=run_manager.get_child(),
)
tool_args, tool_kwargs = self._to_args_and_kwargs(tool_input)
child_config = patch_config(config, callbacks=run_manager.get_child())
context = copy_context()
context.run(_set_config_context, child_config)
coro = (
context.run(
self._arun, *tool_args, run_manager=run_manager, **tool_kwargs
)
if new_arg_supported
else context.run(self._arun, *tool_args, **tool_kwargs)
)
if self.__class__._arun is BaseTool._arun or signature(
self._arun
).parameters.get("run_manager"):
tool_kwargs["run_manager"] = run_manager
coro = context.run(self._arun, *tool_args, **tool_kwargs)
if accepts_context(asyncio.create_task):
observation = await asyncio.create_task(coro, context=context) # type: ignore
response = await asyncio.create_task(coro, context=context) # type: ignore
else:
observation = await coro
response = await coro
if self.response_format == "content_and_raw_output":
if not isinstance(response, tuple) or len(response) != 2:
raise ValueError(
"Since response_format='content_and_raw_output' "
"a two-tuple of the message content and raw tool output is "
f"expected. Instead generated response of type: "
f"{type(response)}."
)
content, raw_output = response
else:
content = response
except ValidationError as e:
if not self.handle_validation_error:
raise e
elif isinstance(self.handle_validation_error, bool):
observation = "Tool input validation error"
elif isinstance(self.handle_validation_error, str):
observation = self.handle_validation_error
elif callable(self.handle_validation_error):
observation = self.handle_validation_error(e)
error_to_raise = e
else:
raise ValueError(
f"Got unexpected type of `handle_validation_error`. Expected bool, "
f"str or callable. Received: {self.handle_validation_error}"
)
return observation
content = _handle_validation_error(e, flag=self.handle_validation_error)
except ToolException as e:
if not self.handle_tool_error:
await run_manager.on_tool_error(e)
raise e
elif isinstance(self.handle_tool_error, bool):
if e.args:
observation = e.args[0]
else:
observation = "Tool execution error"
elif isinstance(self.handle_tool_error, str):
observation = self.handle_tool_error
elif callable(self.handle_tool_error):
observation = self.handle_tool_error(e)
error_to_raise = e
else:
raise ValueError(
f"Got unexpected type of `handle_tool_error`. Expected bool, str "
f"or callable. Received: {self.handle_tool_error}"
)
await run_manager.on_tool_end(
observation, color="red", name=self.name, **kwargs
)
return observation
content = _handle_tool_error(e, flag=self.handle_tool_error)
except (Exception, KeyboardInterrupt) as e:
await run_manager.on_tool_error(e)
raise e
else:
await run_manager.on_tool_end(
observation, color=color, name=self.name, **kwargs
)
return observation
error_to_raise = e
if error_to_raise:
await run_manager.on_tool_error(error_to_raise)
raise error_to_raise
output = _format_output(content, raw_output, tool_call_id)
await run_manager.on_tool_end(output, color=color, name=self.name, **kwargs)
return output
@deprecated("0.1.47", alternative="invoke", removal="0.3.0")
def __call__(self, tool_input: str, callbacks: Callbacks = None) -> str:
@ -738,7 +687,7 @@ class Tool(BaseTool):
async def ainvoke(
self,
input: Union[str, Dict],
input: Union[str, Dict, ToolCall],
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> Any:
@ -780,17 +729,10 @@ class Tool(BaseTool):
) -> Any:
"""Use the tool."""
if self.func:
new_argument_supported = signature(self.func).parameters.get("callbacks")
return (
self.func(
*args,
callbacks=run_manager.get_child() if run_manager else None,
**kwargs,
)
if new_argument_supported
else self.func(*args, **kwargs)
)
raise NotImplementedError("Tool does not support sync")
if run_manager and signature(self.func).parameters.get("callbacks"):
kwargs["callbacks"] = run_manager.get_child()
return self.func(*args, **kwargs)
raise NotImplementedError("Tool does not support sync invocation.")
async def _arun(
self,
@ -800,26 +742,13 @@ class Tool(BaseTool):
) -> Any:
"""Use the tool asynchronously."""
if self.coroutine:
new_argument_supported = signature(self.coroutine).parameters.get(
"callbacks"
)
return (
await self.coroutine(
*args,
callbacks=run_manager.get_child() if run_manager else None,
**kwargs,
)
if new_argument_supported
else await self.coroutine(*args, **kwargs)
)
else:
return await run_in_executor(
None,
self._run,
run_manager=run_manager.get_sync() if run_manager else None,
*args,
**kwargs,
)
if run_manager and signature(self.coroutine).parameters.get("callbacks"):
kwargs["callbacks"] = run_manager.get_child()
return await self.coroutine(*args, **kwargs)
# NOTE: this code is unreachable since _arun is only called if coroutine is not
# None.
return await super()._arun(*args, run_manager=run_manager, **kwargs)
# TODO: this is for backwards compatibility, remove in future
def __init__(
@ -870,9 +799,10 @@ class StructuredTool(BaseTool):
# --- Runnable ---
# TODO: Is this needed?
async def ainvoke(
self,
input: Union[str, Dict],
input: Union[str, Dict, ToolCall],
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> Any:
@ -897,45 +827,26 @@ class StructuredTool(BaseTool):
) -> Any:
"""Use the tool."""
if self.func:
new_argument_supported = signature(self.func).parameters.get("callbacks")
return (
self.func(
*args,
callbacks=run_manager.get_child() if run_manager else None,
**kwargs,
)
if new_argument_supported
else self.func(*args, **kwargs)
)
raise NotImplementedError("Tool does not support sync")
if run_manager and signature(self.func).parameters.get("callbacks"):
kwargs["callbacks"] = run_manager.get_child()
return self.func(*args, **kwargs)
raise NotImplementedError("StructuredTool does not support sync invocation.")
async def _arun(
self,
*args: Any,
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
**kwargs: Any,
) -> str:
) -> Any:
"""Use the tool asynchronously."""
if self.coroutine:
new_argument_supported = signature(self.coroutine).parameters.get(
"callbacks"
)
return (
await self.coroutine(
*args,
callbacks=run_manager.get_child() if run_manager else None,
**kwargs,
)
if new_argument_supported
else await self.coroutine(*args, **kwargs)
)
return await run_in_executor(
None,
self._run,
run_manager=run_manager.get_sync() if run_manager else None,
*args,
**kwargs,
)
if run_manager and signature(self.coroutine).parameters.get("callbacks"):
kwargs["callbacks"] = run_manager.get_child()
return await self.coroutine(*args, **kwargs)
# NOTE: this code is unreachable since _arun is only called if coroutine is not
# None.
return await super()._arun(*args, run_manager=run_manager, **kwargs)
@classmethod
def from_function(
@ -947,6 +858,8 @@ class StructuredTool(BaseTool):
return_direct: bool = False,
args_schema: Optional[Type[BaseModel]] = None,
infer_schema: bool = True,
*,
response_format: Literal["content", "content_and_raw_output"] = "content",
parse_docstring: bool = False,
error_on_invalid_docstring: bool = False,
**kwargs: Any,
@ -963,6 +876,10 @@ class StructuredTool(BaseTool):
return_direct: Whether to return the result directly or as a callback
args_schema: The schema of the tool's input arguments
infer_schema: Whether to infer the schema from the function's signature
response_format: The tool response format. If "content" then the output of
the tool is interpreted as the contents of a ToolMessage. If
"content_and_raw_output" then the output is expected to be a two-tuple
corresponding to the (content, raw_output) of a ToolMessage.
parse_docstring: if ``infer_schema`` and ``parse_docstring``, will attempt
to parse parameter descriptions from Google Style function docstrings.
error_on_invalid_docstring: if ``parse_docstring`` is provided, configures
@ -1020,6 +937,7 @@ class StructuredTool(BaseTool):
args_schema=_args_schema, # type: ignore[arg-type]
description=description_,
return_direct=return_direct,
response_format=response_format,
**kwargs,
)
@ -1029,6 +947,7 @@ def tool(
return_direct: bool = False,
args_schema: Optional[Type[BaseModel]] = None,
infer_schema: bool = True,
response_format: Literal["content", "content_and_raw_output"] = "content",
parse_docstring: bool = False,
error_on_invalid_docstring: bool = True,
) -> Callable:
@ -1042,6 +961,10 @@ def tool(
infer_schema: Whether to infer the schema of the arguments from
the function's signature. This also makes the resultant tool
accept a dictionary input to its `run()` function.
response_format: The tool response format. If "content" then the output of
the tool is interpreted as the contents of a ToolMessage. If
"content_and_raw_output" then the output is expected to be a two-tuple
corresponding to the (content, raw_output) of a ToolMessage.
parse_docstring: if ``infer_schema`` and ``parse_docstring``, will attempt to
parse parameter descriptions from Google Style function docstrings.
error_on_invalid_docstring: if ``parse_docstring`` is provided, configures
@ -1064,8 +987,12 @@ def tool(
# Searches the API for the query.
return
.. versionadded:: 0.2.14
Parse Google-style docstrings:
@tool(response_format="content_and_raw_output")
def search_api(query: str) -> Tuple[str, dict]:
return "partial json of results", {"full": "object of results"}
.. versionadded:: 0.2.14
Parse Google-style docstrings:
.. code-block:: python
@ -1179,6 +1106,7 @@ def tool(
return_direct=return_direct,
args_schema=schema,
infer_schema=infer_schema,
response_format=response_format,
parse_docstring=parse_docstring,
error_on_invalid_docstring=error_on_invalid_docstring,
)
@ -1195,6 +1123,7 @@ def tool(
description=f"{tool_name} tool",
return_direct=return_direct,
coroutine=coroutine,
response_format=response_format,
)
return _make_tool
@ -1350,6 +1279,103 @@ class BaseToolkit(BaseModel, ABC):
"""Get the tools in the toolkit."""
def _is_tool_call(x: Any) -> bool:
return isinstance(x, dict) and x.get("type") == "tool_call"
def _handle_validation_error(
e: ValidationError,
*,
flag: Union[Literal[True], str, Callable[[ValidationError], str]],
) -> str:
if isinstance(flag, bool):
content = "Tool input validation error"
elif isinstance(flag, str):
content = flag
elif callable(flag):
content = flag(e)
else:
raise ValueError(
f"Got unexpected type of `handle_validation_error`. Expected bool, "
f"str or callable. Received: {flag}"
)
return content
def _handle_tool_error(
e: ToolException,
*,
flag: Optional[Union[Literal[True], str, Callable[[ToolException], str]]],
) -> str:
if isinstance(flag, bool):
if e.args:
content = e.args[0]
else:
content = "Tool execution error"
elif isinstance(flag, str):
content = flag
elif callable(flag):
content = flag(e)
else:
raise ValueError(
f"Got unexpected type of `handle_tool_error`. Expected bool, str "
f"or callable. Received: {flag}"
)
return content
def _prep_run_args(
input: Union[str, dict, ToolCall],
config: Optional[RunnableConfig],
**kwargs: Any,
) -> Tuple[Union[str, Dict], Dict]:
config = ensure_config(config)
if _is_tool_call(input):
tool_call_id: Optional[str] = cast(ToolCall, input)["id"]
tool_input: Union[str, dict] = cast(ToolCall, input)["args"]
else:
tool_call_id = None
tool_input = cast(Union[str, dict], input)
return (
tool_input,
dict(
callbacks=config.get("callbacks"),
tags=config.get("tags"),
metadata=config.get("metadata"),
run_name=config.get("run_name"),
run_id=config.pop("run_id", None),
config=config,
tool_call_id=tool_call_id,
**kwargs,
),
)
def _format_output(
content: Any, raw_output: Any, tool_call_id: Optional[str]
) -> Union[ToolMessage, Any]:
if tool_call_id:
# NOTE: This will fail to stringify lists which aren't actually content blocks
# but whose first element happens to be a string or dict. Tools should avoid
# returning such contents.
if not isinstance(content, str) and not (
isinstance(content, list)
and content
and isinstance(content[0], (str, dict))
):
content = _stringify(content)
return ToolMessage(content, raw_output=raw_output, tool_call_id=tool_call_id)
else:
return content
def _stringify(content: Any) -> str:
try:
return json.dumps(content)
except Exception:
return str(content)
def _get_description_from_runnable(runnable: Runnable) -> str:
"""Generate a placeholder description of a runnable."""
input_schema = runnable.input_schema.schema()

View File

@ -317,6 +317,13 @@
'title': 'Name',
'type': 'string',
}),
'type': dict({
'enum': list([
'invalid_tool_call',
]),
'title': 'Type',
'type': 'string',
}),
}),
'required': list([
'name',
@ -419,6 +426,13 @@
'title': 'Name',
'type': 'string',
}),
'type': dict({
'enum': list([
'tool_call',
]),
'title': 'Type',
'type': 'string',
}),
}),
'required': list([
'name',
@ -908,6 +922,13 @@
'title': 'Name',
'type': 'string',
}),
'type': dict({
'enum': list([
'invalid_tool_call',
]),
'title': 'Type',
'type': 'string',
}),
}),
'required': list([
'name',
@ -1010,6 +1031,13 @@
'title': 'Name',
'type': 'string',
}),
'type': dict({
'enum': list([
'tool_call',
]),
'title': 'Type',
'type': 'string',
}),
}),
'required': list([
'name',

View File

@ -674,6 +674,13 @@
'title': 'Name',
'type': 'string',
}),
'type': dict({
'enum': list([
'invalid_tool_call',
]),
'title': 'Type',
'type': 'string',
}),
}),
'required': list([
'name',
@ -776,6 +783,13 @@
'title': 'Name',
'type': 'string',
}),
'type': dict({
'enum': list([
'tool_call',
]),
'title': 'Type',
'type': 'string',
}),
}),
'required': list([
'name',

View File

@ -5577,6 +5577,13 @@
'title': 'Name',
'type': 'string',
}),
'type': dict({
'enum': list([
'invalid_tool_call',
]),
'title': 'Type',
'type': 'string',
}),
}),
'required': list([
'name',
@ -5701,6 +5708,13 @@
'title': 'Name',
'type': 'string',
}),
'type': dict({
'enum': list([
'tool_call',
]),
'title': 'Type',
'type': 'string',
}),
}),
'required': list([
'name',
@ -6237,6 +6251,13 @@
'title': 'Name',
'type': 'string',
}),
'type': dict({
'enum': list([
'invalid_tool_call',
]),
'title': 'Type',
'type': 'string',
}),
}),
'required': list([
'name',
@ -6361,6 +6382,13 @@
'title': 'Name',
'type': 'string',
}),
'type': dict({
'enum': list([
'tool_call',
]),
'title': 'Type',
'type': 'string',
}),
}),
'required': list([
'name',
@ -6834,6 +6862,13 @@
'title': 'Name',
'type': 'string',
}),
'type': dict({
'enum': list([
'invalid_tool_call',
]),
'title': 'Type',
'type': 'string',
}),
}),
'required': list([
'name',
@ -6936,6 +6971,13 @@
'title': 'Name',
'type': 'string',
}),
'type': dict({
'enum': list([
'tool_call',
]),
'title': 'Type',
'type': 'string',
}),
}),
'required': list([
'name',
@ -7444,6 +7486,13 @@
'title': 'Name',
'type': 'string',
}),
'type': dict({
'enum': list([
'invalid_tool_call',
]),
'title': 'Type',
'type': 'string',
}),
}),
'required': list([
'name',
@ -7568,6 +7617,13 @@
'title': 'Name',
'type': 'string',
}),
'type': dict({
'enum': list([
'tool_call',
]),
'title': 'Type',
'type': 'string',
}),
}),
'required': list([
'name',
@ -8068,6 +8124,13 @@
'title': 'Name',
'type': 'string',
}),
'type': dict({
'enum': list([
'invalid_tool_call',
]),
'title': 'Type',
'type': 'string',
}),
}),
'required': list([
'name',
@ -8203,6 +8266,13 @@
'title': 'Name',
'type': 'string',
}),
'type': dict({
'enum': list([
'tool_call',
]),
'title': 'Type',
'type': 'string',
}),
}),
'required': list([
'name',
@ -8683,6 +8753,13 @@
'title': 'Name',
'type': 'string',
}),
'type': dict({
'enum': list([
'invalid_tool_call',
]),
'title': 'Type',
'type': 'string',
}),
}),
'required': list([
'name',
@ -8785,6 +8862,13 @@
'title': 'Name',
'type': 'string',
}),
'type': dict({
'enum': list([
'tool_call',
]),
'title': 'Type',
'type': 'string',
}),
}),
'required': list([
'name',
@ -9238,6 +9322,13 @@
'title': 'Name',
'type': 'string',
}),
'type': dict({
'enum': list([
'invalid_tool_call',
]),
'title': 'Type',
'type': 'string',
}),
}),
'required': list([
'name',
@ -9340,6 +9431,13 @@
'title': 'Name',
'type': 'string',
}),
'type': dict({
'enum': list([
'tool_call',
]),
'title': 'Type',
'type': 'string',
}),
}),
'required': list([
'name',
@ -9880,6 +9978,13 @@
'title': 'Name',
'type': 'string',
}),
'type': dict({
'enum': list([
'invalid_tool_call',
]),
'title': 'Type',
'type': 'string',
}),
}),
'required': list([
'name',
@ -10004,6 +10109,13 @@
'title': 'Name',
'type': 'string',
}),
'type': dict({
'enum': list([
'tool_call',
]),
'title': 'Type',
'type': 'string',
}),
}),
'required': list([
'name',

View File

@ -8,7 +8,7 @@ import textwrap
from datetime import datetime
from enum import Enum
from functools import partial
from typing import Any, Callable, Dict, List, Optional, Type, Union
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Type, Union
import pytest
from typing_extensions import Annotated, TypedDict
@ -17,6 +17,7 @@ from langchain_core.callbacks import (
AsyncCallbackManagerForToolRun,
CallbackManagerForToolRun,
)
from langchain_core.messages import ToolMessage
from langchain_core.pydantic_v1 import BaseModel, ValidationError
from langchain_core.runnables import Runnable, RunnableLambda, ensure_config
from langchain_core.tools import (
@ -1067,6 +1068,65 @@ def test_tool_annotated_descriptions() -> None:
}
def test_tool_call_input_tool_message_output() -> None:
tool_call = {
"name": "structured_api",
"args": {"arg1": 1, "arg2": True, "arg3": {"img": "base64string..."}},
"id": "123",
"type": "tool_call",
}
tool = _MockStructuredTool()
expected = ToolMessage("1 True {'img': 'base64string...'}", tool_call_id="123")
actual = tool.invoke(tool_call)
assert actual == expected
tool_call.pop("type")
with pytest.raises(ValidationError):
tool.invoke(tool_call)
class _MockStructuredToolWithRawOutput(BaseTool):
name: str = "structured_api"
args_schema: Type[BaseModel] = _MockSchema
description: str = "A Structured Tool"
response_format: Literal["content_and_raw_output"] = "content_and_raw_output"
def _run(
self, arg1: int, arg2: bool, arg3: Optional[dict] = None
) -> Tuple[str, dict]:
return f"{arg1} {arg2}", {"arg1": arg1, "arg2": arg2, "arg3": arg3}
@tool("structured_api", response_format="content_and_raw_output")
def _mock_structured_tool_with_raw_output(
arg1: int, arg2: bool, arg3: Optional[dict] = None
) -> Tuple[str, dict]:
"""A Structured Tool"""
return f"{arg1} {arg2}", {"arg1": arg1, "arg2": arg2, "arg3": arg3}
@pytest.mark.parametrize(
"tool", [_MockStructuredToolWithRawOutput(), _mock_structured_tool_with_raw_output]
)
def test_tool_call_input_tool_message_with_raw_output(tool: BaseTool) -> None:
tool_call: Dict = {
"name": "structured_api",
"args": {"arg1": 1, "arg2": True, "arg3": {"img": "base64string..."}},
"id": "123",
"type": "tool_call",
}
expected = ToolMessage("1 True", raw_output=tool_call["args"], tool_call_id="123")
actual = tool.invoke(tool_call)
assert actual == expected
tool_call.pop("type")
with pytest.raises(ValidationError):
tool.invoke(tool_call)
actual_content = tool.invoke(tool_call["args"])
assert actual_content == expected.content
def test_convert_from_runnable_dict() -> None:
# Test with typed dict input
class Args(TypedDict):

View File

@ -1,4 +1,4 @@
# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand.
# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand.
[[package]]
name = "aiohttp"
@ -1760,7 +1760,7 @@ files = [
[[package]]
name = "langchain-core"
version = "0.2.12"
version = "0.2.13"
description = "Building applications with LLMs through composability"
optional = false
python-versions = ">=3.8.1,<4.0"
@ -1784,7 +1784,7 @@ url = "../core"
[[package]]
name = "langchain-openai"
version = "0.1.14"
version = "0.1.15"
description = "An integration package connecting OpenAI and LangChain"
optional = true
python-versions = ">=3.8.1,<4.0"
@ -1792,7 +1792,7 @@ files = []
develop = true
[package.dependencies]
langchain-core = ">=0.2.2,<0.3"
langchain-core = "^0.2.13"
openai = "^1.32.0"
tiktoken = ">=0.7,<1"
@ -1834,13 +1834,13 @@ types-requests = ">=2.31.0.2,<3.0.0.0"
[[package]]
name = "langsmith"
version = "0.1.84"
version = "0.1.85"
description = "Client library to connect to the LangSmith LLM Tracing and Evaluation Platform."
optional = false
python-versions = "<4.0,>=3.8.1"
files = [
{file = "langsmith-0.1.84-py3-none-any.whl", hash = "sha256:01f3c6390dba26c583bac8dd0e551ce3d0509c7f55cad714db0b5c8d36e4c7ff"},
{file = "langsmith-0.1.84.tar.gz", hash = "sha256:5220c0439838b9a5bd320fd3686be505c5083dcee22d2452006c23891153bea1"},
{file = "langsmith-0.1.85-py3-none-any.whl", hash = "sha256:c1f94384f10cea96f7b4d33fd3db7ec180c03c7468877d50846f881d2017ff94"},
{file = "langsmith-0.1.85.tar.gz", hash = "sha256:acff31f9e53efa48586cf8e32f65625a335c74d7c4fa306d1655ac18452296f6"},
]
[package.dependencies]
@ -2350,13 +2350,13 @@ files = [
[[package]]
name = "openai"
version = "1.35.10"
version = "1.35.13"
description = "The official Python library for the openai API"
optional = true
python-versions = ">=3.7.1"
files = [
{file = "openai-1.35.10-py3-none-any.whl", hash = "sha256:962cb5c23224b5cbd16078308dabab97a08b0a5ad736a4fdb3dc2ffc44ac974f"},
{file = "openai-1.35.10.tar.gz", hash = "sha256:85966949f4f960f3e4b239a659f9fd64d3a97ecc43c44dc0a044b5c7f11cccc6"},
{file = "openai-1.35.13-py3-none-any.whl", hash = "sha256:36ec3e93e0d1f243f69be85c89b9221a471c3e450dfd9df16c9829e3cdf63e60"},
{file = "openai-1.35.13.tar.gz", hash = "sha256:c684f3945608baf7d2dcc0ef3ee6f3e27e4c66f21076df0b47be45d57e6ae6e4"},
]
[package.dependencies]
@ -4141,13 +4141,13 @@ urllib3 = ">=2"
[[package]]
name = "types-setuptools"
version = "70.2.0.20240704"
version = "70.3.0.20240710"
description = "Typing stubs for setuptools"
optional = false
python-versions = ">=3.8"
files = [
{file = "types-setuptools-70.2.0.20240704.tar.gz", hash = "sha256:2f8d28d16ca1607080f9fdf19595bd49c942884b2bbd6529c9b8a9a8fc8db911"},
{file = "types_setuptools-70.2.0.20240704-py3-none-any.whl", hash = "sha256:6b892d5441c2ed58dd255724516e3df1db54892fb20597599aea66d04c3e4d7f"},
{file = "types-setuptools-70.3.0.20240710.tar.gz", hash = "sha256:842cbf399812d2b65042c9d6ff35113bbf282dee38794779aa1f94e597bafc35"},
{file = "types_setuptools-70.3.0.20240710-py3-none-any.whl", hash = "sha256:bd0db2a4b9f2c49ac5564be4e0fb3125c4c46b1f73eafdcbceffa5b005cceca4"},
]
[[package]]

View File

@ -43,6 +43,7 @@ from langchain_core.messages import (
ToolMessage,
)
from langchain_core.messages.ai import UsageMetadata
from langchain_core.messages.tool import tool_call_chunk as create_tool_call_chunk
from langchain_core.output_parsers import (
JsonOutputKeyToolsParser,
PydanticToolsParser,
@ -1102,12 +1103,12 @@ def _make_message_chunk_from_anthropic_event(
warnings.warn("Received unexpected tool content block.")
content_block = event.content_block.model_dump()
content_block["index"] = event.index
tool_call_chunk = {
"index": event.index,
"id": event.content_block.id,
"name": event.content_block.name,
"args": "",
}
tool_call_chunk = create_tool_call_chunk(
index=event.index,
id=event.content_block.id,
name=event.content_block.name,
args="",
)
message_chunk = AIMessageChunk(
content=[content_block],
tool_call_chunks=[tool_call_chunk], # type: ignore

View File

@ -1,6 +1,7 @@
from typing import Any, List, Optional, Type, Union, cast
from langchain_core.messages import AIMessage, ToolCall
from langchain_core.messages.tool import tool_call
from langchain_core.output_parsers import BaseGenerationOutputParser
from langchain_core.outputs import ChatGeneration, Generation
from langchain_core.pydantic_v1 import BaseModel
@ -79,7 +80,7 @@ def extract_tool_calls(content: Union[str, List[Union[str, dict]]]) -> List[Tool
if block["type"] != "tool_use":
continue
tool_calls.append(
ToolCall(name=block["name"], args=block["input"], id=block["id"])
tool_call(name=block["name"], args=block["input"], id=block["id"])
)
return tool_calls
else:

View File

@ -365,10 +365,7 @@ async def test_astreaming() -> None:
def test_tool_use() -> None:
llm = ChatAnthropic( # type: ignore[call-arg]
model=MODEL_NAME,
)
llm = ChatAnthropic(model=MODEL_NAME) # type: ignore[call-arg]
llm_with_tools = llm.bind_tools(
[
{
@ -478,6 +475,7 @@ def test_anthropic_with_empty_text_block() -> None:
"name": "type_letter",
"args": {"letter": "d"},
"id": "toolu_01V6d6W32QGGSmQm4BT98EKk",
"type": "tool_call",
},
],
),

View File

@ -33,8 +33,20 @@ class _Foo2(BaseModel):
def test_tools_output_parser() -> None:
output_parser = ToolsOutputParser()
expected = [
{"name": "_Foo1", "args": {"bar": 0}, "id": "1", "index": 1},
{"name": "_Foo2", "args": {"baz": "a"}, "id": "2", "index": 3},
{
"name": "_Foo1",
"args": {"bar": 0},
"id": "1",
"index": 1,
"type": "tool_call",
},
{
"name": "_Foo2",
"args": {"baz": "a"},
"id": "2",
"index": 3,
"type": "tool_call",
},
]
actual = output_parser.parse_result(_RESULT)
assert expected == actual
@ -56,7 +68,13 @@ def test_tools_output_parser_args_only() -> None:
def test_tools_output_parser_first_tool_only() -> None:
output_parser = ToolsOutputParser(first_tool_only=True)
expected: Any = {"name": "_Foo1", "args": {"bar": 0}, "id": "1", "index": 1}
expected: Any = {
"name": "_Foo1",
"args": {"bar": 0},
"id": "1",
"index": 1,
"type": "tool_call",
}
actual = output_parser.parse_result(_RESULT)
assert expected == actual
@ -81,7 +99,14 @@ def test_tools_output_parser_empty_content() -> None:
)
message = AIMessage(
"",
tool_calls=[{"name": "ChartType", "args": {"chart_type": "pie"}, "id": "foo"}],
tool_calls=[
{
"name": "ChartType",
"args": {"chart_type": "pie"},
"id": "foo",
"type": "tool_call",
}
],
)
actual = output_parser.invoke(message)
expected = ChartType(chart_type="pie")

View File

@ -9,10 +9,11 @@ import json
import os
import re
import urllib
from copy import deepcopy
from dataclasses import dataclass
from datetime import datetime, timedelta, timezone
from io import BytesIO
from typing import Any, BinaryIO, Callable, List, Optional
from typing import Any, BinaryIO, Callable, List, Literal, Optional, Tuple
from uuid import uuid4
import requests
@ -126,6 +127,8 @@ class SessionsPythonREPLTool(BaseTool):
session_id: str = str(uuid4())
"""The session ID to use for the code interpreter. Defaults to a random UUID."""
response_format: Literal["content_and_raw_output"] = "content_and_raw_output"
def _build_url(self, path: str) -> str:
pool_management_endpoint = self.pool_management_endpoint
if not pool_management_endpoint:
@ -164,16 +167,16 @@ class SessionsPythonREPLTool(BaseTool):
properties = response_json.get("properties", {})
return properties
def _run(self, python_code: str) -> Any:
def _run(self, python_code: str, **kwargs: Any) -> Tuple[str, dict]:
response = self.execute(python_code)
# if the result is an image, remove the base64 data
result = response.get("result")
result = deepcopy(response.get("result"))
if isinstance(result, dict):
if result.get("type") == "image" and "base64_data" in result:
result.pop("base64_data")
return json.dumps(
content = json.dumps(
{
"result": result,
"stdout": response.get("stdout"),
@ -181,6 +184,7 @@ class SessionsPythonREPLTool(BaseTool):
},
indent=2,
)
return content, response
def upload_file(
self,

View File

@ -54,6 +54,12 @@ from langchain_core.messages import (
ToolMessage,
ToolMessageChunk,
)
from langchain_core.messages.tool import (
ToolCallChunk,
)
from langchain_core.messages.tool import (
tool_call_chunk as create_tool_call_chunk,
)
from langchain_core.output_parsers import JsonOutputParser, PydanticOutputParser
from langchain_core.output_parsers.base import OutputParserLike
from langchain_core.output_parsers.openai_tools import (
@ -199,6 +205,7 @@ def _convert_chunk_to_message_chunk(
role = cast(str, _dict.get("role"))
content = cast(str, _dict.get("content") or "")
additional_kwargs: Dict = {}
tool_call_chunks: List[ToolCallChunk] = []
if _dict.get("function_call"):
function_call = dict(_dict["function_call"])
if "name" in function_call and function_call["name"] is None:
@ -206,21 +213,18 @@ def _convert_chunk_to_message_chunk(
additional_kwargs["function_call"] = function_call
if raw_tool_calls := _dict.get("tool_calls"):
additional_kwargs["tool_calls"] = raw_tool_calls
try:
tool_call_chunks = [
{
"name": rtc["function"].get("name"),
"args": rtc["function"].get("arguments"),
"id": rtc.get("id"),
"index": rtc["index"],
}
for rtc in raw_tool_calls
]
except KeyError:
pass
else:
tool_call_chunks = []
for rtc in raw_tool_calls:
try:
tool_call_chunks.append(
create_tool_call_chunk(
name=rtc["function"].get("name"),
args=rtc["function"].get("arguments"),
id=rtc.get("id"),
index=rtc.get("index"),
)
)
except KeyError:
pass
if role == "user" or default_class == HumanMessageChunk:
return HumanMessageChunk(content=content)
elif role == "assistant" or default_class == AIMessageChunk:
@ -237,7 +241,7 @@ def _convert_chunk_to_message_chunk(
return AIMessageChunk(
content=content,
additional_kwargs=additional_kwargs,
tool_call_chunks=tool_call_chunks, # type: ignore[arg-type]
tool_call_chunks=tool_call_chunks,
usage_metadata=usage_metadata, # type: ignore[arg-type]
)
elif role == "system" or default_class == SystemMessageChunk:

View File

@ -53,6 +53,7 @@ from langchain_core.messages import (
ToolMessage,
ToolMessageChunk,
)
from langchain_core.messages.tool import tool_call_chunk as create_tool_call_chunk
from langchain_core.output_parsers import (
JsonOutputParser,
PydanticOutputParser,
@ -511,19 +512,19 @@ class ChatGroq(BaseChatModel):
generation = chat_result.generations[0]
message = cast(AIMessage, generation.message)
tool_call_chunks = [
{
"name": rtc["function"].get("name"),
"args": rtc["function"].get("arguments"),
"id": rtc.get("id"),
"index": rtc.get("index"),
}
create_tool_call_chunk(
name=rtc["function"].get("name"),
args=rtc["function"].get("arguments"),
id=rtc.get("id"),
index=rtc.get("index"),
)
for rtc in message.additional_kwargs.get("tool_calls", [])
]
chunk_ = ChatGenerationChunk(
message=AIMessageChunk(
content=message.content,
additional_kwargs=message.additional_kwargs,
tool_call_chunks=tool_call_chunks, # type: ignore[arg-type]
tool_call_chunks=tool_call_chunks,
usage_metadata=message.usage_metadata,
),
generation_info=generation.generation_info,

View File

@ -77,6 +77,7 @@ def test__convert_dict_to_message_tool_call() -> None:
name="GenerateUsername",
args={"name": "Sally", "hair_color": "green"},
id="call_wm0JY6CdwOMZ4eTxHWUThDNz",
type="tool_call",
)
],
)
@ -112,6 +113,7 @@ def test__convert_dict_to_message_tool_call() -> None:
args="oops",
id="call_wm0JY6CdwOMZ4eTxHWUThDNz",
error="Function GenerateUsername arguments:\n\noops\n\nare not valid JSON. Received JSONDecodeError Expecting value: line 1 column 1 (char 0)", # noqa: E501
type="invalid_tool_call",
),
],
tool_calls=[
@ -119,6 +121,7 @@ def test__convert_dict_to_message_tool_call() -> None:
name="GenerateUsername",
args={"name": "Sally", "hair_color": "green"},
id="call_abc123",
type="tool_call",
),
],
)

View File

@ -42,10 +42,12 @@ from langchain_core.messages import (
HumanMessageChunk,
SystemMessage,
SystemMessageChunk,
ToolCallChunk,
ToolMessage,
ToolMessageChunk,
convert_to_messages,
)
from langchain_core.messages.tool import tool_call_chunk as create_tool_call_chunk
from langchain_core.output_parsers import JsonOutputParser, PydanticOutputParser
from langchain_core.output_parsers.base import OutputParserLike
from langchain_core.output_parsers.openai_tools import (
@ -174,6 +176,7 @@ def _convert_delta_to_message_chunk(
role = cast(str, _dict.get("role"))
content = cast(str, _dict.get("content") or "")
additional_kwargs: Dict = {}
tool_call_chunks: List[ToolCallChunk] = []
if _dict.get("function_call"):
function_call = dict(_dict["function_call"])
if "name" in function_call and function_call["name"] is None:
@ -181,21 +184,18 @@ def _convert_delta_to_message_chunk(
additional_kwargs["function_call"] = function_call
if raw_tool_calls := _dict.get("tool_calls"):
additional_kwargs["tool_calls"] = raw_tool_calls
try:
tool_call_chunks = [
{
"name": rtc["function"].get("name"),
"args": rtc["function"].get("arguments"),
"id": rtc.get("id"),
"index": rtc["index"],
}
for rtc in raw_tool_calls
]
except KeyError:
pass
else:
tool_call_chunks = []
for rtc in raw_tool_calls:
try:
tool_call_chunks.append(
create_tool_call_chunk(
name=rtc["function"].get("name"),
args=rtc["function"].get("arguments"),
id=rtc.get("id"),
index=rtc.get("index"),
)
)
except KeyError:
pass
if role == "user" or default_class == HumanMessageChunk:
return HumanMessageChunk(content=content)
elif role == "assistant" or default_class == AIMessageChunk:

View File

@ -50,6 +50,7 @@ from langchain_core.messages import (
ToolCall,
ToolMessage,
)
from langchain_core.messages.tool import tool_call_chunk
from langchain_core.output_parsers import (
JsonOutputParser,
PydanticOutputParser,
@ -103,19 +104,10 @@ def _convert_mistral_chat_message_to_message(
dict, parse_tool_call(raw_tool_call, return_id=True)
)
if not parsed["id"]:
tool_call_id = uuid.uuid4().hex[:]
tool_calls.append(
{
**parsed,
**{"id": tool_call_id},
},
)
else:
tool_calls.append(parsed)
parsed["id"] = uuid.uuid4().hex[:]
tool_calls.append(parsed)
except Exception as e:
invalid_tool_calls.append(
dict(make_invalid_tool_call(raw_tool_call, str(e)))
)
invalid_tool_calls.append(make_invalid_tool_call(raw_tool_call, str(e)))
return AIMessage(
content=content,
additional_kwargs=additional_kwargs,
@ -206,12 +198,12 @@ def _convert_chunk_to_message_chunk(
else:
tool_call_id = raw_tool_call.get("id")
tool_call_chunks.append(
{
"name": raw_tool_call["function"].get("name"),
"args": raw_tool_call["function"].get("arguments"),
"id": tool_call_id,
"index": raw_tool_call.get("index"),
}
tool_call_chunk(
name=raw_tool_call["function"].get("name"),
args=raw_tool_call["function"].get("arguments"),
id=tool_call_id,
index=raw_tool_call.get("index"),
)
)
except KeyError:
pass

View File

@ -144,6 +144,7 @@ def test__convert_dict_to_message_tool_call() -> None:
name="GenerateUsername",
args={"name": "Sally", "hair_color": "green"},
id="abc123",
type="tool_call",
)
],
)
@ -178,6 +179,7 @@ def test__convert_dict_to_message_tool_call() -> None:
args="oops",
error="Function GenerateUsername arguments:\n\noops\n\nare not valid JSON. Received JSONDecodeError Expecting value: line 1 column 1 (char 0)", # noqa: E501
id="abc123",
type="invalid_tool_call",
),
],
tool_calls=[
@ -185,6 +187,7 @@ def test__convert_dict_to_message_tool_call() -> None:
name="GenerateUsername",
args={"name": "Sally", "hair_color": "green"},
id="def456",
type="tool_call",
),
],
)

View File

@ -63,6 +63,7 @@ from langchain_core.messages import (
ToolMessageChunk,
)
from langchain_core.messages.ai import UsageMetadata
from langchain_core.messages.tool import tool_call_chunk
from langchain_core.output_parsers import JsonOutputParser, PydanticOutputParser
from langchain_core.output_parsers.base import OutputParserLike
from langchain_core.output_parsers.openai_tools import (
@ -244,12 +245,12 @@ def _convert_delta_to_message_chunk(
additional_kwargs["tool_calls"] = raw_tool_calls
try:
tool_call_chunks = [
{
"name": rtc["function"].get("name"),
"args": rtc["function"].get("arguments"),
"id": rtc.get("id"),
"index": rtc["index"],
}
tool_call_chunk(
name=rtc["function"].get("name"),
args=rtc["function"].get("arguments"),
id=rtc.get("id"),
index=rtc["index"],
)
for rtc in raw_tool_calls
]
except KeyError:

View File

@ -117,6 +117,7 @@ def test__convert_dict_to_message_tool_call() -> None:
name="GenerateUsername",
args={"name": "Sally", "hair_color": "green"},
id="call_wm0JY6CdwOMZ4eTxHWUThDNz",
type="tool_call",
)
],
)
@ -151,6 +152,7 @@ def test__convert_dict_to_message_tool_call() -> None:
args="oops",
id="call_wm0JY6CdwOMZ4eTxHWUThDNz",
error="Function GenerateUsername arguments:\n\noops\n\nare not valid JSON. Received JSONDecodeError Expecting value: line 1 column 1 (char 0)", # noqa: E501
type="invalid_tool_call",
)
],
tool_calls=[
@ -158,6 +160,7 @@ def test__convert_dict_to_message_tool_call() -> None:
name="GenerateUsername",
args={"name": "Sally", "hair_color": "green"},
id="call_abc123",
type="tool_call",
)
],
)
@ -353,7 +356,10 @@ def test_get_num_tokens_from_messages() -> None:
),
AIMessage("a nice bird"),
AIMessage(
"", tool_calls=[ToolCall(id="foo", name="bar", args={"arg1": "arg1"})]
"",
tool_calls=[
ToolCall(id="foo", name="bar", args={"arg1": "arg1"}, type="tool_call")
],
),
AIMessage(
"",
@ -362,7 +368,10 @@ def test_get_num_tokens_from_messages() -> None:
},
),
AIMessage(
"text", tool_calls=[ToolCall(id="foo", name="bar", args={"arg1": "arg1"})]
"text",
tool_calls=[
ToolCall(id="foo", name="bar", args={"arg1": "arg1"}, type="tool_call")
],
),
ToolMessage("foobar", tool_call_id="foo"),
]