From d1a2e194c376f241116bf8e520f1a9bb297cdf3a Mon Sep 17 00:00:00 2001 From: BeatrixCohere <128378696+BeatrixCohere@users.noreply.github.com> Date: Thu, 28 Mar 2024 17:19:38 +0000 Subject: [PATCH] cohere[patch]: misc fixs tool use agent and cohere chat (#19705) Bug fixes in this PR: * allows for other params such as "message" not just the input param to the prompt for the cohere tools agent * fixes to documents kwarg from messages * fixes to tool_calls API call --------- Co-authored-by: Harry M <127103098+harry-cohere@users.noreply.github.com> --- libs/partners/cohere/docs/cohere_agent.ipynb | 237 ++++++++++++++++++ .../cohere/langchain_cohere/chat_models.py | 29 +-- .../cohere/langchain_cohere/cohere_agent.py | 45 +++- .../tests/unit_tests/test_cohere_agent.py | 62 ++++- 4 files changed, 342 insertions(+), 31 deletions(-) create mode 100644 libs/partners/cohere/docs/cohere_agent.ipynb diff --git a/libs/partners/cohere/docs/cohere_agent.ipynb b/libs/partners/cohere/docs/cohere_agent.ipynb new file mode 100644 index 0000000000..c2e8c78f74 --- /dev/null +++ b/libs/partners/cohere/docs/cohere_agent.ipynb @@ -0,0 +1,237 @@ +{ + "cells": [ + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "---\n", + "sidebar_position: 0\n", + "---" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Cohere Tools\n", + "\n", + "The following notebook goes over how to use the Cohere tools agent:" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Prerequisites for this notebook:" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Requirement already satisfied: langchain in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (0.1.13)\n", + "Requirement already satisfied: langchain-cohere in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (0.1.0rc2)\n", + "Requirement already satisfied: PyYAML>=5.3 in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from langchain) (6.0.1)\n", + "Requirement already satisfied: SQLAlchemy<3,>=1.4 in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from langchain) (2.0.27)\n", + "Requirement already satisfied: aiohttp<4.0.0,>=3.8.3 in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from langchain) (3.9.3)\n", + "Requirement already satisfied: dataclasses-json<0.7,>=0.5.7 in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from langchain) (0.6.4)\n", + "Requirement already satisfied: jsonpatch<2.0,>=1.33 in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from langchain) (1.33)\n", + "Requirement already satisfied: langchain-community<0.1,>=0.0.29 in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from langchain) (0.0.29)\n", + "Requirement already satisfied: langchain-core<0.2.0,>=0.1.33 in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from langchain) (0.1.35)\n", + "Requirement already satisfied: langchain-text-splitters<0.1,>=0.0.1 in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from langchain) (0.0.1)\n", + "Requirement already satisfied: langsmith<0.2.0,>=0.1.17 in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from langchain) (0.1.31)\n", + "Requirement already satisfied: numpy<2,>=1 in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from langchain) (1.24.4)\n", + "Requirement already satisfied: pydantic<3,>=1 in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from langchain) (2.6.4)\n", + "Requirement already satisfied: requests<3,>=2 in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from langchain) (2.31.0)\n", + "Requirement already satisfied: tenacity<9.0.0,>=8.1.0 in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from langchain) (8.2.3)\n", + "Requirement already satisfied: cohere<6.0.0,>=5.1.4 in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from langchain-cohere) (5.1.4)\n", + "Requirement already satisfied: aiosignal>=1.1.2 in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from aiohttp<4.0.0,>=3.8.3->langchain) (1.3.1)\n", + "Requirement already satisfied: attrs>=17.3.0 in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from aiohttp<4.0.0,>=3.8.3->langchain) (23.2.0)\n", + "Requirement already satisfied: frozenlist>=1.1.1 in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from aiohttp<4.0.0,>=3.8.3->langchain) (1.4.1)\n", + "Requirement already satisfied: multidict<7.0,>=4.5 in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from aiohttp<4.0.0,>=3.8.3->langchain) (6.0.5)\n", + "Requirement already satisfied: yarl<2.0,>=1.0 in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from aiohttp<4.0.0,>=3.8.3->langchain) (1.9.4)\n", + "Requirement already satisfied: httpx>=0.21.2 in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from cohere<6.0.0,>=5.1.4->langchain-cohere) (0.27.0)\n", + "Requirement already satisfied: typing_extensions>=4.0.0 in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from cohere<6.0.0,>=5.1.4->langchain-cohere) (4.10.0)\n", + "Requirement already satisfied: marshmallow<4.0.0,>=3.18.0 in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from dataclasses-json<0.7,>=0.5.7->langchain) (3.20.2)\n", + "Requirement already satisfied: typing-inspect<1,>=0.4.0 in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from dataclasses-json<0.7,>=0.5.7->langchain) (0.9.0)\n", + "Requirement already satisfied: jsonpointer>=1.9 in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from jsonpatch<2.0,>=1.33->langchain) (2.4)\n", + "Requirement already satisfied: packaging<24.0,>=23.2 in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from langchain-core<0.2.0,>=0.1.33->langchain) (23.2)\n", + "Requirement already satisfied: orjson<4.0.0,>=3.9.14 in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from langsmith<0.2.0,>=0.1.17->langchain) (3.9.15)\n", + "Requirement already satisfied: annotated-types>=0.4.0 in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from pydantic<3,>=1->langchain) (0.6.0)\n", + "Requirement already satisfied: pydantic-core==2.16.3 in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from pydantic<3,>=1->langchain) (2.16.3)\n", + "Requirement already satisfied: charset-normalizer<4,>=2 in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from requests<3,>=2->langchain) (3.3.2)\n", + "Requirement already satisfied: idna<4,>=2.5 in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from requests<3,>=2->langchain) (3.6)\n", + "Requirement already satisfied: urllib3<3,>=1.21.1 in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from requests<3,>=2->langchain) (2.2.1)\n", + "Requirement already satisfied: certifi>=2017.4.17 in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from requests<3,>=2->langchain) (2024.2.2)\n", + "Requirement already satisfied: anyio in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from httpx>=0.21.2->cohere<6.0.0,>=5.1.4->langchain-cohere) (4.3.0)\n", + "Requirement already satisfied: httpcore==1.* in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from httpx>=0.21.2->cohere<6.0.0,>=5.1.4->langchain-cohere) (1.0.4)\n", + "Requirement already satisfied: sniffio in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from httpx>=0.21.2->cohere<6.0.0,>=5.1.4->langchain-cohere) (1.3.1)\n", + "Requirement already satisfied: h11<0.15,>=0.13 in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from httpcore==1.*->httpx>=0.21.2->cohere<6.0.0,>=5.1.4->langchain-cohere) (0.14.0)\n", + "Requirement already satisfied: mypy-extensions>=0.3.0 in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from typing-inspect<1,>=0.4.0->dataclasses-json<0.7,>=0.5.7->langchain) (1.0.0)\n", + "Requirement already satisfied: wikipedia in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (1.4.0)\n", + "Requirement already satisfied: beautifulsoup4 in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from wikipedia) (4.12.3)\n", + "Requirement already satisfied: requests<3.0.0,>=2.0.0 in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from wikipedia) (2.31.0)\n", + "Requirement already satisfied: charset-normalizer<4,>=2 in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from requests<3.0.0,>=2.0.0->wikipedia) (3.3.2)\n", + "Requirement already satisfied: idna<4,>=2.5 in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from requests<3.0.0,>=2.0.0->wikipedia) (3.6)\n", + "Requirement already satisfied: urllib3<3,>=1.21.1 in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from requests<3.0.0,>=2.0.0->wikipedia) (2.2.1)\n", + "Requirement already satisfied: certifi>=2017.4.17 in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from requests<3.0.0,>=2.0.0->wikipedia) (2024.2.2)\n", + "Requirement already satisfied: soupsieve>1.2 in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from beautifulsoup4->wikipedia) (2.5)\n" + ] + } + ], + "source": [ + "# install package\n", + "!pip install langchain langchain-cohere\n", + "!pip install wikipedia" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "from langchain.agents import AgentExecutor\n", + "from langchain.retrievers import WikipediaRetriever\n", + "from langchain.tools.retriever import create_retriever_tool\n", + "from langchain_cohere import create_cohere_tools_agent\n", + "from langchain_cohere.chat_models import ChatCohere\n", + "from langchain_core.prompts import ChatPromptTemplate" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Next we create the prompt template and cohere model" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "# Create the prompt\n", + "prompt = ChatPromptTemplate.from_template(\n", + " \"Write all output in capital letters. {input}\"\n", + ")\n", + "\n", + "# Create the Cohere chat model\n", + "chat = ChatCohere(cohere_api_key=\"API_KEY\", model=\"command-r\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In this example we use a Wikipedia retrieval tool " + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "retriever = WikipediaRetriever()\n", + "retriever_tool = create_retriever_tool(\n", + " retriever,\n", + " \"wikipedia\",\n", + " \"Search for information on Wikipedia\",\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Next, create the cohere tool agent and call with the input" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\n", + "\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n", + "\u001b[32;1m\u001b[1;3mwikipedia\u001b[0m\u001b[36;1m\u001b[1;3m\u001b[0m\u001b[32;1m\u001b[1;3m\u001b[0m\n", + "\n", + "\u001b[1m> Finished chain.\u001b[0m\n" + ] + }, + { + "data": { + "text/plain": [ + "{'input': 'Who founded Cohere?',\n", + " 'text': 'COHERE WAS FOUNDED BY AIDAN GOMEZ, IVAN ZAPATA, AND ALON GELLA.',\n", + " 'additional_info': {'documents': [{'answer': '',\n", + " 'id': 'wikipedia:0:0',\n", + " 'tool_name': 'wikipedia'}],\n", + " 'citations': [ChatCitation(start=22, end=63, text='AIDAN GOMEZ, IVAN ZAPATA, AND ALON GELLA.', document_ids=['wikipedia:0:0'])],\n", + " 'search_results': None,\n", + " 'search_queries': None,\n", + " 'is_search_required': None,\n", + " 'generation_id': '3b7e96be-8aad-4fa0-9ae3-7a38e800c289',\n", + " 'token_count': {'prompt_tokens': 740,\n", + " 'response_tokens': 27,\n", + " 'total_tokens': 767,\n", + " 'billed_tokens': 48}}}" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "agent = create_cohere_tools_agent(\n", + " llm=chat,\n", + " tools=[retriever_tool],\n", + " prompt=prompt,\n", + ")\n", + "agent_executor = AgentExecutor(agent=agent, tools=[retriever_tool], verbose=True)\n", + "agent_executor.invoke({\"input\": \"Who founded Cohere?\"})" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.7" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/libs/partners/cohere/langchain_cohere/chat_models.py b/libs/partners/cohere/langchain_cohere/chat_models.py index 8babfc3e16..004973fb93 100644 --- a/libs/partners/cohere/langchain_cohere/chat_models.py +++ b/libs/partners/cohere/langchain_cohere/chat_models.py @@ -81,25 +81,22 @@ def get_cohere_chat_request( additional_kwargs = messages[-1].additional_kwargs # cohere SDK will fail loudly if both connectors and documents are provided - if ( - len(additional_kwargs.get("documents", [])) > 0 - and documents - and len(documents) > 0 - ): + if additional_kwargs.get("documents", []) and documents and len(documents) > 0: raise ValueError( - "Received documents both as a keyword argument and as an prompt additional" - "keywword argument. Please choose only one option." + "Received documents both as a keyword argument and as an prompt additional keyword argument. Please choose only one option." # noqa: E501 ) - formatted_docs = [ - { - "text": doc.page_content, - "id": doc.metadata.get("id") or f"doc-{str(i)}", - } - for i, doc in enumerate(additional_kwargs.get("documents", [])) - ] or documents - if not formatted_docs: - formatted_docs = None + formatted_docs: Optional[List[Dict[str, Any]]] = None + if additional_kwargs.get("documents"): + formatted_docs = [ + { + "text": doc.page_content, + "id": doc.metadata.get("id") or f"doc-{str(i)}", + } + for i, doc in enumerate(additional_kwargs.get("documents", [])) + ] + elif documents: + formatted_docs = documents # by enabling automatic prompt truncation, the probability of request failure is # reduced with minimal impact on response quality diff --git a/libs/partners/cohere/langchain_cohere/cohere_agent.py b/libs/partners/cohere/langchain_cohere/cohere_agent.py index 5bf8328e8c..b9c60fb743 100644 --- a/libs/partners/cohere/langchain_cohere/cohere_agent.py +++ b/libs/partners/cohere/langchain_cohere/cohere_agent.py @@ -1,6 +1,12 @@ +import json from typing import Any, Dict, List, Sequence, Tuple, Type, Union -from cohere.types import Tool, ToolParameterDefinitionsValue +from cohere.types import ( + ChatRequestToolResultsItem, + Tool, + ToolCall, + ToolParameterDefinitionsValue, +) from langchain_core.agents import AgentAction, AgentFinish from langchain_core.language_models import BaseLanguageModel from langchain_core.output_parsers import BaseOutputParser @@ -30,9 +36,7 @@ def create_cohere_tools_agent( RunnablePassthrough.assign( # Intermediate steps are in tool results. # Edit below to change the prompt parameters. - input=lambda x: prompt.format_messages( - input=x["input"], agent_scratchpad=[] - ), + input=lambda x: prompt.format_messages(**x, agent_scratchpad=[]), tools=lambda x: _format_to_cohere_tools(tools), tool_results=lambda x: _format_to_cohere_tools_messages( x["intermediate_steps"] @@ -52,20 +56,35 @@ def _format_to_cohere_tools( def _format_to_cohere_tools_messages( intermediate_steps: Sequence[Tuple[AgentAction, str]], -) -> list: +) -> List[Dict[str, Any]]: """Convert (AgentAction, tool output) tuples into tool messages.""" if len(intermediate_steps) == 0: return [] tool_results = [] for agent_action, observation in intermediate_steps: + # agent_action.tool_input can be a dict, serialised dict, or string. + # Cohere API only accepts a dict. + tool_call_parameters: Dict[str, Any] + if isinstance(agent_action.tool_input, dict): + # tool_input is a dict, use as-is. + tool_call_parameters = agent_action.tool_input + else: + try: + # tool_input is serialised dict. + tool_call_parameters = json.loads(agent_action.tool_input) + if not isinstance(tool_call_parameters, dict): + raise ValueError() + except ValueError: + # tool_input is a string, last ditch attempt at having something useful. + tool_call_parameters = {"input": agent_action.tool_input} tool_results.append( - { - "call": { - "name": agent_action.tool, - "parameters": agent_action.tool_input, - }, - "outputs": [{"answer": observation}], - } + ChatRequestToolResultsItem( + call=ToolCall( + name=agent_action.tool, + parameters=tool_call_parameters, + ), + outputs=[{"answer": observation}], + ).dict() ) return tool_results @@ -143,7 +162,7 @@ class _CohereToolsAgentOutputParser( ) -> Union[List[AgentAction], AgentFinish]: if not isinstance(result[0], ChatGeneration): raise ValueError(f"Expected ChatGeneration, got {type(result)}") - if result[0].message.additional_kwargs["tool_calls"]: + if "tool_calls" in result[0].message.additional_kwargs: actions = [] for tool in result[0].message.additional_kwargs["tool_calls"]: function = tool.get("function", {}) diff --git a/libs/partners/cohere/tests/unit_tests/test_cohere_agent.py b/libs/partners/cohere/tests/unit_tests/test_cohere_agent.py index 9dc082a55e..5ef7e42d52 100644 --- a/libs/partners/cohere/tests/unit_tests/test_cohere_agent.py +++ b/libs/partners/cohere/tests/unit_tests/test_cohere_agent.py @@ -1,9 +1,14 @@ -from typing import Any, Dict, Optional, Type, Union +import json +from typing import Any, Dict, List, Optional, Tuple, Type, Union import pytest +from langchain_core.agents import AgentAction from langchain_core.tools import BaseModel, BaseTool, Field -from langchain_cohere.cohere_agent import _format_to_cohere_tools +from langchain_cohere.cohere_agent import ( + _format_to_cohere_tools, + _format_to_cohere_tools_messages, +) expected_test_tool_definition = { "description": "test_tool description", @@ -80,3 +85,56 @@ def test_format_to_cohere_tools( actual = _format_to_cohere_tools([tool]) assert [expected_test_tool_definition] == actual + + +@pytest.mark.parametrize( + "intermediate_step,expected", + [ + pytest.param( + ( + AgentAction(tool="tool_name", tool_input={"arg1": "value1"}, log=""), + "result", + ), + { + "call": {"name": "tool_name", "parameters": {"arg1": "value1"}}, + "outputs": [{"answer": "result"}], + }, + id="tool_input as dict", + ), + pytest.param( + ( + AgentAction( + tool="tool_name", tool_input=json.dumps({"arg1": "value1"}), log="" + ), + "result", + ), + { + "call": {"name": "tool_name", "parameters": {"arg1": "value1"}}, + "outputs": [{"answer": "result"}], + }, + id="tool_input as serialized dict", + ), + pytest.param( + (AgentAction(tool="tool_name", tool_input="foo", log=""), "result"), + { + "call": {"name": "tool_name", "parameters": {"input": "foo"}}, + "outputs": [{"answer": "result"}], + }, + id="tool_input as string", + ), + pytest.param( + (AgentAction(tool="tool_name", tool_input="['foo']", log=""), "result"), + { + "call": {"name": "tool_name", "parameters": {"input": "['foo']"}}, + "outputs": [{"answer": "result"}], + }, + id="tool_input unrelated JSON", + ), + ], +) +def test_format_to_cohere_tools_messages( + intermediate_step: Tuple[AgentAction, str], expected: List[Dict[str, Any]] +) -> None: + actual = _format_to_cohere_tools_messages(intermediate_steps=[intermediate_step]) + + assert [expected] == actual