mirror of https://github.com/hwchase17/langchain
cohere[patch]: Add cohere tools agent (#19602)
**Description**: Adds a cohere tools agent and related notebook. --------- Co-authored-by: BeatrixCohere <128378696+BeatrixCohere@users.noreply.github.com> Co-authored-by: Erick Friis <erick@langchain.dev>pull/19689/head
parent
5c41f4083e
commit
3685f8ceac
@ -0,0 +1,168 @@
|
||||
from typing import Any, Dict, List, Sequence, Tuple, Type, Union
|
||||
|
||||
from cohere.types import Tool, ToolParameterDefinitionsValue
|
||||
from langchain_core.agents import AgentAction, AgentFinish
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
from langchain_core.output_parsers import BaseOutputParser
|
||||
from langchain_core.outputs import Generation
|
||||
from langchain_core.outputs.chat_generation import ChatGeneration
|
||||
from langchain_core.prompts.chat import ChatPromptTemplate
|
||||
from langchain_core.pydantic_v1 import BaseModel
|
||||
from langchain_core.runnables import Runnable, RunnablePassthrough
|
||||
from langchain_core.runnables.base import RunnableLambda
|
||||
from langchain_core.tools import BaseTool
|
||||
from langchain_core.utils.function_calling import convert_to_openai_function
|
||||
|
||||
|
||||
def create_cohere_tools_agent(
|
||||
llm: BaseLanguageModel, tools: Sequence[BaseTool], prompt: ChatPromptTemplate
|
||||
) -> Runnable:
|
||||
def llm_with_tools(input_: Dict) -> Runnable:
|
||||
tool_results = (
|
||||
input_["tool_results"] if len(input_["tool_results"]) > 0 else None
|
||||
)
|
||||
tools_ = input_["tools"] if len(input_["tools"]) > 0 else None
|
||||
return RunnableLambda(lambda x: x["input"]) | llm.bind(
|
||||
tools=tools_, tool_results=tool_results
|
||||
)
|
||||
|
||||
agent = (
|
||||
RunnablePassthrough.assign(
|
||||
# Intermediate steps are in tool results.
|
||||
# Edit below to change the prompt parameters.
|
||||
input=lambda x: prompt.format_messages(
|
||||
input=x["input"], agent_scratchpad=[]
|
||||
),
|
||||
tools=lambda x: _format_to_cohere_tools(tools),
|
||||
tool_results=lambda x: _format_to_cohere_tools_messages(
|
||||
x["intermediate_steps"]
|
||||
),
|
||||
)
|
||||
| llm_with_tools
|
||||
| _CohereToolsAgentOutputParser()
|
||||
)
|
||||
return agent
|
||||
|
||||
|
||||
def _format_to_cohere_tools(
|
||||
tools: Sequence[Union[Dict[str, Any], BaseTool, Type[BaseModel]]],
|
||||
) -> List[Dict[str, Any]]:
|
||||
return [_convert_to_cohere_tool(tool) for tool in tools]
|
||||
|
||||
|
||||
def _format_to_cohere_tools_messages(
|
||||
intermediate_steps: Sequence[Tuple[AgentAction, str]],
|
||||
) -> list:
|
||||
"""Convert (AgentAction, tool output) tuples into tool messages."""
|
||||
if len(intermediate_steps) == 0:
|
||||
return []
|
||||
tool_results = []
|
||||
for agent_action, observation in intermediate_steps:
|
||||
tool_results.append(
|
||||
{
|
||||
"call": {
|
||||
"name": agent_action.tool,
|
||||
"parameters": agent_action.tool_input,
|
||||
},
|
||||
"outputs": [{"answer": observation}],
|
||||
}
|
||||
)
|
||||
|
||||
return tool_results
|
||||
|
||||
|
||||
def _convert_to_cohere_tool(
|
||||
tool: Union[Dict[str, Any], BaseTool, Type[BaseModel]],
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Convert a BaseTool instance, JSON schema dict, or BaseModel type to a Cohere tool.
|
||||
"""
|
||||
if isinstance(tool, BaseTool):
|
||||
return Tool(
|
||||
name=tool.name,
|
||||
description=tool.description,
|
||||
parameter_definitions={
|
||||
param_name: ToolParameterDefinitionsValue(
|
||||
description=param_definition.get("description"),
|
||||
type=param_definition.get("type"),
|
||||
required="default" not in param_definition,
|
||||
)
|
||||
for param_name, param_definition in tool.args.items()
|
||||
},
|
||||
).dict()
|
||||
elif isinstance(tool, dict):
|
||||
if not all(k in tool for k in ("title", "description", "properties")):
|
||||
raise ValueError(
|
||||
"Unsupported dict type. Tool must be passed in as a BaseTool instance, JSON schema dict, or BaseModel type." # noqa: E501
|
||||
)
|
||||
return Tool(
|
||||
name=tool.get("title"),
|
||||
description=tool.get("description"),
|
||||
parameter_definitions={
|
||||
param_name: ToolParameterDefinitionsValue(
|
||||
description=param_definition.get("description"),
|
||||
type=param_definition.get("type"),
|
||||
required="default" not in param_definition,
|
||||
)
|
||||
for param_name, param_definition in tool.get("properties", {}).items()
|
||||
},
|
||||
).dict()
|
||||
elif issubclass(tool, BaseModel):
|
||||
as_json_schema_function = convert_to_openai_function(tool)
|
||||
parameters = as_json_schema_function.get("parameters", {})
|
||||
properties = parameters.get("properties", {})
|
||||
return Tool(
|
||||
name=as_json_schema_function.get("name"),
|
||||
description=as_json_schema_function.get(
|
||||
# The Cohere API requires the description field.
|
||||
"description",
|
||||
as_json_schema_function.get("name"),
|
||||
),
|
||||
parameter_definitions={
|
||||
param_name: ToolParameterDefinitionsValue(
|
||||
description=param_definition.get("description"),
|
||||
type=param_definition.get("type"),
|
||||
required=param_name in parameters.get("required", []),
|
||||
)
|
||||
for param_name, param_definition in properties.items()
|
||||
},
|
||||
).dict()
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported tool type {type(tool)}. Tool must be passed in as a BaseTool instance, JSON schema dict, or BaseModel type." # noqa: E501
|
||||
)
|
||||
|
||||
|
||||
class _CohereToolsAgentOutputParser(
|
||||
BaseOutputParser[Union[List[AgentAction], AgentFinish]]
|
||||
):
|
||||
"""Parses a message into agent actions/finish."""
|
||||
|
||||
def parse_result(
|
||||
self, result: List[Generation], *, partial: bool = False
|
||||
) -> Union[List[AgentAction], AgentFinish]:
|
||||
if not isinstance(result[0], ChatGeneration):
|
||||
raise ValueError(f"Expected ChatGeneration, got {type(result)}")
|
||||
if result[0].message.additional_kwargs["tool_calls"]:
|
||||
actions = []
|
||||
for tool in result[0].message.additional_kwargs["tool_calls"]:
|
||||
function = tool.get("function", {})
|
||||
actions.append(
|
||||
AgentAction(
|
||||
tool=function.get("name"),
|
||||
tool_input=function.get("arguments"),
|
||||
log=function.get("name"),
|
||||
)
|
||||
)
|
||||
return actions
|
||||
else:
|
||||
return AgentFinish(
|
||||
return_values={
|
||||
"text": result[0].message.content,
|
||||
"additional_info": result[0].message.additional_kwargs,
|
||||
},
|
||||
log="",
|
||||
)
|
||||
|
||||
def parse(self, text: str) -> Union[List[AgentAction], AgentFinish]:
|
||||
raise ValueError("Can only parse messages")
|
@ -0,0 +1,82 @@
|
||||
from typing import Any, Dict, Optional, Type, Union
|
||||
|
||||
import pytest
|
||||
from langchain_core.tools import BaseModel, BaseTool, Field
|
||||
|
||||
from langchain_cohere.cohere_agent import _format_to_cohere_tools
|
||||
|
||||
expected_test_tool_definition = {
|
||||
"description": "test_tool description",
|
||||
"name": "test_tool",
|
||||
"parameter_definitions": {
|
||||
"arg_1": {
|
||||
"description": "Arg1 description",
|
||||
"required": True,
|
||||
"type": "string",
|
||||
},
|
||||
"optional_arg_2": {
|
||||
"description": "Arg2 description",
|
||||
"required": False,
|
||||
"type": "string",
|
||||
},
|
||||
"arg_3": {
|
||||
"description": "Arg3 description",
|
||||
"required": True,
|
||||
"type": "integer",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class _TestToolSchema(BaseModel):
|
||||
arg_1: str = Field(description="Arg1 description")
|
||||
optional_arg_2: Optional[str] = Field(description="Arg2 description", default="2")
|
||||
arg_3: int = Field(description="Arg3 description")
|
||||
|
||||
|
||||
class _TestTool(BaseTool):
|
||||
name = "test_tool"
|
||||
description = "test_tool description"
|
||||
args_schema: Type[_TestToolSchema] = _TestToolSchema
|
||||
|
||||
def _run(self, *args: Any, **kwargs: Any) -> Any:
|
||||
pass
|
||||
|
||||
|
||||
class test_tool(BaseModel):
|
||||
"""test_tool description"""
|
||||
|
||||
arg_1: str = Field(description="Arg1 description")
|
||||
optional_arg_2: Optional[str] = Field(description="Arg2 description", default="2")
|
||||
arg_3: int = Field(description="Arg3 description")
|
||||
|
||||
|
||||
test_tool_as_dict = {
|
||||
"title": "test_tool",
|
||||
"description": "test_tool description",
|
||||
"properties": {
|
||||
"arg_1": {"description": "Arg1 description", "type": "string"},
|
||||
"optional_arg_2": {
|
||||
"description": "Arg2 description",
|
||||
"type": "string",
|
||||
"default": "2",
|
||||
},
|
||||
"arg_3": {"description": "Arg3 description", "type": "integer"},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"tool",
|
||||
[
|
||||
pytest.param(_TestTool(), id="tool from BaseTool"),
|
||||
pytest.param(test_tool, id="BaseModel"),
|
||||
pytest.param(test_tool_as_dict, id="JSON schema dict"),
|
||||
],
|
||||
)
|
||||
def test_format_to_cohere_tools(
|
||||
tool: Union[Dict[str, Any], BaseTool, Type[BaseModel]],
|
||||
) -> None:
|
||||
actual = _format_to_cohere_tools([tool])
|
||||
|
||||
assert [expected_test_tool_definition] == actual
|
Loading…
Reference in New Issue