mirror of
https://github.com/hwchase17/langchain
synced 2024-11-06 03:20:49 +00:00
cohere[patch]: Add cohere tools agent (#19602)
**Description**: Adds a cohere tools agent and related notebook. --------- Co-authored-by: BeatrixCohere <128378696+BeatrixCohere@users.noreply.github.com> Co-authored-by: Erick Friis <erick@langchain.dev>
This commit is contained in:
parent
5c41f4083e
commit
3685f8ceac
@ -1,4 +1,5 @@
|
||||
from langchain_cohere.chat_models import ChatCohere
|
||||
from langchain_cohere.cohere_agent import create_cohere_tools_agent
|
||||
from langchain_cohere.embeddings import CohereEmbeddings
|
||||
from langchain_cohere.rag_retrievers import CohereRagRetriever
|
||||
from langchain_cohere.rerank import CohereRerank
|
||||
@ -9,4 +10,5 @@ __all__ = [
|
||||
"CohereEmbeddings",
|
||||
"CohereRagRetriever",
|
||||
"CohereRerank",
|
||||
"create_cohere_tools_agent",
|
||||
]
|
||||
|
@ -1,9 +1,22 @@
|
||||
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional
|
||||
import json
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncIterator,
|
||||
Dict,
|
||||
Iterator,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
Type,
|
||||
Union,
|
||||
)
|
||||
|
||||
from cohere.types import NonStreamedChatResponse, ToolCall
|
||||
from langchain_core.callbacks import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain_core.language_models import LanguageModelInput
|
||||
from langchain_core.language_models.chat_models import (
|
||||
BaseChatModel,
|
||||
agenerate_from_stream,
|
||||
@ -18,7 +31,11 @@ from langchain_core.messages import (
|
||||
SystemMessage,
|
||||
)
|
||||
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
||||
from langchain_core.pydantic_v1 import BaseModel
|
||||
from langchain_core.runnables import Runnable
|
||||
from langchain_core.tools import BaseTool
|
||||
|
||||
from langchain_cohere.cohere_agent import _format_to_cohere_tools
|
||||
from langchain_cohere.llms import BaseCohere
|
||||
|
||||
|
||||
@ -143,6 +160,14 @@ class ChatCohere(BaseChatModel, BaseCohere):
|
||||
}
|
||||
return {k: v for k, v in base_params.items() if v is not None}
|
||||
|
||||
def bind_tools(
|
||||
self,
|
||||
tools: Sequence[Union[Dict[str, Any], BaseTool, Type[BaseModel]]],
|
||||
**kwargs: Any,
|
||||
) -> Runnable[LanguageModelInput, BaseMessage]:
|
||||
formatted_tools = _format_to_cohere_tools(tools)
|
||||
return super().bind(tools=formatted_tools, **kwargs)
|
||||
|
||||
@property
|
||||
def _identifying_params(self) -> Dict[str, Any]:
|
||||
"""Get the identifying parameters."""
|
||||
@ -169,6 +194,14 @@ class ChatCohere(BaseChatModel, BaseCohere):
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(delta, chunk=chunk)
|
||||
yield chunk
|
||||
elif data.event_type == "stream-end":
|
||||
generation_info = self._get_generation_info(data.response)
|
||||
yield ChatGenerationChunk(
|
||||
message=AIMessageChunk(
|
||||
content="", additional_kwargs=generation_info
|
||||
),
|
||||
generation_info=generation_info,
|
||||
)
|
||||
|
||||
async def _astream(
|
||||
self,
|
||||
@ -191,16 +224,34 @@ class ChatCohere(BaseChatModel, BaseCohere):
|
||||
if run_manager:
|
||||
await run_manager.on_llm_new_token(delta, chunk=chunk)
|
||||
yield chunk
|
||||
elif data.event_type == "stream-end":
|
||||
generation_info = self._get_generation_info(data.response)
|
||||
yield ChatGenerationChunk(
|
||||
message=AIMessageChunk(
|
||||
content="", additional_kwargs=generation_info
|
||||
),
|
||||
generation_info=generation_info,
|
||||
)
|
||||
|
||||
def _get_generation_info(self, response: Any) -> Dict[str, Any]:
|
||||
def _get_generation_info(self, response: NonStreamedChatResponse) -> Dict[str, Any]:
|
||||
"""Get the generation info from cohere API response."""
|
||||
return {
|
||||
generation_info = {
|
||||
"documents": response.documents,
|
||||
"citations": response.citations,
|
||||
"search_results": response.search_results,
|
||||
"search_queries": response.search_queries,
|
||||
"token_count": response.token_count,
|
||||
"is_search_required": response.is_search_required,
|
||||
"generation_id": response.generation_id,
|
||||
}
|
||||
if response.tool_calls:
|
||||
# Only populate tool_calls when 1) present on the response and
|
||||
# 2) has one or more calls.
|
||||
generation_info["tool_calls"] = _format_cohere_tool_calls(
|
||||
response.generation_id or "", response.tool_calls
|
||||
)
|
||||
if hasattr(response, "token_count"):
|
||||
generation_info["token_count"] = response.token_count
|
||||
return generation_info
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
@ -218,10 +269,8 @@ class ChatCohere(BaseChatModel, BaseCohere):
|
||||
request = get_cohere_chat_request(messages, **self._default_params, **kwargs)
|
||||
response = self.client.chat(**request)
|
||||
|
||||
message = AIMessage(content=response.text)
|
||||
generation_info = None
|
||||
if hasattr(response, "documents"):
|
||||
generation_info = self._get_generation_info(response)
|
||||
message = AIMessage(content=response.text, additional_kwargs=generation_info)
|
||||
return ChatResult(
|
||||
generations=[
|
||||
ChatGeneration(message=message, generation_info=generation_info)
|
||||
@ -244,10 +293,8 @@ class ChatCohere(BaseChatModel, BaseCohere):
|
||||
request = get_cohere_chat_request(messages, **self._default_params, **kwargs)
|
||||
response = self.client.chat(**request)
|
||||
|
||||
message = AIMessage(content=response.text)
|
||||
generation_info = None
|
||||
if hasattr(response, "documents"):
|
||||
generation_info = self._get_generation_info(response)
|
||||
message = AIMessage(content=response.text, additional_kwargs=generation_info)
|
||||
return ChatResult(
|
||||
generations=[
|
||||
ChatGeneration(message=message, generation_info=generation_info)
|
||||
@ -257,3 +304,27 @@ class ChatCohere(BaseChatModel, BaseCohere):
|
||||
def get_num_tokens(self, text: str) -> int:
|
||||
"""Calculate number of tokens."""
|
||||
return len(self.client.tokenize(text).tokens)
|
||||
|
||||
|
||||
def _format_cohere_tool_calls(
|
||||
generation_id: str, tool_calls: Optional[List[ToolCall]] = None
|
||||
) -> List[Dict]:
|
||||
"""
|
||||
Formats a Cohere API response into the tool call format used elsewhere in Langchain.
|
||||
"""
|
||||
if not tool_calls:
|
||||
return []
|
||||
|
||||
formatted_tool_calls = []
|
||||
for tool_call in tool_calls:
|
||||
formatted_tool_calls.append(
|
||||
{
|
||||
"id": generation_id,
|
||||
"function": {
|
||||
"name": tool_call.name,
|
||||
"arguments": json.dumps(tool_call.parameters),
|
||||
},
|
||||
"type": "function",
|
||||
}
|
||||
)
|
||||
return formatted_tool_calls
|
||||
|
168
libs/partners/cohere/langchain_cohere/cohere_agent.py
Normal file
168
libs/partners/cohere/langchain_cohere/cohere_agent.py
Normal file
@ -0,0 +1,168 @@
|
||||
from typing import Any, Dict, List, Sequence, Tuple, Type, Union
|
||||
|
||||
from cohere.types import Tool, ToolParameterDefinitionsValue
|
||||
from langchain_core.agents import AgentAction, AgentFinish
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
from langchain_core.output_parsers import BaseOutputParser
|
||||
from langchain_core.outputs import Generation
|
||||
from langchain_core.outputs.chat_generation import ChatGeneration
|
||||
from langchain_core.prompts.chat import ChatPromptTemplate
|
||||
from langchain_core.pydantic_v1 import BaseModel
|
||||
from langchain_core.runnables import Runnable, RunnablePassthrough
|
||||
from langchain_core.runnables.base import RunnableLambda
|
||||
from langchain_core.tools import BaseTool
|
||||
from langchain_core.utils.function_calling import convert_to_openai_function
|
||||
|
||||
|
||||
def create_cohere_tools_agent(
|
||||
llm: BaseLanguageModel, tools: Sequence[BaseTool], prompt: ChatPromptTemplate
|
||||
) -> Runnable:
|
||||
def llm_with_tools(input_: Dict) -> Runnable:
|
||||
tool_results = (
|
||||
input_["tool_results"] if len(input_["tool_results"]) > 0 else None
|
||||
)
|
||||
tools_ = input_["tools"] if len(input_["tools"]) > 0 else None
|
||||
return RunnableLambda(lambda x: x["input"]) | llm.bind(
|
||||
tools=tools_, tool_results=tool_results
|
||||
)
|
||||
|
||||
agent = (
|
||||
RunnablePassthrough.assign(
|
||||
# Intermediate steps are in tool results.
|
||||
# Edit below to change the prompt parameters.
|
||||
input=lambda x: prompt.format_messages(
|
||||
input=x["input"], agent_scratchpad=[]
|
||||
),
|
||||
tools=lambda x: _format_to_cohere_tools(tools),
|
||||
tool_results=lambda x: _format_to_cohere_tools_messages(
|
||||
x["intermediate_steps"]
|
||||
),
|
||||
)
|
||||
| llm_with_tools
|
||||
| _CohereToolsAgentOutputParser()
|
||||
)
|
||||
return agent
|
||||
|
||||
|
||||
def _format_to_cohere_tools(
|
||||
tools: Sequence[Union[Dict[str, Any], BaseTool, Type[BaseModel]]],
|
||||
) -> List[Dict[str, Any]]:
|
||||
return [_convert_to_cohere_tool(tool) for tool in tools]
|
||||
|
||||
|
||||
def _format_to_cohere_tools_messages(
|
||||
intermediate_steps: Sequence[Tuple[AgentAction, str]],
|
||||
) -> list:
|
||||
"""Convert (AgentAction, tool output) tuples into tool messages."""
|
||||
if len(intermediate_steps) == 0:
|
||||
return []
|
||||
tool_results = []
|
||||
for agent_action, observation in intermediate_steps:
|
||||
tool_results.append(
|
||||
{
|
||||
"call": {
|
||||
"name": agent_action.tool,
|
||||
"parameters": agent_action.tool_input,
|
||||
},
|
||||
"outputs": [{"answer": observation}],
|
||||
}
|
||||
)
|
||||
|
||||
return tool_results
|
||||
|
||||
|
||||
def _convert_to_cohere_tool(
|
||||
tool: Union[Dict[str, Any], BaseTool, Type[BaseModel]],
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Convert a BaseTool instance, JSON schema dict, or BaseModel type to a Cohere tool.
|
||||
"""
|
||||
if isinstance(tool, BaseTool):
|
||||
return Tool(
|
||||
name=tool.name,
|
||||
description=tool.description,
|
||||
parameter_definitions={
|
||||
param_name: ToolParameterDefinitionsValue(
|
||||
description=param_definition.get("description"),
|
||||
type=param_definition.get("type"),
|
||||
required="default" not in param_definition,
|
||||
)
|
||||
for param_name, param_definition in tool.args.items()
|
||||
},
|
||||
).dict()
|
||||
elif isinstance(tool, dict):
|
||||
if not all(k in tool for k in ("title", "description", "properties")):
|
||||
raise ValueError(
|
||||
"Unsupported dict type. Tool must be passed in as a BaseTool instance, JSON schema dict, or BaseModel type." # noqa: E501
|
||||
)
|
||||
return Tool(
|
||||
name=tool.get("title"),
|
||||
description=tool.get("description"),
|
||||
parameter_definitions={
|
||||
param_name: ToolParameterDefinitionsValue(
|
||||
description=param_definition.get("description"),
|
||||
type=param_definition.get("type"),
|
||||
required="default" not in param_definition,
|
||||
)
|
||||
for param_name, param_definition in tool.get("properties", {}).items()
|
||||
},
|
||||
).dict()
|
||||
elif issubclass(tool, BaseModel):
|
||||
as_json_schema_function = convert_to_openai_function(tool)
|
||||
parameters = as_json_schema_function.get("parameters", {})
|
||||
properties = parameters.get("properties", {})
|
||||
return Tool(
|
||||
name=as_json_schema_function.get("name"),
|
||||
description=as_json_schema_function.get(
|
||||
# The Cohere API requires the description field.
|
||||
"description",
|
||||
as_json_schema_function.get("name"),
|
||||
),
|
||||
parameter_definitions={
|
||||
param_name: ToolParameterDefinitionsValue(
|
||||
description=param_definition.get("description"),
|
||||
type=param_definition.get("type"),
|
||||
required=param_name in parameters.get("required", []),
|
||||
)
|
||||
for param_name, param_definition in properties.items()
|
||||
},
|
||||
).dict()
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported tool type {type(tool)}. Tool must be passed in as a BaseTool instance, JSON schema dict, or BaseModel type." # noqa: E501
|
||||
)
|
||||
|
||||
|
||||
class _CohereToolsAgentOutputParser(
|
||||
BaseOutputParser[Union[List[AgentAction], AgentFinish]]
|
||||
):
|
||||
"""Parses a message into agent actions/finish."""
|
||||
|
||||
def parse_result(
|
||||
self, result: List[Generation], *, partial: bool = False
|
||||
) -> Union[List[AgentAction], AgentFinish]:
|
||||
if not isinstance(result[0], ChatGeneration):
|
||||
raise ValueError(f"Expected ChatGeneration, got {type(result)}")
|
||||
if result[0].message.additional_kwargs["tool_calls"]:
|
||||
actions = []
|
||||
for tool in result[0].message.additional_kwargs["tool_calls"]:
|
||||
function = tool.get("function", {})
|
||||
actions.append(
|
||||
AgentAction(
|
||||
tool=function.get("name"),
|
||||
tool_input=function.get("arguments"),
|
||||
log=function.get("name"),
|
||||
)
|
||||
)
|
||||
return actions
|
||||
else:
|
||||
return AgentFinish(
|
||||
return_values={
|
||||
"text": result[0].message.content,
|
||||
"additional_info": result[0].message.additional_kwargs,
|
||||
},
|
||||
log="",
|
||||
)
|
||||
|
||||
def parse(self, text: str) -> Union[List[AgentAction], AgentFinish]:
|
||||
raise ValueError("Can only parse messages")
|
@ -20,6 +20,7 @@ def _get_docs(response: Any) -> List[Document]:
|
||||
docs = (
|
||||
[]
|
||||
if "documents" not in response.generation_info
|
||||
or len(response.generation_info["documents"]) == 0
|
||||
else [
|
||||
Document(page_content=doc["snippet"], metadata=doc)
|
||||
for doc in response.generation_info["documents"]
|
||||
|
9
libs/partners/cohere/poetry.lock
generated
9
libs/partners/cohere/poetry.lock
generated
@ -165,13 +165,13 @@ types = ["chardet (>=5.1.0)", "mypy", "pytest", "pytest-cov", "pytest-dependency
|
||||
|
||||
[[package]]
|
||||
name = "cohere"
|
||||
version = "5.1.2"
|
||||
version = "5.1.4"
|
||||
description = ""
|
||||
optional = false
|
||||
python-versions = "<4.0,>=3.8"
|
||||
files = [
|
||||
{file = "cohere-5.1.2-py3-none-any.whl", hash = "sha256:7782e32cba671fc04203c3b56a9ce1b70e9459d7c983e8576b04d394fbe809f5"},
|
||||
{file = "cohere-5.1.2.tar.gz", hash = "sha256:21af5ed6edcf939062c41240040316084cd7e753cf3207f661f68abb4bbbe846"},
|
||||
{file = "cohere-5.1.4-py3-none-any.whl", hash = "sha256:b88c44dfa44301f55f509db120582a6127c2e391c6c43a4dc58767f4df056a9d"},
|
||||
{file = "cohere-5.1.4.tar.gz", hash = "sha256:81b45fe37df2d62aaf57094402cb62b5fed285c25667dab96023f2ad2591ff35"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@ -742,7 +742,6 @@ files = [
|
||||
{file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"},
|
||||
{file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"},
|
||||
{file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"},
|
||||
{file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef"},
|
||||
{file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"},
|
||||
{file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"},
|
||||
{file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"},
|
||||
@ -957,4 +956,4 @@ watchmedo = ["PyYAML (>=3.10)"]
|
||||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = ">=3.8.1,<4.0"
|
||||
content-hash = "7ed2d31c084d528c64eb959df1a2ea29345a70117e9d29f322607fe247804cc5"
|
||||
content-hash = "6a5887a0391a649e1a45f3e3c766a880e133367d2656a9b5a37d75ebc33adef6"
|
||||
|
@ -1,6 +1,6 @@
|
||||
[tool.poetry]
|
||||
name = "langchain-cohere"
|
||||
version = "0.1.0rc1"
|
||||
version = "0.1.0rc2"
|
||||
description = "An integration package connecting Cohere and LangChain"
|
||||
authors = []
|
||||
readme = "README.md"
|
||||
@ -13,7 +13,7 @@ license = "MIT"
|
||||
[tool.poetry.dependencies]
|
||||
python = ">=3.8.1,<4.0"
|
||||
langchain-core = "^0.1.32"
|
||||
cohere = "^5.1.1"
|
||||
cohere = "^5.1.4"
|
||||
|
||||
[tool.poetry.group.test]
|
||||
optional = true
|
||||
|
@ -1,4 +1,12 @@
|
||||
"""Test ChatCohere chat model."""
|
||||
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
from langchain_core.messages import AIMessage, AIMessageChunk
|
||||
from langchain_core.pydantic_v1 import BaseModel
|
||||
|
||||
from langchain_cohere import ChatCohere
|
||||
|
||||
|
||||
@ -61,3 +69,76 @@ def test_invoke() -> None:
|
||||
|
||||
result = llm.invoke("I'm Pickle Rick", config=dict(tags=["foo"]))
|
||||
assert isinstance(result.content, str)
|
||||
|
||||
|
||||
def test_invoke_tool_calls() -> None:
|
||||
llm = ChatCohere(temperature=0)
|
||||
|
||||
class Person(BaseModel):
|
||||
name: str
|
||||
age: int
|
||||
|
||||
tool_llm = llm.bind_tools([Person])
|
||||
|
||||
# where it calls the tool
|
||||
result = tool_llm.invoke("Erick, 27 years old")
|
||||
|
||||
assert isinstance(result, AIMessage)
|
||||
additional_kwargs = result.additional_kwargs
|
||||
assert "tool_calls" in additional_kwargs
|
||||
assert len(additional_kwargs["tool_calls"]) == 1
|
||||
assert additional_kwargs["tool_calls"][0]["function"]["name"] == "Person"
|
||||
assert json.loads(additional_kwargs["tool_calls"][0]["function"]["arguments"]) == {
|
||||
"name": "Erick",
|
||||
"age": 27,
|
||||
}
|
||||
|
||||
|
||||
def test_streaming_tool_call() -> None:
|
||||
llm = ChatCohere(temperature=0)
|
||||
|
||||
class Person(BaseModel):
|
||||
name: str
|
||||
age: int
|
||||
|
||||
tool_llm = llm.bind_tools([Person])
|
||||
|
||||
# where it calls the tool
|
||||
strm = tool_llm.stream("Erick, 27 years old")
|
||||
|
||||
additional_kwargs = None
|
||||
for chunk in strm:
|
||||
assert isinstance(chunk, AIMessageChunk)
|
||||
assert chunk.content == ""
|
||||
additional_kwargs = chunk.additional_kwargs
|
||||
|
||||
assert additional_kwargs is not None
|
||||
assert "tool_calls" in additional_kwargs
|
||||
assert len(additional_kwargs["tool_calls"]) == 1
|
||||
assert additional_kwargs["tool_calls"][0]["function"]["name"] == "Person"
|
||||
assert json.loads(additional_kwargs["tool_calls"][0]["function"]["arguments"]) == {
|
||||
"name": "Erick",
|
||||
"age": 27,
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.xfail(
|
||||
reason="Cohere models return empty output when a tool is passed in but not called."
|
||||
)
|
||||
def test_streaming_tool_call_no_tool_calls() -> None:
|
||||
llm = ChatCohere(temperature=0)
|
||||
|
||||
class Person(BaseModel):
|
||||
name: str
|
||||
age: int
|
||||
|
||||
tool_llm = llm.bind_tools([Person])
|
||||
|
||||
# where it doesn't call the tool
|
||||
strm = tool_llm.stream("What is 2+2?")
|
||||
acc: Any = None
|
||||
for chunk in strm:
|
||||
assert isinstance(chunk, AIMessageChunk)
|
||||
acc = chunk if acc is None else acc + chunk
|
||||
assert acc.content != ""
|
||||
assert "tool_calls" not in acc.additional_kwargs
|
||||
|
@ -2,6 +2,7 @@
|
||||
import typing
|
||||
|
||||
import pytest
|
||||
from cohere.types import NonStreamedChatResponse, ToolCall
|
||||
|
||||
from langchain_cohere.chat_models import ChatCohere
|
||||
|
||||
@ -28,3 +29,87 @@ def test_initialization() -> None:
|
||||
def test_default_params(chat_cohere: ChatCohere, expected: typing.Dict) -> None:
|
||||
actual = chat_cohere._default_params
|
||||
assert expected == actual
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"response, expected",
|
||||
[
|
||||
pytest.param(
|
||||
NonStreamedChatResponse(
|
||||
generation_id="foo",
|
||||
text="",
|
||||
tool_calls=[
|
||||
ToolCall(name="tool1", parameters={"arg1": 1, "arg2": "2"}),
|
||||
ToolCall(name="tool2", parameters={"arg3": 3, "arg4": "4"}),
|
||||
],
|
||||
),
|
||||
{
|
||||
"documents": None,
|
||||
"citations": None,
|
||||
"search_results": None,
|
||||
"search_queries": None,
|
||||
"is_search_required": None,
|
||||
"generation_id": "foo",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "foo",
|
||||
"function": {
|
||||
"name": "tool1",
|
||||
"arguments": '{"arg1": 1, "arg2": "2"}',
|
||||
},
|
||||
"type": "function",
|
||||
},
|
||||
{
|
||||
"id": "foo",
|
||||
"function": {
|
||||
"name": "tool2",
|
||||
"arguments": '{"arg3": 3, "arg4": "4"}',
|
||||
},
|
||||
"type": "function",
|
||||
},
|
||||
],
|
||||
},
|
||||
id="tools should be called",
|
||||
),
|
||||
pytest.param(
|
||||
NonStreamedChatResponse(
|
||||
generation_id="foo",
|
||||
text="",
|
||||
tool_calls=[],
|
||||
),
|
||||
{
|
||||
"documents": None,
|
||||
"citations": None,
|
||||
"search_results": None,
|
||||
"search_queries": None,
|
||||
"is_search_required": None,
|
||||
"generation_id": "foo",
|
||||
},
|
||||
id="no tools should be called",
|
||||
),
|
||||
pytest.param(
|
||||
NonStreamedChatResponse(
|
||||
generation_id="foo",
|
||||
text="bar",
|
||||
tool_calls=[],
|
||||
),
|
||||
{
|
||||
"documents": None,
|
||||
"citations": None,
|
||||
"search_results": None,
|
||||
"search_queries": None,
|
||||
"is_search_required": None,
|
||||
"generation_id": "foo",
|
||||
},
|
||||
id="chat response without tools/documents/citations/tools etc",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_get_generation_info(
|
||||
response: typing.Any, expected: typing.Dict[str, typing.Any]
|
||||
) -> None:
|
||||
chat_cohere = ChatCohere(cohere_api_key="test")
|
||||
|
||||
actual = chat_cohere._get_generation_info(response)
|
||||
|
||||
assert expected == actual
|
||||
|
82
libs/partners/cohere/tests/unit_tests/test_cohere_agent.py
Normal file
82
libs/partners/cohere/tests/unit_tests/test_cohere_agent.py
Normal file
@ -0,0 +1,82 @@
|
||||
from typing import Any, Dict, Optional, Type, Union
|
||||
|
||||
import pytest
|
||||
from langchain_core.tools import BaseModel, BaseTool, Field
|
||||
|
||||
from langchain_cohere.cohere_agent import _format_to_cohere_tools
|
||||
|
||||
expected_test_tool_definition = {
|
||||
"description": "test_tool description",
|
||||
"name": "test_tool",
|
||||
"parameter_definitions": {
|
||||
"arg_1": {
|
||||
"description": "Arg1 description",
|
||||
"required": True,
|
||||
"type": "string",
|
||||
},
|
||||
"optional_arg_2": {
|
||||
"description": "Arg2 description",
|
||||
"required": False,
|
||||
"type": "string",
|
||||
},
|
||||
"arg_3": {
|
||||
"description": "Arg3 description",
|
||||
"required": True,
|
||||
"type": "integer",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class _TestToolSchema(BaseModel):
|
||||
arg_1: str = Field(description="Arg1 description")
|
||||
optional_arg_2: Optional[str] = Field(description="Arg2 description", default="2")
|
||||
arg_3: int = Field(description="Arg3 description")
|
||||
|
||||
|
||||
class _TestTool(BaseTool):
|
||||
name = "test_tool"
|
||||
description = "test_tool description"
|
||||
args_schema: Type[_TestToolSchema] = _TestToolSchema
|
||||
|
||||
def _run(self, *args: Any, **kwargs: Any) -> Any:
|
||||
pass
|
||||
|
||||
|
||||
class test_tool(BaseModel):
|
||||
"""test_tool description"""
|
||||
|
||||
arg_1: str = Field(description="Arg1 description")
|
||||
optional_arg_2: Optional[str] = Field(description="Arg2 description", default="2")
|
||||
arg_3: int = Field(description="Arg3 description")
|
||||
|
||||
|
||||
test_tool_as_dict = {
|
||||
"title": "test_tool",
|
||||
"description": "test_tool description",
|
||||
"properties": {
|
||||
"arg_1": {"description": "Arg1 description", "type": "string"},
|
||||
"optional_arg_2": {
|
||||
"description": "Arg2 description",
|
||||
"type": "string",
|
||||
"default": "2",
|
||||
},
|
||||
"arg_3": {"description": "Arg3 description", "type": "integer"},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"tool",
|
||||
[
|
||||
pytest.param(_TestTool(), id="tool from BaseTool"),
|
||||
pytest.param(test_tool, id="BaseModel"),
|
||||
pytest.param(test_tool_as_dict, id="JSON schema dict"),
|
||||
],
|
||||
)
|
||||
def test_format_to_cohere_tools(
|
||||
tool: Union[Dict[str, Any], BaseTool, Type[BaseModel]],
|
||||
) -> None:
|
||||
actual = _format_to_cohere_tools([tool])
|
||||
|
||||
assert [expected_test_tool_definition] == actual
|
@ -6,6 +6,7 @@ EXPECTED_ALL = [
|
||||
"CohereEmbeddings",
|
||||
"CohereRagRetriever",
|
||||
"CohereRerank",
|
||||
"create_cohere_tools_agent",
|
||||
]
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user