From 3685f8ceac42ee3614127c1210e3d746599219e5 Mon Sep 17 00:00:00 2001 From: harry-cohere <127103098+harry-cohere@users.noreply.github.com> Date: Thu, 28 Mar 2024 01:35:43 +0000 Subject: [PATCH] cohere[patch]: Add cohere tools agent (#19602) **Description**: Adds a cohere tools agent and related notebook. --------- Co-authored-by: BeatrixCohere <128378696+BeatrixCohere@users.noreply.github.com> Co-authored-by: Erick Friis --- .../cohere/langchain_cohere/__init__.py | 2 + .../cohere/langchain_cohere/chat_models.py | 95 ++++++++-- .../cohere/langchain_cohere/cohere_agent.py | 168 ++++++++++++++++++ .../cohere/langchain_cohere/rag_retrievers.py | 1 + libs/partners/cohere/poetry.lock | 9 +- libs/partners/cohere/pyproject.toml | 4 +- .../integration_tests/test_chat_models.py | 81 +++++++++ .../tests/unit_tests/test_chat_models.py | 85 +++++++++ .../tests/unit_tests/test_cohere_agent.py | 82 +++++++++ .../cohere/tests/unit_tests/test_imports.py | 1 + 10 files changed, 509 insertions(+), 19 deletions(-) create mode 100644 libs/partners/cohere/langchain_cohere/cohere_agent.py create mode 100644 libs/partners/cohere/tests/unit_tests/test_cohere_agent.py diff --git a/libs/partners/cohere/langchain_cohere/__init__.py b/libs/partners/cohere/langchain_cohere/__init__.py index 1f554a006e..52d5336193 100644 --- a/libs/partners/cohere/langchain_cohere/__init__.py +++ b/libs/partners/cohere/langchain_cohere/__init__.py @@ -1,4 +1,5 @@ from langchain_cohere.chat_models import ChatCohere +from langchain_cohere.cohere_agent import create_cohere_tools_agent from langchain_cohere.embeddings import CohereEmbeddings from langchain_cohere.rag_retrievers import CohereRagRetriever from langchain_cohere.rerank import CohereRerank @@ -9,4 +10,5 @@ __all__ = [ "CohereEmbeddings", "CohereRagRetriever", "CohereRerank", + "create_cohere_tools_agent", ] diff --git a/libs/partners/cohere/langchain_cohere/chat_models.py b/libs/partners/cohere/langchain_cohere/chat_models.py index f60f5636dd..ea830e81f9 100644 --- a/libs/partners/cohere/langchain_cohere/chat_models.py +++ b/libs/partners/cohere/langchain_cohere/chat_models.py @@ -1,9 +1,22 @@ -from typing import Any, AsyncIterator, Dict, Iterator, List, Optional +import json +from typing import ( + Any, + AsyncIterator, + Dict, + Iterator, + List, + Optional, + Sequence, + Type, + Union, +) +from cohere.types import NonStreamedChatResponse, ToolCall from langchain_core.callbacks import ( AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun, ) +from langchain_core.language_models import LanguageModelInput from langchain_core.language_models.chat_models import ( BaseChatModel, agenerate_from_stream, @@ -18,7 +31,11 @@ from langchain_core.messages import ( SystemMessage, ) from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult +from langchain_core.pydantic_v1 import BaseModel +from langchain_core.runnables import Runnable +from langchain_core.tools import BaseTool +from langchain_cohere.cohere_agent import _format_to_cohere_tools from langchain_cohere.llms import BaseCohere @@ -143,6 +160,14 @@ class ChatCohere(BaseChatModel, BaseCohere): } return {k: v for k, v in base_params.items() if v is not None} + def bind_tools( + self, + tools: Sequence[Union[Dict[str, Any], BaseTool, Type[BaseModel]]], + **kwargs: Any, + ) -> Runnable[LanguageModelInput, BaseMessage]: + formatted_tools = _format_to_cohere_tools(tools) + return super().bind(tools=formatted_tools, **kwargs) + @property def _identifying_params(self) -> Dict[str, Any]: """Get the identifying parameters.""" @@ -169,6 +194,14 @@ class ChatCohere(BaseChatModel, BaseCohere): if run_manager: run_manager.on_llm_new_token(delta, chunk=chunk) yield chunk + elif data.event_type == "stream-end": + generation_info = self._get_generation_info(data.response) + yield ChatGenerationChunk( + message=AIMessageChunk( + content="", additional_kwargs=generation_info + ), + generation_info=generation_info, + ) async def _astream( self, @@ -191,16 +224,34 @@ class ChatCohere(BaseChatModel, BaseCohere): if run_manager: await run_manager.on_llm_new_token(delta, chunk=chunk) yield chunk + elif data.event_type == "stream-end": + generation_info = self._get_generation_info(data.response) + yield ChatGenerationChunk( + message=AIMessageChunk( + content="", additional_kwargs=generation_info + ), + generation_info=generation_info, + ) - def _get_generation_info(self, response: Any) -> Dict[str, Any]: + def _get_generation_info(self, response: NonStreamedChatResponse) -> Dict[str, Any]: """Get the generation info from cohere API response.""" - return { + generation_info = { "documents": response.documents, "citations": response.citations, "search_results": response.search_results, "search_queries": response.search_queries, - "token_count": response.token_count, + "is_search_required": response.is_search_required, + "generation_id": response.generation_id, } + if response.tool_calls: + # Only populate tool_calls when 1) present on the response and + # 2) has one or more calls. + generation_info["tool_calls"] = _format_cohere_tool_calls( + response.generation_id or "", response.tool_calls + ) + if hasattr(response, "token_count"): + generation_info["token_count"] = response.token_count + return generation_info def _generate( self, @@ -218,10 +269,8 @@ class ChatCohere(BaseChatModel, BaseCohere): request = get_cohere_chat_request(messages, **self._default_params, **kwargs) response = self.client.chat(**request) - message = AIMessage(content=response.text) - generation_info = None - if hasattr(response, "documents"): - generation_info = self._get_generation_info(response) + generation_info = self._get_generation_info(response) + message = AIMessage(content=response.text, additional_kwargs=generation_info) return ChatResult( generations=[ ChatGeneration(message=message, generation_info=generation_info) @@ -244,10 +293,8 @@ class ChatCohere(BaseChatModel, BaseCohere): request = get_cohere_chat_request(messages, **self._default_params, **kwargs) response = self.client.chat(**request) - message = AIMessage(content=response.text) - generation_info = None - if hasattr(response, "documents"): - generation_info = self._get_generation_info(response) + generation_info = self._get_generation_info(response) + message = AIMessage(content=response.text, additional_kwargs=generation_info) return ChatResult( generations=[ ChatGeneration(message=message, generation_info=generation_info) @@ -257,3 +304,27 @@ class ChatCohere(BaseChatModel, BaseCohere): def get_num_tokens(self, text: str) -> int: """Calculate number of tokens.""" return len(self.client.tokenize(text).tokens) + + +def _format_cohere_tool_calls( + generation_id: str, tool_calls: Optional[List[ToolCall]] = None +) -> List[Dict]: + """ + Formats a Cohere API response into the tool call format used elsewhere in Langchain. + """ + if not tool_calls: + return [] + + formatted_tool_calls = [] + for tool_call in tool_calls: + formatted_tool_calls.append( + { + "id": generation_id, + "function": { + "name": tool_call.name, + "arguments": json.dumps(tool_call.parameters), + }, + "type": "function", + } + ) + return formatted_tool_calls diff --git a/libs/partners/cohere/langchain_cohere/cohere_agent.py b/libs/partners/cohere/langchain_cohere/cohere_agent.py new file mode 100644 index 0000000000..5bf8328e8c --- /dev/null +++ b/libs/partners/cohere/langchain_cohere/cohere_agent.py @@ -0,0 +1,168 @@ +from typing import Any, Dict, List, Sequence, Tuple, Type, Union + +from cohere.types import Tool, ToolParameterDefinitionsValue +from langchain_core.agents import AgentAction, AgentFinish +from langchain_core.language_models import BaseLanguageModel +from langchain_core.output_parsers import BaseOutputParser +from langchain_core.outputs import Generation +from langchain_core.outputs.chat_generation import ChatGeneration +from langchain_core.prompts.chat import ChatPromptTemplate +from langchain_core.pydantic_v1 import BaseModel +from langchain_core.runnables import Runnable, RunnablePassthrough +from langchain_core.runnables.base import RunnableLambda +from langchain_core.tools import BaseTool +from langchain_core.utils.function_calling import convert_to_openai_function + + +def create_cohere_tools_agent( + llm: BaseLanguageModel, tools: Sequence[BaseTool], prompt: ChatPromptTemplate +) -> Runnable: + def llm_with_tools(input_: Dict) -> Runnable: + tool_results = ( + input_["tool_results"] if len(input_["tool_results"]) > 0 else None + ) + tools_ = input_["tools"] if len(input_["tools"]) > 0 else None + return RunnableLambda(lambda x: x["input"]) | llm.bind( + tools=tools_, tool_results=tool_results + ) + + agent = ( + RunnablePassthrough.assign( + # Intermediate steps are in tool results. + # Edit below to change the prompt parameters. + input=lambda x: prompt.format_messages( + input=x["input"], agent_scratchpad=[] + ), + tools=lambda x: _format_to_cohere_tools(tools), + tool_results=lambda x: _format_to_cohere_tools_messages( + x["intermediate_steps"] + ), + ) + | llm_with_tools + | _CohereToolsAgentOutputParser() + ) + return agent + + +def _format_to_cohere_tools( + tools: Sequence[Union[Dict[str, Any], BaseTool, Type[BaseModel]]], +) -> List[Dict[str, Any]]: + return [_convert_to_cohere_tool(tool) for tool in tools] + + +def _format_to_cohere_tools_messages( + intermediate_steps: Sequence[Tuple[AgentAction, str]], +) -> list: + """Convert (AgentAction, tool output) tuples into tool messages.""" + if len(intermediate_steps) == 0: + return [] + tool_results = [] + for agent_action, observation in intermediate_steps: + tool_results.append( + { + "call": { + "name": agent_action.tool, + "parameters": agent_action.tool_input, + }, + "outputs": [{"answer": observation}], + } + ) + + return tool_results + + +def _convert_to_cohere_tool( + tool: Union[Dict[str, Any], BaseTool, Type[BaseModel]], +) -> Dict[str, Any]: + """ + Convert a BaseTool instance, JSON schema dict, or BaseModel type to a Cohere tool. + """ + if isinstance(tool, BaseTool): + return Tool( + name=tool.name, + description=tool.description, + parameter_definitions={ + param_name: ToolParameterDefinitionsValue( + description=param_definition.get("description"), + type=param_definition.get("type"), + required="default" not in param_definition, + ) + for param_name, param_definition in tool.args.items() + }, + ).dict() + elif isinstance(tool, dict): + if not all(k in tool for k in ("title", "description", "properties")): + raise ValueError( + "Unsupported dict type. Tool must be passed in as a BaseTool instance, JSON schema dict, or BaseModel type." # noqa: E501 + ) + return Tool( + name=tool.get("title"), + description=tool.get("description"), + parameter_definitions={ + param_name: ToolParameterDefinitionsValue( + description=param_definition.get("description"), + type=param_definition.get("type"), + required="default" not in param_definition, + ) + for param_name, param_definition in tool.get("properties", {}).items() + }, + ).dict() + elif issubclass(tool, BaseModel): + as_json_schema_function = convert_to_openai_function(tool) + parameters = as_json_schema_function.get("parameters", {}) + properties = parameters.get("properties", {}) + return Tool( + name=as_json_schema_function.get("name"), + description=as_json_schema_function.get( + # The Cohere API requires the description field. + "description", + as_json_schema_function.get("name"), + ), + parameter_definitions={ + param_name: ToolParameterDefinitionsValue( + description=param_definition.get("description"), + type=param_definition.get("type"), + required=param_name in parameters.get("required", []), + ) + for param_name, param_definition in properties.items() + }, + ).dict() + else: + raise ValueError( + f"Unsupported tool type {type(tool)}. Tool must be passed in as a BaseTool instance, JSON schema dict, or BaseModel type." # noqa: E501 + ) + + +class _CohereToolsAgentOutputParser( + BaseOutputParser[Union[List[AgentAction], AgentFinish]] +): + """Parses a message into agent actions/finish.""" + + def parse_result( + self, result: List[Generation], *, partial: bool = False + ) -> Union[List[AgentAction], AgentFinish]: + if not isinstance(result[0], ChatGeneration): + raise ValueError(f"Expected ChatGeneration, got {type(result)}") + if result[0].message.additional_kwargs["tool_calls"]: + actions = [] + for tool in result[0].message.additional_kwargs["tool_calls"]: + function = tool.get("function", {}) + actions.append( + AgentAction( + tool=function.get("name"), + tool_input=function.get("arguments"), + log=function.get("name"), + ) + ) + return actions + else: + return AgentFinish( + return_values={ + "text": result[0].message.content, + "additional_info": result[0].message.additional_kwargs, + }, + log="", + ) + + def parse(self, text: str) -> Union[List[AgentAction], AgentFinish]: + raise ValueError("Can only parse messages") diff --git a/libs/partners/cohere/langchain_cohere/rag_retrievers.py b/libs/partners/cohere/langchain_cohere/rag_retrievers.py index 91f4c3a088..0d19459620 100644 --- a/libs/partners/cohere/langchain_cohere/rag_retrievers.py +++ b/libs/partners/cohere/langchain_cohere/rag_retrievers.py @@ -20,6 +20,7 @@ def _get_docs(response: Any) -> List[Document]: docs = ( [] if "documents" not in response.generation_info + or len(response.generation_info["documents"]) == 0 else [ Document(page_content=doc["snippet"], metadata=doc) for doc in response.generation_info["documents"] diff --git a/libs/partners/cohere/poetry.lock b/libs/partners/cohere/poetry.lock index e3fae42bd8..d18c8a94e5 100644 --- a/libs/partners/cohere/poetry.lock +++ b/libs/partners/cohere/poetry.lock @@ -165,13 +165,13 @@ types = ["chardet (>=5.1.0)", "mypy", "pytest", "pytest-cov", "pytest-dependency [[package]] name = "cohere" -version = "5.1.2" +version = "5.1.4" description = "" optional = false python-versions = "<4.0,>=3.8" files = [ - {file = "cohere-5.1.2-py3-none-any.whl", hash = "sha256:7782e32cba671fc04203c3b56a9ce1b70e9459d7c983e8576b04d394fbe809f5"}, - {file = "cohere-5.1.2.tar.gz", hash = "sha256:21af5ed6edcf939062c41240040316084cd7e753cf3207f661f68abb4bbbe846"}, + {file = "cohere-5.1.4-py3-none-any.whl", hash = "sha256:b88c44dfa44301f55f509db120582a6127c2e391c6c43a4dc58767f4df056a9d"}, + {file = "cohere-5.1.4.tar.gz", hash = "sha256:81b45fe37df2d62aaf57094402cb62b5fed285c25667dab96023f2ad2591ff35"}, ] [package.dependencies] @@ -742,7 +742,6 @@ files = [ {file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"}, - {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef"}, {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"}, {file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"}, {file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"}, @@ -957,4 +956,4 @@ watchmedo = ["PyYAML (>=3.10)"] [metadata] lock-version = "2.0" python-versions = ">=3.8.1,<4.0" -content-hash = "7ed2d31c084d528c64eb959df1a2ea29345a70117e9d29f322607fe247804cc5" +content-hash = "6a5887a0391a649e1a45f3e3c766a880e133367d2656a9b5a37d75ebc33adef6" diff --git a/libs/partners/cohere/pyproject.toml b/libs/partners/cohere/pyproject.toml index 8fcad47bfd..71ddcca2e1 100644 --- a/libs/partners/cohere/pyproject.toml +++ b/libs/partners/cohere/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "langchain-cohere" -version = "0.1.0rc1" +version = "0.1.0rc2" description = "An integration package connecting Cohere and LangChain" authors = [] readme = "README.md" @@ -13,7 +13,7 @@ license = "MIT" [tool.poetry.dependencies] python = ">=3.8.1,<4.0" langchain-core = "^0.1.32" -cohere = "^5.1.1" +cohere = "^5.1.4" [tool.poetry.group.test] optional = true diff --git a/libs/partners/cohere/tests/integration_tests/test_chat_models.py b/libs/partners/cohere/tests/integration_tests/test_chat_models.py index 81246c37aa..a94c9e627d 100644 --- a/libs/partners/cohere/tests/integration_tests/test_chat_models.py +++ b/libs/partners/cohere/tests/integration_tests/test_chat_models.py @@ -1,4 +1,12 @@ """Test ChatCohere chat model.""" + +import json +from typing import Any + +import pytest +from langchain_core.messages import AIMessage, AIMessageChunk +from langchain_core.pydantic_v1 import BaseModel + from langchain_cohere import ChatCohere @@ -61,3 +69,76 @@ def test_invoke() -> None: result = llm.invoke("I'm Pickle Rick", config=dict(tags=["foo"])) assert isinstance(result.content, str) + + +def test_invoke_tool_calls() -> None: + llm = ChatCohere(temperature=0) + + class Person(BaseModel): + name: str + age: int + + tool_llm = llm.bind_tools([Person]) + + # where it calls the tool + result = tool_llm.invoke("Erick, 27 years old") + + assert isinstance(result, AIMessage) + additional_kwargs = result.additional_kwargs + assert "tool_calls" in additional_kwargs + assert len(additional_kwargs["tool_calls"]) == 1 + assert additional_kwargs["tool_calls"][0]["function"]["name"] == "Person" + assert json.loads(additional_kwargs["tool_calls"][0]["function"]["arguments"]) == { + "name": "Erick", + "age": 27, + } + + +def test_streaming_tool_call() -> None: + llm = ChatCohere(temperature=0) + + class Person(BaseModel): + name: str + age: int + + tool_llm = llm.bind_tools([Person]) + + # where it calls the tool + strm = tool_llm.stream("Erick, 27 years old") + + additional_kwargs = None + for chunk in strm: + assert isinstance(chunk, AIMessageChunk) + assert chunk.content == "" + additional_kwargs = chunk.additional_kwargs + + assert additional_kwargs is not None + assert "tool_calls" in additional_kwargs + assert len(additional_kwargs["tool_calls"]) == 1 + assert additional_kwargs["tool_calls"][0]["function"]["name"] == "Person" + assert json.loads(additional_kwargs["tool_calls"][0]["function"]["arguments"]) == { + "name": "Erick", + "age": 27, + } + + +@pytest.mark.xfail( + reason="Cohere models return empty output when a tool is passed in but not called." +) +def test_streaming_tool_call_no_tool_calls() -> None: + llm = ChatCohere(temperature=0) + + class Person(BaseModel): + name: str + age: int + + tool_llm = llm.bind_tools([Person]) + + # where it doesn't call the tool + strm = tool_llm.stream("What is 2+2?") + acc: Any = None + for chunk in strm: + assert isinstance(chunk, AIMessageChunk) + acc = chunk if acc is None else acc + chunk + assert acc.content != "" + assert "tool_calls" not in acc.additional_kwargs diff --git a/libs/partners/cohere/tests/unit_tests/test_chat_models.py b/libs/partners/cohere/tests/unit_tests/test_chat_models.py index eecfe33f33..545fa7f288 100644 --- a/libs/partners/cohere/tests/unit_tests/test_chat_models.py +++ b/libs/partners/cohere/tests/unit_tests/test_chat_models.py @@ -2,6 +2,7 @@ import typing import pytest +from cohere.types import NonStreamedChatResponse, ToolCall from langchain_cohere.chat_models import ChatCohere @@ -28,3 +29,87 @@ def test_initialization() -> None: def test_default_params(chat_cohere: ChatCohere, expected: typing.Dict) -> None: actual = chat_cohere._default_params assert expected == actual + + +@pytest.mark.parametrize( + "response, expected", + [ + pytest.param( + NonStreamedChatResponse( + generation_id="foo", + text="", + tool_calls=[ + ToolCall(name="tool1", parameters={"arg1": 1, "arg2": "2"}), + ToolCall(name="tool2", parameters={"arg3": 3, "arg4": "4"}), + ], + ), + { + "documents": None, + "citations": None, + "search_results": None, + "search_queries": None, + "is_search_required": None, + "generation_id": "foo", + "tool_calls": [ + { + "id": "foo", + "function": { + "name": "tool1", + "arguments": '{"arg1": 1, "arg2": "2"}', + }, + "type": "function", + }, + { + "id": "foo", + "function": { + "name": "tool2", + "arguments": '{"arg3": 3, "arg4": "4"}', + }, + "type": "function", + }, + ], + }, + id="tools should be called", + ), + pytest.param( + NonStreamedChatResponse( + generation_id="foo", + text="", + tool_calls=[], + ), + { + "documents": None, + "citations": None, + "search_results": None, + "search_queries": None, + "is_search_required": None, + "generation_id": "foo", + }, + id="no tools should be called", + ), + pytest.param( + NonStreamedChatResponse( + generation_id="foo", + text="bar", + tool_calls=[], + ), + { + "documents": None, + "citations": None, + "search_results": None, + "search_queries": None, + "is_search_required": None, + "generation_id": "foo", + }, + id="chat response without tools/documents/citations/tools etc", + ), + ], +) +def test_get_generation_info( + response: typing.Any, expected: typing.Dict[str, typing.Any] +) -> None: + chat_cohere = ChatCohere(cohere_api_key="test") + + actual = chat_cohere._get_generation_info(response) + + assert expected == actual diff --git a/libs/partners/cohere/tests/unit_tests/test_cohere_agent.py b/libs/partners/cohere/tests/unit_tests/test_cohere_agent.py new file mode 100644 index 0000000000..9dc082a55e --- /dev/null +++ b/libs/partners/cohere/tests/unit_tests/test_cohere_agent.py @@ -0,0 +1,82 @@ +from typing import Any, Dict, Optional, Type, Union + +import pytest +from langchain_core.tools import BaseModel, BaseTool, Field + +from langchain_cohere.cohere_agent import _format_to_cohere_tools + +expected_test_tool_definition = { + "description": "test_tool description", + "name": "test_tool", + "parameter_definitions": { + "arg_1": { + "description": "Arg1 description", + "required": True, + "type": "string", + }, + "optional_arg_2": { + "description": "Arg2 description", + "required": False, + "type": "string", + }, + "arg_3": { + "description": "Arg3 description", + "required": True, + "type": "integer", + }, + }, +} + + +class _TestToolSchema(BaseModel): + arg_1: str = Field(description="Arg1 description") + optional_arg_2: Optional[str] = Field(description="Arg2 description", default="2") + arg_3: int = Field(description="Arg3 description") + + +class _TestTool(BaseTool): + name = "test_tool" + description = "test_tool description" + args_schema: Type[_TestToolSchema] = _TestToolSchema + + def _run(self, *args: Any, **kwargs: Any) -> Any: + pass + + +class test_tool(BaseModel): + """test_tool description""" + + arg_1: str = Field(description="Arg1 description") + optional_arg_2: Optional[str] = Field(description="Arg2 description", default="2") + arg_3: int = Field(description="Arg3 description") + + +test_tool_as_dict = { + "title": "test_tool", + "description": "test_tool description", + "properties": { + "arg_1": {"description": "Arg1 description", "type": "string"}, + "optional_arg_2": { + "description": "Arg2 description", + "type": "string", + "default": "2", + }, + "arg_3": {"description": "Arg3 description", "type": "integer"}, + }, +} + + +@pytest.mark.parametrize( + "tool", + [ + pytest.param(_TestTool(), id="tool from BaseTool"), + pytest.param(test_tool, id="BaseModel"), + pytest.param(test_tool_as_dict, id="JSON schema dict"), + ], +) +def test_format_to_cohere_tools( + tool: Union[Dict[str, Any], BaseTool, Type[BaseModel]], +) -> None: + actual = _format_to_cohere_tools([tool]) + + assert [expected_test_tool_definition] == actual diff --git a/libs/partners/cohere/tests/unit_tests/test_imports.py b/libs/partners/cohere/tests/unit_tests/test_imports.py index ceff62f104..0159c19e94 100644 --- a/libs/partners/cohere/tests/unit_tests/test_imports.py +++ b/libs/partners/cohere/tests/unit_tests/test_imports.py @@ -6,6 +6,7 @@ EXPECTED_ALL = [ "CohereEmbeddings", "CohereRagRetriever", "CohereRerank", + "create_cohere_tools_agent", ]