cohere[patch]: misc fixs tool use agent and cohere chat (#19705)

Bug fixes in this PR:
* allows for other params such as "message" not just the input param to
the prompt for the cohere tools agent
* fixes to documents kwarg from messages
* fixes to tool_calls API call

---------

Co-authored-by: Harry M <127103098+harry-cohere@users.noreply.github.com>
This commit is contained in:
BeatrixCohere 2024-03-28 17:19:38 +00:00 committed by GitHub
parent b35e68c41f
commit d1a2e194c3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 342 additions and 31 deletions

View File

@ -0,0 +1,237 @@
{
"cells": [
{
"cell_type": "raw",
"metadata": {},
"source": [
"---\n",
"sidebar_position: 0\n",
"---"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Cohere Tools\n",
"\n",
"The following notebook goes over how to use the Cohere tools agent:"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Prerequisites for this notebook:"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Requirement already satisfied: langchain in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (0.1.13)\n",
"Requirement already satisfied: langchain-cohere in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (0.1.0rc2)\n",
"Requirement already satisfied: PyYAML>=5.3 in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from langchain) (6.0.1)\n",
"Requirement already satisfied: SQLAlchemy<3,>=1.4 in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from langchain) (2.0.27)\n",
"Requirement already satisfied: aiohttp<4.0.0,>=3.8.3 in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from langchain) (3.9.3)\n",
"Requirement already satisfied: dataclasses-json<0.7,>=0.5.7 in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from langchain) (0.6.4)\n",
"Requirement already satisfied: jsonpatch<2.0,>=1.33 in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from langchain) (1.33)\n",
"Requirement already satisfied: langchain-community<0.1,>=0.0.29 in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from langchain) (0.0.29)\n",
"Requirement already satisfied: langchain-core<0.2.0,>=0.1.33 in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from langchain) (0.1.35)\n",
"Requirement already satisfied: langchain-text-splitters<0.1,>=0.0.1 in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from langchain) (0.0.1)\n",
"Requirement already satisfied: langsmith<0.2.0,>=0.1.17 in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from langchain) (0.1.31)\n",
"Requirement already satisfied: numpy<2,>=1 in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from langchain) (1.24.4)\n",
"Requirement already satisfied: pydantic<3,>=1 in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from langchain) (2.6.4)\n",
"Requirement already satisfied: requests<3,>=2 in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from langchain) (2.31.0)\n",
"Requirement already satisfied: tenacity<9.0.0,>=8.1.0 in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from langchain) (8.2.3)\n",
"Requirement already satisfied: cohere<6.0.0,>=5.1.4 in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from langchain-cohere) (5.1.4)\n",
"Requirement already satisfied: aiosignal>=1.1.2 in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from aiohttp<4.0.0,>=3.8.3->langchain) (1.3.1)\n",
"Requirement already satisfied: attrs>=17.3.0 in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from aiohttp<4.0.0,>=3.8.3->langchain) (23.2.0)\n",
"Requirement already satisfied: frozenlist>=1.1.1 in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from aiohttp<4.0.0,>=3.8.3->langchain) (1.4.1)\n",
"Requirement already satisfied: multidict<7.0,>=4.5 in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from aiohttp<4.0.0,>=3.8.3->langchain) (6.0.5)\n",
"Requirement already satisfied: yarl<2.0,>=1.0 in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from aiohttp<4.0.0,>=3.8.3->langchain) (1.9.4)\n",
"Requirement already satisfied: httpx>=0.21.2 in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from cohere<6.0.0,>=5.1.4->langchain-cohere) (0.27.0)\n",
"Requirement already satisfied: typing_extensions>=4.0.0 in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from cohere<6.0.0,>=5.1.4->langchain-cohere) (4.10.0)\n",
"Requirement already satisfied: marshmallow<4.0.0,>=3.18.0 in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from dataclasses-json<0.7,>=0.5.7->langchain) (3.20.2)\n",
"Requirement already satisfied: typing-inspect<1,>=0.4.0 in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from dataclasses-json<0.7,>=0.5.7->langchain) (0.9.0)\n",
"Requirement already satisfied: jsonpointer>=1.9 in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from jsonpatch<2.0,>=1.33->langchain) (2.4)\n",
"Requirement already satisfied: packaging<24.0,>=23.2 in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from langchain-core<0.2.0,>=0.1.33->langchain) (23.2)\n",
"Requirement already satisfied: orjson<4.0.0,>=3.9.14 in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from langsmith<0.2.0,>=0.1.17->langchain) (3.9.15)\n",
"Requirement already satisfied: annotated-types>=0.4.0 in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from pydantic<3,>=1->langchain) (0.6.0)\n",
"Requirement already satisfied: pydantic-core==2.16.3 in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from pydantic<3,>=1->langchain) (2.16.3)\n",
"Requirement already satisfied: charset-normalizer<4,>=2 in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from requests<3,>=2->langchain) (3.3.2)\n",
"Requirement already satisfied: idna<4,>=2.5 in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from requests<3,>=2->langchain) (3.6)\n",
"Requirement already satisfied: urllib3<3,>=1.21.1 in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from requests<3,>=2->langchain) (2.2.1)\n",
"Requirement already satisfied: certifi>=2017.4.17 in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from requests<3,>=2->langchain) (2024.2.2)\n",
"Requirement already satisfied: anyio in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from httpx>=0.21.2->cohere<6.0.0,>=5.1.4->langchain-cohere) (4.3.0)\n",
"Requirement already satisfied: httpcore==1.* in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from httpx>=0.21.2->cohere<6.0.0,>=5.1.4->langchain-cohere) (1.0.4)\n",
"Requirement already satisfied: sniffio in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from httpx>=0.21.2->cohere<6.0.0,>=5.1.4->langchain-cohere) (1.3.1)\n",
"Requirement already satisfied: h11<0.15,>=0.13 in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from httpcore==1.*->httpx>=0.21.2->cohere<6.0.0,>=5.1.4->langchain-cohere) (0.14.0)\n",
"Requirement already satisfied: mypy-extensions>=0.3.0 in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from typing-inspect<1,>=0.4.0->dataclasses-json<0.7,>=0.5.7->langchain) (1.0.0)\n",
"Requirement already satisfied: wikipedia in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (1.4.0)\n",
"Requirement already satisfied: beautifulsoup4 in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from wikipedia) (4.12.3)\n",
"Requirement already satisfied: requests<3.0.0,>=2.0.0 in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from wikipedia) (2.31.0)\n",
"Requirement already satisfied: charset-normalizer<4,>=2 in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from requests<3.0.0,>=2.0.0->wikipedia) (3.3.2)\n",
"Requirement already satisfied: idna<4,>=2.5 in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from requests<3.0.0,>=2.0.0->wikipedia) (3.6)\n",
"Requirement already satisfied: urllib3<3,>=1.21.1 in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from requests<3.0.0,>=2.0.0->wikipedia) (2.2.1)\n",
"Requirement already satisfied: certifi>=2017.4.17 in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from requests<3.0.0,>=2.0.0->wikipedia) (2024.2.2)\n",
"Requirement already satisfied: soupsieve>1.2 in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from beautifulsoup4->wikipedia) (2.5)\n"
]
}
],
"source": [
"# install package\n",
"!pip install langchain langchain-cohere\n",
"!pip install wikipedia"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"from langchain.agents import AgentExecutor\n",
"from langchain.retrievers import WikipediaRetriever\n",
"from langchain.tools.retriever import create_retriever_tool\n",
"from langchain_cohere import create_cohere_tools_agent\n",
"from langchain_cohere.chat_models import ChatCohere\n",
"from langchain_core.prompts import ChatPromptTemplate"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Next we create the prompt template and cohere model"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"# Create the prompt\n",
"prompt = ChatPromptTemplate.from_template(\n",
" \"Write all output in capital letters. {input}\"\n",
")\n",
"\n",
"# Create the Cohere chat model\n",
"chat = ChatCohere(cohere_api_key=\"API_KEY\", model=\"command-r\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In this example we use a Wikipedia retrieval tool "
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"retriever = WikipediaRetriever()\n",
"retriever_tool = create_retriever_tool(\n",
" retriever,\n",
" \"wikipedia\",\n",
" \"Search for information on Wikipedia\",\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Next, create the cohere tool agent and call with the input"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
"\u001b[32;1m\u001b[1;3mwikipedia\u001b[0m\u001b[36;1m\u001b[1;3m\u001b[0m\u001b[32;1m\u001b[1;3m\u001b[0m\n",
"\n",
"\u001b[1m> Finished chain.\u001b[0m\n"
]
},
{
"data": {
"text/plain": [
"{'input': 'Who founded Cohere?',\n",
" 'text': 'COHERE WAS FOUNDED BY AIDAN GOMEZ, IVAN ZAPATA, AND ALON GELLA.',\n",
" 'additional_info': {'documents': [{'answer': '',\n",
" 'id': 'wikipedia:0:0',\n",
" 'tool_name': 'wikipedia'}],\n",
" 'citations': [ChatCitation(start=22, end=63, text='AIDAN GOMEZ, IVAN ZAPATA, AND ALON GELLA.', document_ids=['wikipedia:0:0'])],\n",
" 'search_results': None,\n",
" 'search_queries': None,\n",
" 'is_search_required': None,\n",
" 'generation_id': '3b7e96be-8aad-4fa0-9ae3-7a38e800c289',\n",
" 'token_count': {'prompt_tokens': 740,\n",
" 'response_tokens': 27,\n",
" 'total_tokens': 767,\n",
" 'billed_tokens': 48}}}"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"agent = create_cohere_tools_agent(\n",
" llm=chat,\n",
" tools=[retriever_tool],\n",
" prompt=prompt,\n",
")\n",
"agent_executor = AgentExecutor(agent=agent, tools=[retriever_tool], verbose=True)\n",
"agent_executor.invoke({\"input\": \"Who founded Cohere?\"})"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.7"
}
},
"nbformat": 4,
"nbformat_minor": 4
}

View File

@ -81,25 +81,22 @@ def get_cohere_chat_request(
additional_kwargs = messages[-1].additional_kwargs additional_kwargs = messages[-1].additional_kwargs
# cohere SDK will fail loudly if both connectors and documents are provided # cohere SDK will fail loudly if both connectors and documents are provided
if ( if additional_kwargs.get("documents", []) and documents and len(documents) > 0:
len(additional_kwargs.get("documents", [])) > 0
and documents
and len(documents) > 0
):
raise ValueError( raise ValueError(
"Received documents both as a keyword argument and as an prompt additional" "Received documents both as a keyword argument and as an prompt additional keyword argument. Please choose only one option." # noqa: E501
"keywword argument. Please choose only one option."
) )
formatted_docs = [ formatted_docs: Optional[List[Dict[str, Any]]] = None
{ if additional_kwargs.get("documents"):
"text": doc.page_content, formatted_docs = [
"id": doc.metadata.get("id") or f"doc-{str(i)}", {
} "text": doc.page_content,
for i, doc in enumerate(additional_kwargs.get("documents", [])) "id": doc.metadata.get("id") or f"doc-{str(i)}",
] or documents }
if not formatted_docs: for i, doc in enumerate(additional_kwargs.get("documents", []))
formatted_docs = None ]
elif documents:
formatted_docs = documents
# by enabling automatic prompt truncation, the probability of request failure is # by enabling automatic prompt truncation, the probability of request failure is
# reduced with minimal impact on response quality # reduced with minimal impact on response quality

View File

@ -1,6 +1,12 @@
import json
from typing import Any, Dict, List, Sequence, Tuple, Type, Union from typing import Any, Dict, List, Sequence, Tuple, Type, Union
from cohere.types import Tool, ToolParameterDefinitionsValue from cohere.types import (
ChatRequestToolResultsItem,
Tool,
ToolCall,
ToolParameterDefinitionsValue,
)
from langchain_core.agents import AgentAction, AgentFinish from langchain_core.agents import AgentAction, AgentFinish
from langchain_core.language_models import BaseLanguageModel from langchain_core.language_models import BaseLanguageModel
from langchain_core.output_parsers import BaseOutputParser from langchain_core.output_parsers import BaseOutputParser
@ -30,9 +36,7 @@ def create_cohere_tools_agent(
RunnablePassthrough.assign( RunnablePassthrough.assign(
# Intermediate steps are in tool results. # Intermediate steps are in tool results.
# Edit below to change the prompt parameters. # Edit below to change the prompt parameters.
input=lambda x: prompt.format_messages( input=lambda x: prompt.format_messages(**x, agent_scratchpad=[]),
input=x["input"], agent_scratchpad=[]
),
tools=lambda x: _format_to_cohere_tools(tools), tools=lambda x: _format_to_cohere_tools(tools),
tool_results=lambda x: _format_to_cohere_tools_messages( tool_results=lambda x: _format_to_cohere_tools_messages(
x["intermediate_steps"] x["intermediate_steps"]
@ -52,20 +56,35 @@ def _format_to_cohere_tools(
def _format_to_cohere_tools_messages( def _format_to_cohere_tools_messages(
intermediate_steps: Sequence[Tuple[AgentAction, str]], intermediate_steps: Sequence[Tuple[AgentAction, str]],
) -> list: ) -> List[Dict[str, Any]]:
"""Convert (AgentAction, tool output) tuples into tool messages.""" """Convert (AgentAction, tool output) tuples into tool messages."""
if len(intermediate_steps) == 0: if len(intermediate_steps) == 0:
return [] return []
tool_results = [] tool_results = []
for agent_action, observation in intermediate_steps: for agent_action, observation in intermediate_steps:
# agent_action.tool_input can be a dict, serialised dict, or string.
# Cohere API only accepts a dict.
tool_call_parameters: Dict[str, Any]
if isinstance(agent_action.tool_input, dict):
# tool_input is a dict, use as-is.
tool_call_parameters = agent_action.tool_input
else:
try:
# tool_input is serialised dict.
tool_call_parameters = json.loads(agent_action.tool_input)
if not isinstance(tool_call_parameters, dict):
raise ValueError()
except ValueError:
# tool_input is a string, last ditch attempt at having something useful.
tool_call_parameters = {"input": agent_action.tool_input}
tool_results.append( tool_results.append(
{ ChatRequestToolResultsItem(
"call": { call=ToolCall(
"name": agent_action.tool, name=agent_action.tool,
"parameters": agent_action.tool_input, parameters=tool_call_parameters,
}, ),
"outputs": [{"answer": observation}], outputs=[{"answer": observation}],
} ).dict()
) )
return tool_results return tool_results
@ -143,7 +162,7 @@ class _CohereToolsAgentOutputParser(
) -> Union[List[AgentAction], AgentFinish]: ) -> Union[List[AgentAction], AgentFinish]:
if not isinstance(result[0], ChatGeneration): if not isinstance(result[0], ChatGeneration):
raise ValueError(f"Expected ChatGeneration, got {type(result)}") raise ValueError(f"Expected ChatGeneration, got {type(result)}")
if result[0].message.additional_kwargs["tool_calls"]: if "tool_calls" in result[0].message.additional_kwargs:
actions = [] actions = []
for tool in result[0].message.additional_kwargs["tool_calls"]: for tool in result[0].message.additional_kwargs["tool_calls"]:
function = tool.get("function", {}) function = tool.get("function", {})

View File

@ -1,9 +1,14 @@
from typing import Any, Dict, Optional, Type, Union import json
from typing import Any, Dict, List, Optional, Tuple, Type, Union
import pytest import pytest
from langchain_core.agents import AgentAction
from langchain_core.tools import BaseModel, BaseTool, Field from langchain_core.tools import BaseModel, BaseTool, Field
from langchain_cohere.cohere_agent import _format_to_cohere_tools from langchain_cohere.cohere_agent import (
_format_to_cohere_tools,
_format_to_cohere_tools_messages,
)
expected_test_tool_definition = { expected_test_tool_definition = {
"description": "test_tool description", "description": "test_tool description",
@ -80,3 +85,56 @@ def test_format_to_cohere_tools(
actual = _format_to_cohere_tools([tool]) actual = _format_to_cohere_tools([tool])
assert [expected_test_tool_definition] == actual assert [expected_test_tool_definition] == actual
@pytest.mark.parametrize(
"intermediate_step,expected",
[
pytest.param(
(
AgentAction(tool="tool_name", tool_input={"arg1": "value1"}, log=""),
"result",
),
{
"call": {"name": "tool_name", "parameters": {"arg1": "value1"}},
"outputs": [{"answer": "result"}],
},
id="tool_input as dict",
),
pytest.param(
(
AgentAction(
tool="tool_name", tool_input=json.dumps({"arg1": "value1"}), log=""
),
"result",
),
{
"call": {"name": "tool_name", "parameters": {"arg1": "value1"}},
"outputs": [{"answer": "result"}],
},
id="tool_input as serialized dict",
),
pytest.param(
(AgentAction(tool="tool_name", tool_input="foo", log=""), "result"),
{
"call": {"name": "tool_name", "parameters": {"input": "foo"}},
"outputs": [{"answer": "result"}],
},
id="tool_input as string",
),
pytest.param(
(AgentAction(tool="tool_name", tool_input="['foo']", log=""), "result"),
{
"call": {"name": "tool_name", "parameters": {"input": "['foo']"}},
"outputs": [{"answer": "result"}],
},
id="tool_input unrelated JSON",
),
],
)
def test_format_to_cohere_tools_messages(
intermediate_step: Tuple[AgentAction, str], expected: List[Dict[str, Any]]
) -> None:
actual = _format_to_cohere_tools_messages(intermediate_steps=[intermediate_step])
assert [expected] == actual