cohere[patch]: Add cohere tools agent (#19602)

**Description**: Adds a cohere tools agent and related notebook.

---------

Co-authored-by: BeatrixCohere <128378696+BeatrixCohere@users.noreply.github.com>
Co-authored-by: Erick Friis <erick@langchain.dev>
This commit is contained in:
harry-cohere 2024-03-28 01:35:43 +00:00 committed by GitHub
parent 5c41f4083e
commit 3685f8ceac
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 509 additions and 19 deletions

View File

@ -1,4 +1,5 @@
from langchain_cohere.chat_models import ChatCohere from langchain_cohere.chat_models import ChatCohere
from langchain_cohere.cohere_agent import create_cohere_tools_agent
from langchain_cohere.embeddings import CohereEmbeddings from langchain_cohere.embeddings import CohereEmbeddings
from langchain_cohere.rag_retrievers import CohereRagRetriever from langchain_cohere.rag_retrievers import CohereRagRetriever
from langchain_cohere.rerank import CohereRerank from langchain_cohere.rerank import CohereRerank
@ -9,4 +10,5 @@ __all__ = [
"CohereEmbeddings", "CohereEmbeddings",
"CohereRagRetriever", "CohereRagRetriever",
"CohereRerank", "CohereRerank",
"create_cohere_tools_agent",
] ]

View File

@ -1,9 +1,22 @@
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional import json
from typing import (
Any,
AsyncIterator,
Dict,
Iterator,
List,
Optional,
Sequence,
Type,
Union,
)
from cohere.types import NonStreamedChatResponse, ToolCall
from langchain_core.callbacks import ( from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun, AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun, CallbackManagerForLLMRun,
) )
from langchain_core.language_models import LanguageModelInput
from langchain_core.language_models.chat_models import ( from langchain_core.language_models.chat_models import (
BaseChatModel, BaseChatModel,
agenerate_from_stream, agenerate_from_stream,
@ -18,7 +31,11 @@ from langchain_core.messages import (
SystemMessage, SystemMessage,
) )
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.pydantic_v1 import BaseModel
from langchain_core.runnables import Runnable
from langchain_core.tools import BaseTool
from langchain_cohere.cohere_agent import _format_to_cohere_tools
from langchain_cohere.llms import BaseCohere from langchain_cohere.llms import BaseCohere
@ -143,6 +160,14 @@ class ChatCohere(BaseChatModel, BaseCohere):
} }
return {k: v for k, v in base_params.items() if v is not None} return {k: v for k, v in base_params.items() if v is not None}
def bind_tools(
self,
tools: Sequence[Union[Dict[str, Any], BaseTool, Type[BaseModel]]],
**kwargs: Any,
) -> Runnable[LanguageModelInput, BaseMessage]:
formatted_tools = _format_to_cohere_tools(tools)
return super().bind(tools=formatted_tools, **kwargs)
@property @property
def _identifying_params(self) -> Dict[str, Any]: def _identifying_params(self) -> Dict[str, Any]:
"""Get the identifying parameters.""" """Get the identifying parameters."""
@ -169,6 +194,14 @@ class ChatCohere(BaseChatModel, BaseCohere):
if run_manager: if run_manager:
run_manager.on_llm_new_token(delta, chunk=chunk) run_manager.on_llm_new_token(delta, chunk=chunk)
yield chunk yield chunk
elif data.event_type == "stream-end":
generation_info = self._get_generation_info(data.response)
yield ChatGenerationChunk(
message=AIMessageChunk(
content="", additional_kwargs=generation_info
),
generation_info=generation_info,
)
async def _astream( async def _astream(
self, self,
@ -191,16 +224,34 @@ class ChatCohere(BaseChatModel, BaseCohere):
if run_manager: if run_manager:
await run_manager.on_llm_new_token(delta, chunk=chunk) await run_manager.on_llm_new_token(delta, chunk=chunk)
yield chunk yield chunk
elif data.event_type == "stream-end":
generation_info = self._get_generation_info(data.response)
yield ChatGenerationChunk(
message=AIMessageChunk(
content="", additional_kwargs=generation_info
),
generation_info=generation_info,
)
def _get_generation_info(self, response: Any) -> Dict[str, Any]: def _get_generation_info(self, response: NonStreamedChatResponse) -> Dict[str, Any]:
"""Get the generation info from cohere API response.""" """Get the generation info from cohere API response."""
return { generation_info = {
"documents": response.documents, "documents": response.documents,
"citations": response.citations, "citations": response.citations,
"search_results": response.search_results, "search_results": response.search_results,
"search_queries": response.search_queries, "search_queries": response.search_queries,
"token_count": response.token_count, "is_search_required": response.is_search_required,
"generation_id": response.generation_id,
} }
if response.tool_calls:
# Only populate tool_calls when 1) present on the response and
# 2) has one or more calls.
generation_info["tool_calls"] = _format_cohere_tool_calls(
response.generation_id or "", response.tool_calls
)
if hasattr(response, "token_count"):
generation_info["token_count"] = response.token_count
return generation_info
def _generate( def _generate(
self, self,
@ -218,10 +269,8 @@ class ChatCohere(BaseChatModel, BaseCohere):
request = get_cohere_chat_request(messages, **self._default_params, **kwargs) request = get_cohere_chat_request(messages, **self._default_params, **kwargs)
response = self.client.chat(**request) response = self.client.chat(**request)
message = AIMessage(content=response.text)
generation_info = None
if hasattr(response, "documents"):
generation_info = self._get_generation_info(response) generation_info = self._get_generation_info(response)
message = AIMessage(content=response.text, additional_kwargs=generation_info)
return ChatResult( return ChatResult(
generations=[ generations=[
ChatGeneration(message=message, generation_info=generation_info) ChatGeneration(message=message, generation_info=generation_info)
@ -244,10 +293,8 @@ class ChatCohere(BaseChatModel, BaseCohere):
request = get_cohere_chat_request(messages, **self._default_params, **kwargs) request = get_cohere_chat_request(messages, **self._default_params, **kwargs)
response = self.client.chat(**request) response = self.client.chat(**request)
message = AIMessage(content=response.text)
generation_info = None
if hasattr(response, "documents"):
generation_info = self._get_generation_info(response) generation_info = self._get_generation_info(response)
message = AIMessage(content=response.text, additional_kwargs=generation_info)
return ChatResult( return ChatResult(
generations=[ generations=[
ChatGeneration(message=message, generation_info=generation_info) ChatGeneration(message=message, generation_info=generation_info)
@ -257,3 +304,27 @@ class ChatCohere(BaseChatModel, BaseCohere):
def get_num_tokens(self, text: str) -> int: def get_num_tokens(self, text: str) -> int:
"""Calculate number of tokens.""" """Calculate number of tokens."""
return len(self.client.tokenize(text).tokens) return len(self.client.tokenize(text).tokens)
def _format_cohere_tool_calls(
generation_id: str, tool_calls: Optional[List[ToolCall]] = None
) -> List[Dict]:
"""
Formats a Cohere API response into the tool call format used elsewhere in Langchain.
"""
if not tool_calls:
return []
formatted_tool_calls = []
for tool_call in tool_calls:
formatted_tool_calls.append(
{
"id": generation_id,
"function": {
"name": tool_call.name,
"arguments": json.dumps(tool_call.parameters),
},
"type": "function",
}
)
return formatted_tool_calls

View File

@ -0,0 +1,168 @@
from typing import Any, Dict, List, Sequence, Tuple, Type, Union
from cohere.types import Tool, ToolParameterDefinitionsValue
from langchain_core.agents import AgentAction, AgentFinish
from langchain_core.language_models import BaseLanguageModel
from langchain_core.output_parsers import BaseOutputParser
from langchain_core.outputs import Generation
from langchain_core.outputs.chat_generation import ChatGeneration
from langchain_core.prompts.chat import ChatPromptTemplate
from langchain_core.pydantic_v1 import BaseModel
from langchain_core.runnables import Runnable, RunnablePassthrough
from langchain_core.runnables.base import RunnableLambda
from langchain_core.tools import BaseTool
from langchain_core.utils.function_calling import convert_to_openai_function
def create_cohere_tools_agent(
llm: BaseLanguageModel, tools: Sequence[BaseTool], prompt: ChatPromptTemplate
) -> Runnable:
def llm_with_tools(input_: Dict) -> Runnable:
tool_results = (
input_["tool_results"] if len(input_["tool_results"]) > 0 else None
)
tools_ = input_["tools"] if len(input_["tools"]) > 0 else None
return RunnableLambda(lambda x: x["input"]) | llm.bind(
tools=tools_, tool_results=tool_results
)
agent = (
RunnablePassthrough.assign(
# Intermediate steps are in tool results.
# Edit below to change the prompt parameters.
input=lambda x: prompt.format_messages(
input=x["input"], agent_scratchpad=[]
),
tools=lambda x: _format_to_cohere_tools(tools),
tool_results=lambda x: _format_to_cohere_tools_messages(
x["intermediate_steps"]
),
)
| llm_with_tools
| _CohereToolsAgentOutputParser()
)
return agent
def _format_to_cohere_tools(
tools: Sequence[Union[Dict[str, Any], BaseTool, Type[BaseModel]]],
) -> List[Dict[str, Any]]:
return [_convert_to_cohere_tool(tool) for tool in tools]
def _format_to_cohere_tools_messages(
intermediate_steps: Sequence[Tuple[AgentAction, str]],
) -> list:
"""Convert (AgentAction, tool output) tuples into tool messages."""
if len(intermediate_steps) == 0:
return []
tool_results = []
for agent_action, observation in intermediate_steps:
tool_results.append(
{
"call": {
"name": agent_action.tool,
"parameters": agent_action.tool_input,
},
"outputs": [{"answer": observation}],
}
)
return tool_results
def _convert_to_cohere_tool(
tool: Union[Dict[str, Any], BaseTool, Type[BaseModel]],
) -> Dict[str, Any]:
"""
Convert a BaseTool instance, JSON schema dict, or BaseModel type to a Cohere tool.
"""
if isinstance(tool, BaseTool):
return Tool(
name=tool.name,
description=tool.description,
parameter_definitions={
param_name: ToolParameterDefinitionsValue(
description=param_definition.get("description"),
type=param_definition.get("type"),
required="default" not in param_definition,
)
for param_name, param_definition in tool.args.items()
},
).dict()
elif isinstance(tool, dict):
if not all(k in tool for k in ("title", "description", "properties")):
raise ValueError(
"Unsupported dict type. Tool must be passed in as a BaseTool instance, JSON schema dict, or BaseModel type." # noqa: E501
)
return Tool(
name=tool.get("title"),
description=tool.get("description"),
parameter_definitions={
param_name: ToolParameterDefinitionsValue(
description=param_definition.get("description"),
type=param_definition.get("type"),
required="default" not in param_definition,
)
for param_name, param_definition in tool.get("properties", {}).items()
},
).dict()
elif issubclass(tool, BaseModel):
as_json_schema_function = convert_to_openai_function(tool)
parameters = as_json_schema_function.get("parameters", {})
properties = parameters.get("properties", {})
return Tool(
name=as_json_schema_function.get("name"),
description=as_json_schema_function.get(
# The Cohere API requires the description field.
"description",
as_json_schema_function.get("name"),
),
parameter_definitions={
param_name: ToolParameterDefinitionsValue(
description=param_definition.get("description"),
type=param_definition.get("type"),
required=param_name in parameters.get("required", []),
)
for param_name, param_definition in properties.items()
},
).dict()
else:
raise ValueError(
f"Unsupported tool type {type(tool)}. Tool must be passed in as a BaseTool instance, JSON schema dict, or BaseModel type." # noqa: E501
)
class _CohereToolsAgentOutputParser(
BaseOutputParser[Union[List[AgentAction], AgentFinish]]
):
"""Parses a message into agent actions/finish."""
def parse_result(
self, result: List[Generation], *, partial: bool = False
) -> Union[List[AgentAction], AgentFinish]:
if not isinstance(result[0], ChatGeneration):
raise ValueError(f"Expected ChatGeneration, got {type(result)}")
if result[0].message.additional_kwargs["tool_calls"]:
actions = []
for tool in result[0].message.additional_kwargs["tool_calls"]:
function = tool.get("function", {})
actions.append(
AgentAction(
tool=function.get("name"),
tool_input=function.get("arguments"),
log=function.get("name"),
)
)
return actions
else:
return AgentFinish(
return_values={
"text": result[0].message.content,
"additional_info": result[0].message.additional_kwargs,
},
log="",
)
def parse(self, text: str) -> Union[List[AgentAction], AgentFinish]:
raise ValueError("Can only parse messages")

View File

@ -20,6 +20,7 @@ def _get_docs(response: Any) -> List[Document]:
docs = ( docs = (
[] []
if "documents" not in response.generation_info if "documents" not in response.generation_info
or len(response.generation_info["documents"]) == 0
else [ else [
Document(page_content=doc["snippet"], metadata=doc) Document(page_content=doc["snippet"], metadata=doc)
for doc in response.generation_info["documents"] for doc in response.generation_info["documents"]

View File

@ -165,13 +165,13 @@ types = ["chardet (>=5.1.0)", "mypy", "pytest", "pytest-cov", "pytest-dependency
[[package]] [[package]]
name = "cohere" name = "cohere"
version = "5.1.2" version = "5.1.4"
description = "" description = ""
optional = false optional = false
python-versions = "<4.0,>=3.8" python-versions = "<4.0,>=3.8"
files = [ files = [
{file = "cohere-5.1.2-py3-none-any.whl", hash = "sha256:7782e32cba671fc04203c3b56a9ce1b70e9459d7c983e8576b04d394fbe809f5"}, {file = "cohere-5.1.4-py3-none-any.whl", hash = "sha256:b88c44dfa44301f55f509db120582a6127c2e391c6c43a4dc58767f4df056a9d"},
{file = "cohere-5.1.2.tar.gz", hash = "sha256:21af5ed6edcf939062c41240040316084cd7e753cf3207f661f68abb4bbbe846"}, {file = "cohere-5.1.4.tar.gz", hash = "sha256:81b45fe37df2d62aaf57094402cb62b5fed285c25667dab96023f2ad2591ff35"},
] ]
[package.dependencies] [package.dependencies]
@ -742,7 +742,6 @@ files = [
{file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"}, {file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"},
{file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"},
{file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"},
{file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef"},
{file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"}, {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"},
{file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"}, {file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"},
{file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"}, {file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"},
@ -957,4 +956,4 @@ watchmedo = ["PyYAML (>=3.10)"]
[metadata] [metadata]
lock-version = "2.0" lock-version = "2.0"
python-versions = ">=3.8.1,<4.0" python-versions = ">=3.8.1,<4.0"
content-hash = "7ed2d31c084d528c64eb959df1a2ea29345a70117e9d29f322607fe247804cc5" content-hash = "6a5887a0391a649e1a45f3e3c766a880e133367d2656a9b5a37d75ebc33adef6"

View File

@ -1,6 +1,6 @@
[tool.poetry] [tool.poetry]
name = "langchain-cohere" name = "langchain-cohere"
version = "0.1.0rc1" version = "0.1.0rc2"
description = "An integration package connecting Cohere and LangChain" description = "An integration package connecting Cohere and LangChain"
authors = [] authors = []
readme = "README.md" readme = "README.md"
@ -13,7 +13,7 @@ license = "MIT"
[tool.poetry.dependencies] [tool.poetry.dependencies]
python = ">=3.8.1,<4.0" python = ">=3.8.1,<4.0"
langchain-core = "^0.1.32" langchain-core = "^0.1.32"
cohere = "^5.1.1" cohere = "^5.1.4"
[tool.poetry.group.test] [tool.poetry.group.test]
optional = true optional = true

View File

@ -1,4 +1,12 @@
"""Test ChatCohere chat model.""" """Test ChatCohere chat model."""
import json
from typing import Any
import pytest
from langchain_core.messages import AIMessage, AIMessageChunk
from langchain_core.pydantic_v1 import BaseModel
from langchain_cohere import ChatCohere from langchain_cohere import ChatCohere
@ -61,3 +69,76 @@ def test_invoke() -> None:
result = llm.invoke("I'm Pickle Rick", config=dict(tags=["foo"])) result = llm.invoke("I'm Pickle Rick", config=dict(tags=["foo"]))
assert isinstance(result.content, str) assert isinstance(result.content, str)
def test_invoke_tool_calls() -> None:
llm = ChatCohere(temperature=0)
class Person(BaseModel):
name: str
age: int
tool_llm = llm.bind_tools([Person])
# where it calls the tool
result = tool_llm.invoke("Erick, 27 years old")
assert isinstance(result, AIMessage)
additional_kwargs = result.additional_kwargs
assert "tool_calls" in additional_kwargs
assert len(additional_kwargs["tool_calls"]) == 1
assert additional_kwargs["tool_calls"][0]["function"]["name"] == "Person"
assert json.loads(additional_kwargs["tool_calls"][0]["function"]["arguments"]) == {
"name": "Erick",
"age": 27,
}
def test_streaming_tool_call() -> None:
llm = ChatCohere(temperature=0)
class Person(BaseModel):
name: str
age: int
tool_llm = llm.bind_tools([Person])
# where it calls the tool
strm = tool_llm.stream("Erick, 27 years old")
additional_kwargs = None
for chunk in strm:
assert isinstance(chunk, AIMessageChunk)
assert chunk.content == ""
additional_kwargs = chunk.additional_kwargs
assert additional_kwargs is not None
assert "tool_calls" in additional_kwargs
assert len(additional_kwargs["tool_calls"]) == 1
assert additional_kwargs["tool_calls"][0]["function"]["name"] == "Person"
assert json.loads(additional_kwargs["tool_calls"][0]["function"]["arguments"]) == {
"name": "Erick",
"age": 27,
}
@pytest.mark.xfail(
reason="Cohere models return empty output when a tool is passed in but not called."
)
def test_streaming_tool_call_no_tool_calls() -> None:
llm = ChatCohere(temperature=0)
class Person(BaseModel):
name: str
age: int
tool_llm = llm.bind_tools([Person])
# where it doesn't call the tool
strm = tool_llm.stream("What is 2+2?")
acc: Any = None
for chunk in strm:
assert isinstance(chunk, AIMessageChunk)
acc = chunk if acc is None else acc + chunk
assert acc.content != ""
assert "tool_calls" not in acc.additional_kwargs

View File

@ -2,6 +2,7 @@
import typing import typing
import pytest import pytest
from cohere.types import NonStreamedChatResponse, ToolCall
from langchain_cohere.chat_models import ChatCohere from langchain_cohere.chat_models import ChatCohere
@ -28,3 +29,87 @@ def test_initialization() -> None:
def test_default_params(chat_cohere: ChatCohere, expected: typing.Dict) -> None: def test_default_params(chat_cohere: ChatCohere, expected: typing.Dict) -> None:
actual = chat_cohere._default_params actual = chat_cohere._default_params
assert expected == actual assert expected == actual
@pytest.mark.parametrize(
"response, expected",
[
pytest.param(
NonStreamedChatResponse(
generation_id="foo",
text="",
tool_calls=[
ToolCall(name="tool1", parameters={"arg1": 1, "arg2": "2"}),
ToolCall(name="tool2", parameters={"arg3": 3, "arg4": "4"}),
],
),
{
"documents": None,
"citations": None,
"search_results": None,
"search_queries": None,
"is_search_required": None,
"generation_id": "foo",
"tool_calls": [
{
"id": "foo",
"function": {
"name": "tool1",
"arguments": '{"arg1": 1, "arg2": "2"}',
},
"type": "function",
},
{
"id": "foo",
"function": {
"name": "tool2",
"arguments": '{"arg3": 3, "arg4": "4"}',
},
"type": "function",
},
],
},
id="tools should be called",
),
pytest.param(
NonStreamedChatResponse(
generation_id="foo",
text="",
tool_calls=[],
),
{
"documents": None,
"citations": None,
"search_results": None,
"search_queries": None,
"is_search_required": None,
"generation_id": "foo",
},
id="no tools should be called",
),
pytest.param(
NonStreamedChatResponse(
generation_id="foo",
text="bar",
tool_calls=[],
),
{
"documents": None,
"citations": None,
"search_results": None,
"search_queries": None,
"is_search_required": None,
"generation_id": "foo",
},
id="chat response without tools/documents/citations/tools etc",
),
],
)
def test_get_generation_info(
response: typing.Any, expected: typing.Dict[str, typing.Any]
) -> None:
chat_cohere = ChatCohere(cohere_api_key="test")
actual = chat_cohere._get_generation_info(response)
assert expected == actual

View File

@ -0,0 +1,82 @@
from typing import Any, Dict, Optional, Type, Union
import pytest
from langchain_core.tools import BaseModel, BaseTool, Field
from langchain_cohere.cohere_agent import _format_to_cohere_tools
expected_test_tool_definition = {
"description": "test_tool description",
"name": "test_tool",
"parameter_definitions": {
"arg_1": {
"description": "Arg1 description",
"required": True,
"type": "string",
},
"optional_arg_2": {
"description": "Arg2 description",
"required": False,
"type": "string",
},
"arg_3": {
"description": "Arg3 description",
"required": True,
"type": "integer",
},
},
}
class _TestToolSchema(BaseModel):
arg_1: str = Field(description="Arg1 description")
optional_arg_2: Optional[str] = Field(description="Arg2 description", default="2")
arg_3: int = Field(description="Arg3 description")
class _TestTool(BaseTool):
name = "test_tool"
description = "test_tool description"
args_schema: Type[_TestToolSchema] = _TestToolSchema
def _run(self, *args: Any, **kwargs: Any) -> Any:
pass
class test_tool(BaseModel):
"""test_tool description"""
arg_1: str = Field(description="Arg1 description")
optional_arg_2: Optional[str] = Field(description="Arg2 description", default="2")
arg_3: int = Field(description="Arg3 description")
test_tool_as_dict = {
"title": "test_tool",
"description": "test_tool description",
"properties": {
"arg_1": {"description": "Arg1 description", "type": "string"},
"optional_arg_2": {
"description": "Arg2 description",
"type": "string",
"default": "2",
},
"arg_3": {"description": "Arg3 description", "type": "integer"},
},
}
@pytest.mark.parametrize(
"tool",
[
pytest.param(_TestTool(), id="tool from BaseTool"),
pytest.param(test_tool, id="BaseModel"),
pytest.param(test_tool_as_dict, id="JSON schema dict"),
],
)
def test_format_to_cohere_tools(
tool: Union[Dict[str, Any], BaseTool, Type[BaseModel]],
) -> None:
actual = _format_to_cohere_tools([tool])
assert [expected_test_tool_definition] == actual

View File

@ -6,6 +6,7 @@ EXPECTED_ALL = [
"CohereEmbeddings", "CohereEmbeddings",
"CohereRagRetriever", "CohereRagRetriever",
"CohereRerank", "CohereRerank",
"create_cohere_tools_agent",
] ]