mirror of
https://github.com/hwchase17/langchain
synced 2024-11-06 03:20:49 +00:00
cohere[patch]: misc fixs tool use agent and cohere chat (#19705)
Bug fixes in this PR: * allows for other params such as "message" not just the input param to the prompt for the cohere tools agent * fixes to documents kwarg from messages * fixes to tool_calls API call --------- Co-authored-by: Harry M <127103098+harry-cohere@users.noreply.github.com>
This commit is contained in:
parent
b35e68c41f
commit
d1a2e194c3
237
libs/partners/cohere/docs/cohere_agent.ipynb
Normal file
237
libs/partners/cohere/docs/cohere_agent.ipynb
Normal file
@ -0,0 +1,237 @@
|
|||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "raw",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"---\n",
|
||||||
|
"sidebar_position: 0\n",
|
||||||
|
"---"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"# Cohere Tools\n",
|
||||||
|
"\n",
|
||||||
|
"The following notebook goes over how to use the Cohere tools agent:"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"Prerequisites for this notebook:"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 1,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Requirement already satisfied: langchain in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (0.1.13)\n",
|
||||||
|
"Requirement already satisfied: langchain-cohere in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (0.1.0rc2)\n",
|
||||||
|
"Requirement already satisfied: PyYAML>=5.3 in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from langchain) (6.0.1)\n",
|
||||||
|
"Requirement already satisfied: SQLAlchemy<3,>=1.4 in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from langchain) (2.0.27)\n",
|
||||||
|
"Requirement already satisfied: aiohttp<4.0.0,>=3.8.3 in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from langchain) (3.9.3)\n",
|
||||||
|
"Requirement already satisfied: dataclasses-json<0.7,>=0.5.7 in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from langchain) (0.6.4)\n",
|
||||||
|
"Requirement already satisfied: jsonpatch<2.0,>=1.33 in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from langchain) (1.33)\n",
|
||||||
|
"Requirement already satisfied: langchain-community<0.1,>=0.0.29 in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from langchain) (0.0.29)\n",
|
||||||
|
"Requirement already satisfied: langchain-core<0.2.0,>=0.1.33 in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from langchain) (0.1.35)\n",
|
||||||
|
"Requirement already satisfied: langchain-text-splitters<0.1,>=0.0.1 in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from langchain) (0.0.1)\n",
|
||||||
|
"Requirement already satisfied: langsmith<0.2.0,>=0.1.17 in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from langchain) (0.1.31)\n",
|
||||||
|
"Requirement already satisfied: numpy<2,>=1 in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from langchain) (1.24.4)\n",
|
||||||
|
"Requirement already satisfied: pydantic<3,>=1 in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from langchain) (2.6.4)\n",
|
||||||
|
"Requirement already satisfied: requests<3,>=2 in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from langchain) (2.31.0)\n",
|
||||||
|
"Requirement already satisfied: tenacity<9.0.0,>=8.1.0 in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from langchain) (8.2.3)\n",
|
||||||
|
"Requirement already satisfied: cohere<6.0.0,>=5.1.4 in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from langchain-cohere) (5.1.4)\n",
|
||||||
|
"Requirement already satisfied: aiosignal>=1.1.2 in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from aiohttp<4.0.0,>=3.8.3->langchain) (1.3.1)\n",
|
||||||
|
"Requirement already satisfied: attrs>=17.3.0 in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from aiohttp<4.0.0,>=3.8.3->langchain) (23.2.0)\n",
|
||||||
|
"Requirement already satisfied: frozenlist>=1.1.1 in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from aiohttp<4.0.0,>=3.8.3->langchain) (1.4.1)\n",
|
||||||
|
"Requirement already satisfied: multidict<7.0,>=4.5 in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from aiohttp<4.0.0,>=3.8.3->langchain) (6.0.5)\n",
|
||||||
|
"Requirement already satisfied: yarl<2.0,>=1.0 in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from aiohttp<4.0.0,>=3.8.3->langchain) (1.9.4)\n",
|
||||||
|
"Requirement already satisfied: httpx>=0.21.2 in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from cohere<6.0.0,>=5.1.4->langchain-cohere) (0.27.0)\n",
|
||||||
|
"Requirement already satisfied: typing_extensions>=4.0.0 in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from cohere<6.0.0,>=5.1.4->langchain-cohere) (4.10.0)\n",
|
||||||
|
"Requirement already satisfied: marshmallow<4.0.0,>=3.18.0 in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from dataclasses-json<0.7,>=0.5.7->langchain) (3.20.2)\n",
|
||||||
|
"Requirement already satisfied: typing-inspect<1,>=0.4.0 in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from dataclasses-json<0.7,>=0.5.7->langchain) (0.9.0)\n",
|
||||||
|
"Requirement already satisfied: jsonpointer>=1.9 in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from jsonpatch<2.0,>=1.33->langchain) (2.4)\n",
|
||||||
|
"Requirement already satisfied: packaging<24.0,>=23.2 in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from langchain-core<0.2.0,>=0.1.33->langchain) (23.2)\n",
|
||||||
|
"Requirement already satisfied: orjson<4.0.0,>=3.9.14 in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from langsmith<0.2.0,>=0.1.17->langchain) (3.9.15)\n",
|
||||||
|
"Requirement already satisfied: annotated-types>=0.4.0 in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from pydantic<3,>=1->langchain) (0.6.0)\n",
|
||||||
|
"Requirement already satisfied: pydantic-core==2.16.3 in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from pydantic<3,>=1->langchain) (2.16.3)\n",
|
||||||
|
"Requirement already satisfied: charset-normalizer<4,>=2 in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from requests<3,>=2->langchain) (3.3.2)\n",
|
||||||
|
"Requirement already satisfied: idna<4,>=2.5 in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from requests<3,>=2->langchain) (3.6)\n",
|
||||||
|
"Requirement already satisfied: urllib3<3,>=1.21.1 in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from requests<3,>=2->langchain) (2.2.1)\n",
|
||||||
|
"Requirement already satisfied: certifi>=2017.4.17 in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from requests<3,>=2->langchain) (2024.2.2)\n",
|
||||||
|
"Requirement already satisfied: anyio in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from httpx>=0.21.2->cohere<6.0.0,>=5.1.4->langchain-cohere) (4.3.0)\n",
|
||||||
|
"Requirement already satisfied: httpcore==1.* in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from httpx>=0.21.2->cohere<6.0.0,>=5.1.4->langchain-cohere) (1.0.4)\n",
|
||||||
|
"Requirement already satisfied: sniffio in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from httpx>=0.21.2->cohere<6.0.0,>=5.1.4->langchain-cohere) (1.3.1)\n",
|
||||||
|
"Requirement already satisfied: h11<0.15,>=0.13 in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from httpcore==1.*->httpx>=0.21.2->cohere<6.0.0,>=5.1.4->langchain-cohere) (0.14.0)\n",
|
||||||
|
"Requirement already satisfied: mypy-extensions>=0.3.0 in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from typing-inspect<1,>=0.4.0->dataclasses-json<0.7,>=0.5.7->langchain) (1.0.0)\n",
|
||||||
|
"Requirement already satisfied: wikipedia in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (1.4.0)\n",
|
||||||
|
"Requirement already satisfied: beautifulsoup4 in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from wikipedia) (4.12.3)\n",
|
||||||
|
"Requirement already satisfied: requests<3.0.0,>=2.0.0 in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from wikipedia) (2.31.0)\n",
|
||||||
|
"Requirement already satisfied: charset-normalizer<4,>=2 in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from requests<3.0.0,>=2.0.0->wikipedia) (3.3.2)\n",
|
||||||
|
"Requirement already satisfied: idna<4,>=2.5 in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from requests<3.0.0,>=2.0.0->wikipedia) (3.6)\n",
|
||||||
|
"Requirement already satisfied: urllib3<3,>=1.21.1 in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from requests<3.0.0,>=2.0.0->wikipedia) (2.2.1)\n",
|
||||||
|
"Requirement already satisfied: certifi>=2017.4.17 in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from requests<3.0.0,>=2.0.0->wikipedia) (2024.2.2)\n",
|
||||||
|
"Requirement already satisfied: soupsieve>1.2 in /Users/beatrix/Repos/langchain-1/.venv/lib/python3.11/site-packages (from beautifulsoup4->wikipedia) (2.5)\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"# install package\n",
|
||||||
|
"!pip install langchain langchain-cohere\n",
|
||||||
|
"!pip install wikipedia"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 2,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from langchain.agents import AgentExecutor\n",
|
||||||
|
"from langchain.retrievers import WikipediaRetriever\n",
|
||||||
|
"from langchain.tools.retriever import create_retriever_tool\n",
|
||||||
|
"from langchain_cohere import create_cohere_tools_agent\n",
|
||||||
|
"from langchain_cohere.chat_models import ChatCohere\n",
|
||||||
|
"from langchain_core.prompts import ChatPromptTemplate"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"Next we create the prompt template and cohere model"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 3,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# Create the prompt\n",
|
||||||
|
"prompt = ChatPromptTemplate.from_template(\n",
|
||||||
|
" \"Write all output in capital letters. {input}\"\n",
|
||||||
|
")\n",
|
||||||
|
"\n",
|
||||||
|
"# Create the Cohere chat model\n",
|
||||||
|
"chat = ChatCohere(cohere_api_key=\"API_KEY\", model=\"command-r\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"In this example we use a Wikipedia retrieval tool "
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 4,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"retriever = WikipediaRetriever()\n",
|
||||||
|
"retriever_tool = create_retriever_tool(\n",
|
||||||
|
" retriever,\n",
|
||||||
|
" \"wikipedia\",\n",
|
||||||
|
" \"Search for information on Wikipedia\",\n",
|
||||||
|
")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"Next, create the cohere tool agent and call with the input"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 5,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
|
||||||
|
"\u001b[32;1m\u001b[1;3mwikipedia\u001b[0m\u001b[36;1m\u001b[1;3m\u001b[0m\u001b[32;1m\u001b[1;3m\u001b[0m\n",
|
||||||
|
"\n",
|
||||||
|
"\u001b[1m> Finished chain.\u001b[0m\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"{'input': 'Who founded Cohere?',\n",
|
||||||
|
" 'text': 'COHERE WAS FOUNDED BY AIDAN GOMEZ, IVAN ZAPATA, AND ALON GELLA.',\n",
|
||||||
|
" 'additional_info': {'documents': [{'answer': '',\n",
|
||||||
|
" 'id': 'wikipedia:0:0',\n",
|
||||||
|
" 'tool_name': 'wikipedia'}],\n",
|
||||||
|
" 'citations': [ChatCitation(start=22, end=63, text='AIDAN GOMEZ, IVAN ZAPATA, AND ALON GELLA.', document_ids=['wikipedia:0:0'])],\n",
|
||||||
|
" 'search_results': None,\n",
|
||||||
|
" 'search_queries': None,\n",
|
||||||
|
" 'is_search_required': None,\n",
|
||||||
|
" 'generation_id': '3b7e96be-8aad-4fa0-9ae3-7a38e800c289',\n",
|
||||||
|
" 'token_count': {'prompt_tokens': 740,\n",
|
||||||
|
" 'response_tokens': 27,\n",
|
||||||
|
" 'total_tokens': 767,\n",
|
||||||
|
" 'billed_tokens': 48}}}"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 5,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"agent = create_cohere_tools_agent(\n",
|
||||||
|
" llm=chat,\n",
|
||||||
|
" tools=[retriever_tool],\n",
|
||||||
|
" prompt=prompt,\n",
|
||||||
|
")\n",
|
||||||
|
"agent_executor = AgentExecutor(agent=agent, tools=[retriever_tool], verbose=True)\n",
|
||||||
|
"agent_executor.invoke({\"input\": \"Who founded Cohere?\"})"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": []
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "Python 3 (ipykernel)",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.11.7"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 4
|
||||||
|
}
|
@ -81,25 +81,22 @@ def get_cohere_chat_request(
|
|||||||
additional_kwargs = messages[-1].additional_kwargs
|
additional_kwargs = messages[-1].additional_kwargs
|
||||||
|
|
||||||
# cohere SDK will fail loudly if both connectors and documents are provided
|
# cohere SDK will fail loudly if both connectors and documents are provided
|
||||||
if (
|
if additional_kwargs.get("documents", []) and documents and len(documents) > 0:
|
||||||
len(additional_kwargs.get("documents", [])) > 0
|
|
||||||
and documents
|
|
||||||
and len(documents) > 0
|
|
||||||
):
|
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Received documents both as a keyword argument and as an prompt additional"
|
"Received documents both as a keyword argument and as an prompt additional keyword argument. Please choose only one option." # noqa: E501
|
||||||
"keywword argument. Please choose only one option."
|
|
||||||
)
|
)
|
||||||
|
|
||||||
formatted_docs = [
|
formatted_docs: Optional[List[Dict[str, Any]]] = None
|
||||||
{
|
if additional_kwargs.get("documents"):
|
||||||
"text": doc.page_content,
|
formatted_docs = [
|
||||||
"id": doc.metadata.get("id") or f"doc-{str(i)}",
|
{
|
||||||
}
|
"text": doc.page_content,
|
||||||
for i, doc in enumerate(additional_kwargs.get("documents", []))
|
"id": doc.metadata.get("id") or f"doc-{str(i)}",
|
||||||
] or documents
|
}
|
||||||
if not formatted_docs:
|
for i, doc in enumerate(additional_kwargs.get("documents", []))
|
||||||
formatted_docs = None
|
]
|
||||||
|
elif documents:
|
||||||
|
formatted_docs = documents
|
||||||
|
|
||||||
# by enabling automatic prompt truncation, the probability of request failure is
|
# by enabling automatic prompt truncation, the probability of request failure is
|
||||||
# reduced with minimal impact on response quality
|
# reduced with minimal impact on response quality
|
||||||
|
@ -1,6 +1,12 @@
|
|||||||
|
import json
|
||||||
from typing import Any, Dict, List, Sequence, Tuple, Type, Union
|
from typing import Any, Dict, List, Sequence, Tuple, Type, Union
|
||||||
|
|
||||||
from cohere.types import Tool, ToolParameterDefinitionsValue
|
from cohere.types import (
|
||||||
|
ChatRequestToolResultsItem,
|
||||||
|
Tool,
|
||||||
|
ToolCall,
|
||||||
|
ToolParameterDefinitionsValue,
|
||||||
|
)
|
||||||
from langchain_core.agents import AgentAction, AgentFinish
|
from langchain_core.agents import AgentAction, AgentFinish
|
||||||
from langchain_core.language_models import BaseLanguageModel
|
from langchain_core.language_models import BaseLanguageModel
|
||||||
from langchain_core.output_parsers import BaseOutputParser
|
from langchain_core.output_parsers import BaseOutputParser
|
||||||
@ -30,9 +36,7 @@ def create_cohere_tools_agent(
|
|||||||
RunnablePassthrough.assign(
|
RunnablePassthrough.assign(
|
||||||
# Intermediate steps are in tool results.
|
# Intermediate steps are in tool results.
|
||||||
# Edit below to change the prompt parameters.
|
# Edit below to change the prompt parameters.
|
||||||
input=lambda x: prompt.format_messages(
|
input=lambda x: prompt.format_messages(**x, agent_scratchpad=[]),
|
||||||
input=x["input"], agent_scratchpad=[]
|
|
||||||
),
|
|
||||||
tools=lambda x: _format_to_cohere_tools(tools),
|
tools=lambda x: _format_to_cohere_tools(tools),
|
||||||
tool_results=lambda x: _format_to_cohere_tools_messages(
|
tool_results=lambda x: _format_to_cohere_tools_messages(
|
||||||
x["intermediate_steps"]
|
x["intermediate_steps"]
|
||||||
@ -52,20 +56,35 @@ def _format_to_cohere_tools(
|
|||||||
|
|
||||||
def _format_to_cohere_tools_messages(
|
def _format_to_cohere_tools_messages(
|
||||||
intermediate_steps: Sequence[Tuple[AgentAction, str]],
|
intermediate_steps: Sequence[Tuple[AgentAction, str]],
|
||||||
) -> list:
|
) -> List[Dict[str, Any]]:
|
||||||
"""Convert (AgentAction, tool output) tuples into tool messages."""
|
"""Convert (AgentAction, tool output) tuples into tool messages."""
|
||||||
if len(intermediate_steps) == 0:
|
if len(intermediate_steps) == 0:
|
||||||
return []
|
return []
|
||||||
tool_results = []
|
tool_results = []
|
||||||
for agent_action, observation in intermediate_steps:
|
for agent_action, observation in intermediate_steps:
|
||||||
|
# agent_action.tool_input can be a dict, serialised dict, or string.
|
||||||
|
# Cohere API only accepts a dict.
|
||||||
|
tool_call_parameters: Dict[str, Any]
|
||||||
|
if isinstance(agent_action.tool_input, dict):
|
||||||
|
# tool_input is a dict, use as-is.
|
||||||
|
tool_call_parameters = agent_action.tool_input
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
# tool_input is serialised dict.
|
||||||
|
tool_call_parameters = json.loads(agent_action.tool_input)
|
||||||
|
if not isinstance(tool_call_parameters, dict):
|
||||||
|
raise ValueError()
|
||||||
|
except ValueError:
|
||||||
|
# tool_input is a string, last ditch attempt at having something useful.
|
||||||
|
tool_call_parameters = {"input": agent_action.tool_input}
|
||||||
tool_results.append(
|
tool_results.append(
|
||||||
{
|
ChatRequestToolResultsItem(
|
||||||
"call": {
|
call=ToolCall(
|
||||||
"name": agent_action.tool,
|
name=agent_action.tool,
|
||||||
"parameters": agent_action.tool_input,
|
parameters=tool_call_parameters,
|
||||||
},
|
),
|
||||||
"outputs": [{"answer": observation}],
|
outputs=[{"answer": observation}],
|
||||||
}
|
).dict()
|
||||||
)
|
)
|
||||||
|
|
||||||
return tool_results
|
return tool_results
|
||||||
@ -143,7 +162,7 @@ class _CohereToolsAgentOutputParser(
|
|||||||
) -> Union[List[AgentAction], AgentFinish]:
|
) -> Union[List[AgentAction], AgentFinish]:
|
||||||
if not isinstance(result[0], ChatGeneration):
|
if not isinstance(result[0], ChatGeneration):
|
||||||
raise ValueError(f"Expected ChatGeneration, got {type(result)}")
|
raise ValueError(f"Expected ChatGeneration, got {type(result)}")
|
||||||
if result[0].message.additional_kwargs["tool_calls"]:
|
if "tool_calls" in result[0].message.additional_kwargs:
|
||||||
actions = []
|
actions = []
|
||||||
for tool in result[0].message.additional_kwargs["tool_calls"]:
|
for tool in result[0].message.additional_kwargs["tool_calls"]:
|
||||||
function = tool.get("function", {})
|
function = tool.get("function", {})
|
||||||
|
@ -1,9 +1,14 @@
|
|||||||
from typing import Any, Dict, Optional, Type, Union
|
import json
|
||||||
|
from typing import Any, Dict, List, Optional, Tuple, Type, Union
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
from langchain_core.agents import AgentAction
|
||||||
from langchain_core.tools import BaseModel, BaseTool, Field
|
from langchain_core.tools import BaseModel, BaseTool, Field
|
||||||
|
|
||||||
from langchain_cohere.cohere_agent import _format_to_cohere_tools
|
from langchain_cohere.cohere_agent import (
|
||||||
|
_format_to_cohere_tools,
|
||||||
|
_format_to_cohere_tools_messages,
|
||||||
|
)
|
||||||
|
|
||||||
expected_test_tool_definition = {
|
expected_test_tool_definition = {
|
||||||
"description": "test_tool description",
|
"description": "test_tool description",
|
||||||
@ -80,3 +85,56 @@ def test_format_to_cohere_tools(
|
|||||||
actual = _format_to_cohere_tools([tool])
|
actual = _format_to_cohere_tools([tool])
|
||||||
|
|
||||||
assert [expected_test_tool_definition] == actual
|
assert [expected_test_tool_definition] == actual
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"intermediate_step,expected",
|
||||||
|
[
|
||||||
|
pytest.param(
|
||||||
|
(
|
||||||
|
AgentAction(tool="tool_name", tool_input={"arg1": "value1"}, log=""),
|
||||||
|
"result",
|
||||||
|
),
|
||||||
|
{
|
||||||
|
"call": {"name": "tool_name", "parameters": {"arg1": "value1"}},
|
||||||
|
"outputs": [{"answer": "result"}],
|
||||||
|
},
|
||||||
|
id="tool_input as dict",
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
(
|
||||||
|
AgentAction(
|
||||||
|
tool="tool_name", tool_input=json.dumps({"arg1": "value1"}), log=""
|
||||||
|
),
|
||||||
|
"result",
|
||||||
|
),
|
||||||
|
{
|
||||||
|
"call": {"name": "tool_name", "parameters": {"arg1": "value1"}},
|
||||||
|
"outputs": [{"answer": "result"}],
|
||||||
|
},
|
||||||
|
id="tool_input as serialized dict",
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
(AgentAction(tool="tool_name", tool_input="foo", log=""), "result"),
|
||||||
|
{
|
||||||
|
"call": {"name": "tool_name", "parameters": {"input": "foo"}},
|
||||||
|
"outputs": [{"answer": "result"}],
|
||||||
|
},
|
||||||
|
id="tool_input as string",
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
(AgentAction(tool="tool_name", tool_input="['foo']", log=""), "result"),
|
||||||
|
{
|
||||||
|
"call": {"name": "tool_name", "parameters": {"input": "['foo']"}},
|
||||||
|
"outputs": [{"answer": "result"}],
|
||||||
|
},
|
||||||
|
id="tool_input unrelated JSON",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_format_to_cohere_tools_messages(
|
||||||
|
intermediate_step: Tuple[AgentAction, str], expected: List[Dict[str, Any]]
|
||||||
|
) -> None:
|
||||||
|
actual = _format_to_cohere_tools_messages(intermediate_steps=[intermediate_step])
|
||||||
|
|
||||||
|
assert [expected] == actual
|
||||||
|
Loading…
Reference in New Issue
Block a user