mirror of
https://github.com/hwchase17/langchain
synced 2024-11-08 07:10:35 +00:00
cohere[patch]: Add cohere tools agent (#19602)
**Description**: Adds a cohere tools agent and related notebook. --------- Co-authored-by: BeatrixCohere <128378696+BeatrixCohere@users.noreply.github.com> Co-authored-by: Erick Friis <erick@langchain.dev>
This commit is contained in:
parent
5c41f4083e
commit
3685f8ceac
@ -1,4 +1,5 @@
|
|||||||
from langchain_cohere.chat_models import ChatCohere
|
from langchain_cohere.chat_models import ChatCohere
|
||||||
|
from langchain_cohere.cohere_agent import create_cohere_tools_agent
|
||||||
from langchain_cohere.embeddings import CohereEmbeddings
|
from langchain_cohere.embeddings import CohereEmbeddings
|
||||||
from langchain_cohere.rag_retrievers import CohereRagRetriever
|
from langchain_cohere.rag_retrievers import CohereRagRetriever
|
||||||
from langchain_cohere.rerank import CohereRerank
|
from langchain_cohere.rerank import CohereRerank
|
||||||
@ -9,4 +10,5 @@ __all__ = [
|
|||||||
"CohereEmbeddings",
|
"CohereEmbeddings",
|
||||||
"CohereRagRetriever",
|
"CohereRagRetriever",
|
||||||
"CohereRerank",
|
"CohereRerank",
|
||||||
|
"create_cohere_tools_agent",
|
||||||
]
|
]
|
||||||
|
@ -1,9 +1,22 @@
|
|||||||
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional
|
import json
|
||||||
|
from typing import (
|
||||||
|
Any,
|
||||||
|
AsyncIterator,
|
||||||
|
Dict,
|
||||||
|
Iterator,
|
||||||
|
List,
|
||||||
|
Optional,
|
||||||
|
Sequence,
|
||||||
|
Type,
|
||||||
|
Union,
|
||||||
|
)
|
||||||
|
|
||||||
|
from cohere.types import NonStreamedChatResponse, ToolCall
|
||||||
from langchain_core.callbacks import (
|
from langchain_core.callbacks import (
|
||||||
AsyncCallbackManagerForLLMRun,
|
AsyncCallbackManagerForLLMRun,
|
||||||
CallbackManagerForLLMRun,
|
CallbackManagerForLLMRun,
|
||||||
)
|
)
|
||||||
|
from langchain_core.language_models import LanguageModelInput
|
||||||
from langchain_core.language_models.chat_models import (
|
from langchain_core.language_models.chat_models import (
|
||||||
BaseChatModel,
|
BaseChatModel,
|
||||||
agenerate_from_stream,
|
agenerate_from_stream,
|
||||||
@ -18,7 +31,11 @@ from langchain_core.messages import (
|
|||||||
SystemMessage,
|
SystemMessage,
|
||||||
)
|
)
|
||||||
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
||||||
|
from langchain_core.pydantic_v1 import BaseModel
|
||||||
|
from langchain_core.runnables import Runnable
|
||||||
|
from langchain_core.tools import BaseTool
|
||||||
|
|
||||||
|
from langchain_cohere.cohere_agent import _format_to_cohere_tools
|
||||||
from langchain_cohere.llms import BaseCohere
|
from langchain_cohere.llms import BaseCohere
|
||||||
|
|
||||||
|
|
||||||
@ -143,6 +160,14 @@ class ChatCohere(BaseChatModel, BaseCohere):
|
|||||||
}
|
}
|
||||||
return {k: v for k, v in base_params.items() if v is not None}
|
return {k: v for k, v in base_params.items() if v is not None}
|
||||||
|
|
||||||
|
def bind_tools(
|
||||||
|
self,
|
||||||
|
tools: Sequence[Union[Dict[str, Any], BaseTool, Type[BaseModel]]],
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Runnable[LanguageModelInput, BaseMessage]:
|
||||||
|
formatted_tools = _format_to_cohere_tools(tools)
|
||||||
|
return super().bind(tools=formatted_tools, **kwargs)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _identifying_params(self) -> Dict[str, Any]:
|
def _identifying_params(self) -> Dict[str, Any]:
|
||||||
"""Get the identifying parameters."""
|
"""Get the identifying parameters."""
|
||||||
@ -169,6 +194,14 @@ class ChatCohere(BaseChatModel, BaseCohere):
|
|||||||
if run_manager:
|
if run_manager:
|
||||||
run_manager.on_llm_new_token(delta, chunk=chunk)
|
run_manager.on_llm_new_token(delta, chunk=chunk)
|
||||||
yield chunk
|
yield chunk
|
||||||
|
elif data.event_type == "stream-end":
|
||||||
|
generation_info = self._get_generation_info(data.response)
|
||||||
|
yield ChatGenerationChunk(
|
||||||
|
message=AIMessageChunk(
|
||||||
|
content="", additional_kwargs=generation_info
|
||||||
|
),
|
||||||
|
generation_info=generation_info,
|
||||||
|
)
|
||||||
|
|
||||||
async def _astream(
|
async def _astream(
|
||||||
self,
|
self,
|
||||||
@ -191,16 +224,34 @@ class ChatCohere(BaseChatModel, BaseCohere):
|
|||||||
if run_manager:
|
if run_manager:
|
||||||
await run_manager.on_llm_new_token(delta, chunk=chunk)
|
await run_manager.on_llm_new_token(delta, chunk=chunk)
|
||||||
yield chunk
|
yield chunk
|
||||||
|
elif data.event_type == "stream-end":
|
||||||
|
generation_info = self._get_generation_info(data.response)
|
||||||
|
yield ChatGenerationChunk(
|
||||||
|
message=AIMessageChunk(
|
||||||
|
content="", additional_kwargs=generation_info
|
||||||
|
),
|
||||||
|
generation_info=generation_info,
|
||||||
|
)
|
||||||
|
|
||||||
def _get_generation_info(self, response: Any) -> Dict[str, Any]:
|
def _get_generation_info(self, response: NonStreamedChatResponse) -> Dict[str, Any]:
|
||||||
"""Get the generation info from cohere API response."""
|
"""Get the generation info from cohere API response."""
|
||||||
return {
|
generation_info = {
|
||||||
"documents": response.documents,
|
"documents": response.documents,
|
||||||
"citations": response.citations,
|
"citations": response.citations,
|
||||||
"search_results": response.search_results,
|
"search_results": response.search_results,
|
||||||
"search_queries": response.search_queries,
|
"search_queries": response.search_queries,
|
||||||
"token_count": response.token_count,
|
"is_search_required": response.is_search_required,
|
||||||
|
"generation_id": response.generation_id,
|
||||||
}
|
}
|
||||||
|
if response.tool_calls:
|
||||||
|
# Only populate tool_calls when 1) present on the response and
|
||||||
|
# 2) has one or more calls.
|
||||||
|
generation_info["tool_calls"] = _format_cohere_tool_calls(
|
||||||
|
response.generation_id or "", response.tool_calls
|
||||||
|
)
|
||||||
|
if hasattr(response, "token_count"):
|
||||||
|
generation_info["token_count"] = response.token_count
|
||||||
|
return generation_info
|
||||||
|
|
||||||
def _generate(
|
def _generate(
|
||||||
self,
|
self,
|
||||||
@ -218,10 +269,8 @@ class ChatCohere(BaseChatModel, BaseCohere):
|
|||||||
request = get_cohere_chat_request(messages, **self._default_params, **kwargs)
|
request = get_cohere_chat_request(messages, **self._default_params, **kwargs)
|
||||||
response = self.client.chat(**request)
|
response = self.client.chat(**request)
|
||||||
|
|
||||||
message = AIMessage(content=response.text)
|
|
||||||
generation_info = None
|
|
||||||
if hasattr(response, "documents"):
|
|
||||||
generation_info = self._get_generation_info(response)
|
generation_info = self._get_generation_info(response)
|
||||||
|
message = AIMessage(content=response.text, additional_kwargs=generation_info)
|
||||||
return ChatResult(
|
return ChatResult(
|
||||||
generations=[
|
generations=[
|
||||||
ChatGeneration(message=message, generation_info=generation_info)
|
ChatGeneration(message=message, generation_info=generation_info)
|
||||||
@ -244,10 +293,8 @@ class ChatCohere(BaseChatModel, BaseCohere):
|
|||||||
request = get_cohere_chat_request(messages, **self._default_params, **kwargs)
|
request = get_cohere_chat_request(messages, **self._default_params, **kwargs)
|
||||||
response = self.client.chat(**request)
|
response = self.client.chat(**request)
|
||||||
|
|
||||||
message = AIMessage(content=response.text)
|
|
||||||
generation_info = None
|
|
||||||
if hasattr(response, "documents"):
|
|
||||||
generation_info = self._get_generation_info(response)
|
generation_info = self._get_generation_info(response)
|
||||||
|
message = AIMessage(content=response.text, additional_kwargs=generation_info)
|
||||||
return ChatResult(
|
return ChatResult(
|
||||||
generations=[
|
generations=[
|
||||||
ChatGeneration(message=message, generation_info=generation_info)
|
ChatGeneration(message=message, generation_info=generation_info)
|
||||||
@ -257,3 +304,27 @@ class ChatCohere(BaseChatModel, BaseCohere):
|
|||||||
def get_num_tokens(self, text: str) -> int:
|
def get_num_tokens(self, text: str) -> int:
|
||||||
"""Calculate number of tokens."""
|
"""Calculate number of tokens."""
|
||||||
return len(self.client.tokenize(text).tokens)
|
return len(self.client.tokenize(text).tokens)
|
||||||
|
|
||||||
|
|
||||||
|
def _format_cohere_tool_calls(
|
||||||
|
generation_id: str, tool_calls: Optional[List[ToolCall]] = None
|
||||||
|
) -> List[Dict]:
|
||||||
|
"""
|
||||||
|
Formats a Cohere API response into the tool call format used elsewhere in Langchain.
|
||||||
|
"""
|
||||||
|
if not tool_calls:
|
||||||
|
return []
|
||||||
|
|
||||||
|
formatted_tool_calls = []
|
||||||
|
for tool_call in tool_calls:
|
||||||
|
formatted_tool_calls.append(
|
||||||
|
{
|
||||||
|
"id": generation_id,
|
||||||
|
"function": {
|
||||||
|
"name": tool_call.name,
|
||||||
|
"arguments": json.dumps(tool_call.parameters),
|
||||||
|
},
|
||||||
|
"type": "function",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return formatted_tool_calls
|
||||||
|
168
libs/partners/cohere/langchain_cohere/cohere_agent.py
Normal file
168
libs/partners/cohere/langchain_cohere/cohere_agent.py
Normal file
@ -0,0 +1,168 @@
|
|||||||
|
from typing import Any, Dict, List, Sequence, Tuple, Type, Union
|
||||||
|
|
||||||
|
from cohere.types import Tool, ToolParameterDefinitionsValue
|
||||||
|
from langchain_core.agents import AgentAction, AgentFinish
|
||||||
|
from langchain_core.language_models import BaseLanguageModel
|
||||||
|
from langchain_core.output_parsers import BaseOutputParser
|
||||||
|
from langchain_core.outputs import Generation
|
||||||
|
from langchain_core.outputs.chat_generation import ChatGeneration
|
||||||
|
from langchain_core.prompts.chat import ChatPromptTemplate
|
||||||
|
from langchain_core.pydantic_v1 import BaseModel
|
||||||
|
from langchain_core.runnables import Runnable, RunnablePassthrough
|
||||||
|
from langchain_core.runnables.base import RunnableLambda
|
||||||
|
from langchain_core.tools import BaseTool
|
||||||
|
from langchain_core.utils.function_calling import convert_to_openai_function
|
||||||
|
|
||||||
|
|
||||||
|
def create_cohere_tools_agent(
|
||||||
|
llm: BaseLanguageModel, tools: Sequence[BaseTool], prompt: ChatPromptTemplate
|
||||||
|
) -> Runnable:
|
||||||
|
def llm_with_tools(input_: Dict) -> Runnable:
|
||||||
|
tool_results = (
|
||||||
|
input_["tool_results"] if len(input_["tool_results"]) > 0 else None
|
||||||
|
)
|
||||||
|
tools_ = input_["tools"] if len(input_["tools"]) > 0 else None
|
||||||
|
return RunnableLambda(lambda x: x["input"]) | llm.bind(
|
||||||
|
tools=tools_, tool_results=tool_results
|
||||||
|
)
|
||||||
|
|
||||||
|
agent = (
|
||||||
|
RunnablePassthrough.assign(
|
||||||
|
# Intermediate steps are in tool results.
|
||||||
|
# Edit below to change the prompt parameters.
|
||||||
|
input=lambda x: prompt.format_messages(
|
||||||
|
input=x["input"], agent_scratchpad=[]
|
||||||
|
),
|
||||||
|
tools=lambda x: _format_to_cohere_tools(tools),
|
||||||
|
tool_results=lambda x: _format_to_cohere_tools_messages(
|
||||||
|
x["intermediate_steps"]
|
||||||
|
),
|
||||||
|
)
|
||||||
|
| llm_with_tools
|
||||||
|
| _CohereToolsAgentOutputParser()
|
||||||
|
)
|
||||||
|
return agent
|
||||||
|
|
||||||
|
|
||||||
|
def _format_to_cohere_tools(
|
||||||
|
tools: Sequence[Union[Dict[str, Any], BaseTool, Type[BaseModel]]],
|
||||||
|
) -> List[Dict[str, Any]]:
|
||||||
|
return [_convert_to_cohere_tool(tool) for tool in tools]
|
||||||
|
|
||||||
|
|
||||||
|
def _format_to_cohere_tools_messages(
|
||||||
|
intermediate_steps: Sequence[Tuple[AgentAction, str]],
|
||||||
|
) -> list:
|
||||||
|
"""Convert (AgentAction, tool output) tuples into tool messages."""
|
||||||
|
if len(intermediate_steps) == 0:
|
||||||
|
return []
|
||||||
|
tool_results = []
|
||||||
|
for agent_action, observation in intermediate_steps:
|
||||||
|
tool_results.append(
|
||||||
|
{
|
||||||
|
"call": {
|
||||||
|
"name": agent_action.tool,
|
||||||
|
"parameters": agent_action.tool_input,
|
||||||
|
},
|
||||||
|
"outputs": [{"answer": observation}],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return tool_results
|
||||||
|
|
||||||
|
|
||||||
|
def _convert_to_cohere_tool(
|
||||||
|
tool: Union[Dict[str, Any], BaseTool, Type[BaseModel]],
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Convert a BaseTool instance, JSON schema dict, or BaseModel type to a Cohere tool.
|
||||||
|
"""
|
||||||
|
if isinstance(tool, BaseTool):
|
||||||
|
return Tool(
|
||||||
|
name=tool.name,
|
||||||
|
description=tool.description,
|
||||||
|
parameter_definitions={
|
||||||
|
param_name: ToolParameterDefinitionsValue(
|
||||||
|
description=param_definition.get("description"),
|
||||||
|
type=param_definition.get("type"),
|
||||||
|
required="default" not in param_definition,
|
||||||
|
)
|
||||||
|
for param_name, param_definition in tool.args.items()
|
||||||
|
},
|
||||||
|
).dict()
|
||||||
|
elif isinstance(tool, dict):
|
||||||
|
if not all(k in tool for k in ("title", "description", "properties")):
|
||||||
|
raise ValueError(
|
||||||
|
"Unsupported dict type. Tool must be passed in as a BaseTool instance, JSON schema dict, or BaseModel type." # noqa: E501
|
||||||
|
)
|
||||||
|
return Tool(
|
||||||
|
name=tool.get("title"),
|
||||||
|
description=tool.get("description"),
|
||||||
|
parameter_definitions={
|
||||||
|
param_name: ToolParameterDefinitionsValue(
|
||||||
|
description=param_definition.get("description"),
|
||||||
|
type=param_definition.get("type"),
|
||||||
|
required="default" not in param_definition,
|
||||||
|
)
|
||||||
|
for param_name, param_definition in tool.get("properties", {}).items()
|
||||||
|
},
|
||||||
|
).dict()
|
||||||
|
elif issubclass(tool, BaseModel):
|
||||||
|
as_json_schema_function = convert_to_openai_function(tool)
|
||||||
|
parameters = as_json_schema_function.get("parameters", {})
|
||||||
|
properties = parameters.get("properties", {})
|
||||||
|
return Tool(
|
||||||
|
name=as_json_schema_function.get("name"),
|
||||||
|
description=as_json_schema_function.get(
|
||||||
|
# The Cohere API requires the description field.
|
||||||
|
"description",
|
||||||
|
as_json_schema_function.get("name"),
|
||||||
|
),
|
||||||
|
parameter_definitions={
|
||||||
|
param_name: ToolParameterDefinitionsValue(
|
||||||
|
description=param_definition.get("description"),
|
||||||
|
type=param_definition.get("type"),
|
||||||
|
required=param_name in parameters.get("required", []),
|
||||||
|
)
|
||||||
|
for param_name, param_definition in properties.items()
|
||||||
|
},
|
||||||
|
).dict()
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unsupported tool type {type(tool)}. Tool must be passed in as a BaseTool instance, JSON schema dict, or BaseModel type." # noqa: E501
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class _CohereToolsAgentOutputParser(
|
||||||
|
BaseOutputParser[Union[List[AgentAction], AgentFinish]]
|
||||||
|
):
|
||||||
|
"""Parses a message into agent actions/finish."""
|
||||||
|
|
||||||
|
def parse_result(
|
||||||
|
self, result: List[Generation], *, partial: bool = False
|
||||||
|
) -> Union[List[AgentAction], AgentFinish]:
|
||||||
|
if not isinstance(result[0], ChatGeneration):
|
||||||
|
raise ValueError(f"Expected ChatGeneration, got {type(result)}")
|
||||||
|
if result[0].message.additional_kwargs["tool_calls"]:
|
||||||
|
actions = []
|
||||||
|
for tool in result[0].message.additional_kwargs["tool_calls"]:
|
||||||
|
function = tool.get("function", {})
|
||||||
|
actions.append(
|
||||||
|
AgentAction(
|
||||||
|
tool=function.get("name"),
|
||||||
|
tool_input=function.get("arguments"),
|
||||||
|
log=function.get("name"),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return actions
|
||||||
|
else:
|
||||||
|
return AgentFinish(
|
||||||
|
return_values={
|
||||||
|
"text": result[0].message.content,
|
||||||
|
"additional_info": result[0].message.additional_kwargs,
|
||||||
|
},
|
||||||
|
log="",
|
||||||
|
)
|
||||||
|
|
||||||
|
def parse(self, text: str) -> Union[List[AgentAction], AgentFinish]:
|
||||||
|
raise ValueError("Can only parse messages")
|
@ -20,6 +20,7 @@ def _get_docs(response: Any) -> List[Document]:
|
|||||||
docs = (
|
docs = (
|
||||||
[]
|
[]
|
||||||
if "documents" not in response.generation_info
|
if "documents" not in response.generation_info
|
||||||
|
or len(response.generation_info["documents"]) == 0
|
||||||
else [
|
else [
|
||||||
Document(page_content=doc["snippet"], metadata=doc)
|
Document(page_content=doc["snippet"], metadata=doc)
|
||||||
for doc in response.generation_info["documents"]
|
for doc in response.generation_info["documents"]
|
||||||
|
9
libs/partners/cohere/poetry.lock
generated
9
libs/partners/cohere/poetry.lock
generated
@ -165,13 +165,13 @@ types = ["chardet (>=5.1.0)", "mypy", "pytest", "pytest-cov", "pytest-dependency
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "cohere"
|
name = "cohere"
|
||||||
version = "5.1.2"
|
version = "5.1.4"
|
||||||
description = ""
|
description = ""
|
||||||
optional = false
|
optional = false
|
||||||
python-versions = "<4.0,>=3.8"
|
python-versions = "<4.0,>=3.8"
|
||||||
files = [
|
files = [
|
||||||
{file = "cohere-5.1.2-py3-none-any.whl", hash = "sha256:7782e32cba671fc04203c3b56a9ce1b70e9459d7c983e8576b04d394fbe809f5"},
|
{file = "cohere-5.1.4-py3-none-any.whl", hash = "sha256:b88c44dfa44301f55f509db120582a6127c2e391c6c43a4dc58767f4df056a9d"},
|
||||||
{file = "cohere-5.1.2.tar.gz", hash = "sha256:21af5ed6edcf939062c41240040316084cd7e753cf3207f661f68abb4bbbe846"},
|
{file = "cohere-5.1.4.tar.gz", hash = "sha256:81b45fe37df2d62aaf57094402cb62b5fed285c25667dab96023f2ad2591ff35"},
|
||||||
]
|
]
|
||||||
|
|
||||||
[package.dependencies]
|
[package.dependencies]
|
||||||
@ -742,7 +742,6 @@ files = [
|
|||||||
{file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"},
|
{file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"},
|
||||||
{file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"},
|
{file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"},
|
||||||
{file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"},
|
{file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"},
|
||||||
{file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef"},
|
|
||||||
{file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"},
|
{file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"},
|
||||||
{file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"},
|
{file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"},
|
||||||
{file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"},
|
{file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"},
|
||||||
@ -957,4 +956,4 @@ watchmedo = ["PyYAML (>=3.10)"]
|
|||||||
[metadata]
|
[metadata]
|
||||||
lock-version = "2.0"
|
lock-version = "2.0"
|
||||||
python-versions = ">=3.8.1,<4.0"
|
python-versions = ">=3.8.1,<4.0"
|
||||||
content-hash = "7ed2d31c084d528c64eb959df1a2ea29345a70117e9d29f322607fe247804cc5"
|
content-hash = "6a5887a0391a649e1a45f3e3c766a880e133367d2656a9b5a37d75ebc33adef6"
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
[tool.poetry]
|
[tool.poetry]
|
||||||
name = "langchain-cohere"
|
name = "langchain-cohere"
|
||||||
version = "0.1.0rc1"
|
version = "0.1.0rc2"
|
||||||
description = "An integration package connecting Cohere and LangChain"
|
description = "An integration package connecting Cohere and LangChain"
|
||||||
authors = []
|
authors = []
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
@ -13,7 +13,7 @@ license = "MIT"
|
|||||||
[tool.poetry.dependencies]
|
[tool.poetry.dependencies]
|
||||||
python = ">=3.8.1,<4.0"
|
python = ">=3.8.1,<4.0"
|
||||||
langchain-core = "^0.1.32"
|
langchain-core = "^0.1.32"
|
||||||
cohere = "^5.1.1"
|
cohere = "^5.1.4"
|
||||||
|
|
||||||
[tool.poetry.group.test]
|
[tool.poetry.group.test]
|
||||||
optional = true
|
optional = true
|
||||||
|
@ -1,4 +1,12 @@
|
|||||||
"""Test ChatCohere chat model."""
|
"""Test ChatCohere chat model."""
|
||||||
|
|
||||||
|
import json
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from langchain_core.messages import AIMessage, AIMessageChunk
|
||||||
|
from langchain_core.pydantic_v1 import BaseModel
|
||||||
|
|
||||||
from langchain_cohere import ChatCohere
|
from langchain_cohere import ChatCohere
|
||||||
|
|
||||||
|
|
||||||
@ -61,3 +69,76 @@ def test_invoke() -> None:
|
|||||||
|
|
||||||
result = llm.invoke("I'm Pickle Rick", config=dict(tags=["foo"]))
|
result = llm.invoke("I'm Pickle Rick", config=dict(tags=["foo"]))
|
||||||
assert isinstance(result.content, str)
|
assert isinstance(result.content, str)
|
||||||
|
|
||||||
|
|
||||||
|
def test_invoke_tool_calls() -> None:
|
||||||
|
llm = ChatCohere(temperature=0)
|
||||||
|
|
||||||
|
class Person(BaseModel):
|
||||||
|
name: str
|
||||||
|
age: int
|
||||||
|
|
||||||
|
tool_llm = llm.bind_tools([Person])
|
||||||
|
|
||||||
|
# where it calls the tool
|
||||||
|
result = tool_llm.invoke("Erick, 27 years old")
|
||||||
|
|
||||||
|
assert isinstance(result, AIMessage)
|
||||||
|
additional_kwargs = result.additional_kwargs
|
||||||
|
assert "tool_calls" in additional_kwargs
|
||||||
|
assert len(additional_kwargs["tool_calls"]) == 1
|
||||||
|
assert additional_kwargs["tool_calls"][0]["function"]["name"] == "Person"
|
||||||
|
assert json.loads(additional_kwargs["tool_calls"][0]["function"]["arguments"]) == {
|
||||||
|
"name": "Erick",
|
||||||
|
"age": 27,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def test_streaming_tool_call() -> None:
|
||||||
|
llm = ChatCohere(temperature=0)
|
||||||
|
|
||||||
|
class Person(BaseModel):
|
||||||
|
name: str
|
||||||
|
age: int
|
||||||
|
|
||||||
|
tool_llm = llm.bind_tools([Person])
|
||||||
|
|
||||||
|
# where it calls the tool
|
||||||
|
strm = tool_llm.stream("Erick, 27 years old")
|
||||||
|
|
||||||
|
additional_kwargs = None
|
||||||
|
for chunk in strm:
|
||||||
|
assert isinstance(chunk, AIMessageChunk)
|
||||||
|
assert chunk.content == ""
|
||||||
|
additional_kwargs = chunk.additional_kwargs
|
||||||
|
|
||||||
|
assert additional_kwargs is not None
|
||||||
|
assert "tool_calls" in additional_kwargs
|
||||||
|
assert len(additional_kwargs["tool_calls"]) == 1
|
||||||
|
assert additional_kwargs["tool_calls"][0]["function"]["name"] == "Person"
|
||||||
|
assert json.loads(additional_kwargs["tool_calls"][0]["function"]["arguments"]) == {
|
||||||
|
"name": "Erick",
|
||||||
|
"age": 27,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.xfail(
|
||||||
|
reason="Cohere models return empty output when a tool is passed in but not called."
|
||||||
|
)
|
||||||
|
def test_streaming_tool_call_no_tool_calls() -> None:
|
||||||
|
llm = ChatCohere(temperature=0)
|
||||||
|
|
||||||
|
class Person(BaseModel):
|
||||||
|
name: str
|
||||||
|
age: int
|
||||||
|
|
||||||
|
tool_llm = llm.bind_tools([Person])
|
||||||
|
|
||||||
|
# where it doesn't call the tool
|
||||||
|
strm = tool_llm.stream("What is 2+2?")
|
||||||
|
acc: Any = None
|
||||||
|
for chunk in strm:
|
||||||
|
assert isinstance(chunk, AIMessageChunk)
|
||||||
|
acc = chunk if acc is None else acc + chunk
|
||||||
|
assert acc.content != ""
|
||||||
|
assert "tool_calls" not in acc.additional_kwargs
|
||||||
|
@ -2,6 +2,7 @@
|
|||||||
import typing
|
import typing
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
from cohere.types import NonStreamedChatResponse, ToolCall
|
||||||
|
|
||||||
from langchain_cohere.chat_models import ChatCohere
|
from langchain_cohere.chat_models import ChatCohere
|
||||||
|
|
||||||
@ -28,3 +29,87 @@ def test_initialization() -> None:
|
|||||||
def test_default_params(chat_cohere: ChatCohere, expected: typing.Dict) -> None:
|
def test_default_params(chat_cohere: ChatCohere, expected: typing.Dict) -> None:
|
||||||
actual = chat_cohere._default_params
|
actual = chat_cohere._default_params
|
||||||
assert expected == actual
|
assert expected == actual
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"response, expected",
|
||||||
|
[
|
||||||
|
pytest.param(
|
||||||
|
NonStreamedChatResponse(
|
||||||
|
generation_id="foo",
|
||||||
|
text="",
|
||||||
|
tool_calls=[
|
||||||
|
ToolCall(name="tool1", parameters={"arg1": 1, "arg2": "2"}),
|
||||||
|
ToolCall(name="tool2", parameters={"arg3": 3, "arg4": "4"}),
|
||||||
|
],
|
||||||
|
),
|
||||||
|
{
|
||||||
|
"documents": None,
|
||||||
|
"citations": None,
|
||||||
|
"search_results": None,
|
||||||
|
"search_queries": None,
|
||||||
|
"is_search_required": None,
|
||||||
|
"generation_id": "foo",
|
||||||
|
"tool_calls": [
|
||||||
|
{
|
||||||
|
"id": "foo",
|
||||||
|
"function": {
|
||||||
|
"name": "tool1",
|
||||||
|
"arguments": '{"arg1": 1, "arg2": "2"}',
|
||||||
|
},
|
||||||
|
"type": "function",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "foo",
|
||||||
|
"function": {
|
||||||
|
"name": "tool2",
|
||||||
|
"arguments": '{"arg3": 3, "arg4": "4"}',
|
||||||
|
},
|
||||||
|
"type": "function",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
id="tools should be called",
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
NonStreamedChatResponse(
|
||||||
|
generation_id="foo",
|
||||||
|
text="",
|
||||||
|
tool_calls=[],
|
||||||
|
),
|
||||||
|
{
|
||||||
|
"documents": None,
|
||||||
|
"citations": None,
|
||||||
|
"search_results": None,
|
||||||
|
"search_queries": None,
|
||||||
|
"is_search_required": None,
|
||||||
|
"generation_id": "foo",
|
||||||
|
},
|
||||||
|
id="no tools should be called",
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
NonStreamedChatResponse(
|
||||||
|
generation_id="foo",
|
||||||
|
text="bar",
|
||||||
|
tool_calls=[],
|
||||||
|
),
|
||||||
|
{
|
||||||
|
"documents": None,
|
||||||
|
"citations": None,
|
||||||
|
"search_results": None,
|
||||||
|
"search_queries": None,
|
||||||
|
"is_search_required": None,
|
||||||
|
"generation_id": "foo",
|
||||||
|
},
|
||||||
|
id="chat response without tools/documents/citations/tools etc",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_get_generation_info(
|
||||||
|
response: typing.Any, expected: typing.Dict[str, typing.Any]
|
||||||
|
) -> None:
|
||||||
|
chat_cohere = ChatCohere(cohere_api_key="test")
|
||||||
|
|
||||||
|
actual = chat_cohere._get_generation_info(response)
|
||||||
|
|
||||||
|
assert expected == actual
|
||||||
|
82
libs/partners/cohere/tests/unit_tests/test_cohere_agent.py
Normal file
82
libs/partners/cohere/tests/unit_tests/test_cohere_agent.py
Normal file
@ -0,0 +1,82 @@
|
|||||||
|
from typing import Any, Dict, Optional, Type, Union
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from langchain_core.tools import BaseModel, BaseTool, Field
|
||||||
|
|
||||||
|
from langchain_cohere.cohere_agent import _format_to_cohere_tools
|
||||||
|
|
||||||
|
expected_test_tool_definition = {
|
||||||
|
"description": "test_tool description",
|
||||||
|
"name": "test_tool",
|
||||||
|
"parameter_definitions": {
|
||||||
|
"arg_1": {
|
||||||
|
"description": "Arg1 description",
|
||||||
|
"required": True,
|
||||||
|
"type": "string",
|
||||||
|
},
|
||||||
|
"optional_arg_2": {
|
||||||
|
"description": "Arg2 description",
|
||||||
|
"required": False,
|
||||||
|
"type": "string",
|
||||||
|
},
|
||||||
|
"arg_3": {
|
||||||
|
"description": "Arg3 description",
|
||||||
|
"required": True,
|
||||||
|
"type": "integer",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class _TestToolSchema(BaseModel):
|
||||||
|
arg_1: str = Field(description="Arg1 description")
|
||||||
|
optional_arg_2: Optional[str] = Field(description="Arg2 description", default="2")
|
||||||
|
arg_3: int = Field(description="Arg3 description")
|
||||||
|
|
||||||
|
|
||||||
|
class _TestTool(BaseTool):
|
||||||
|
name = "test_tool"
|
||||||
|
description = "test_tool description"
|
||||||
|
args_schema: Type[_TestToolSchema] = _TestToolSchema
|
||||||
|
|
||||||
|
def _run(self, *args: Any, **kwargs: Any) -> Any:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class test_tool(BaseModel):
|
||||||
|
"""test_tool description"""
|
||||||
|
|
||||||
|
arg_1: str = Field(description="Arg1 description")
|
||||||
|
optional_arg_2: Optional[str] = Field(description="Arg2 description", default="2")
|
||||||
|
arg_3: int = Field(description="Arg3 description")
|
||||||
|
|
||||||
|
|
||||||
|
test_tool_as_dict = {
|
||||||
|
"title": "test_tool",
|
||||||
|
"description": "test_tool description",
|
||||||
|
"properties": {
|
||||||
|
"arg_1": {"description": "Arg1 description", "type": "string"},
|
||||||
|
"optional_arg_2": {
|
||||||
|
"description": "Arg2 description",
|
||||||
|
"type": "string",
|
||||||
|
"default": "2",
|
||||||
|
},
|
||||||
|
"arg_3": {"description": "Arg3 description", "type": "integer"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"tool",
|
||||||
|
[
|
||||||
|
pytest.param(_TestTool(), id="tool from BaseTool"),
|
||||||
|
pytest.param(test_tool, id="BaseModel"),
|
||||||
|
pytest.param(test_tool_as_dict, id="JSON schema dict"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_format_to_cohere_tools(
|
||||||
|
tool: Union[Dict[str, Any], BaseTool, Type[BaseModel]],
|
||||||
|
) -> None:
|
||||||
|
actual = _format_to_cohere_tools([tool])
|
||||||
|
|
||||||
|
assert [expected_test_tool_definition] == actual
|
@ -6,6 +6,7 @@ EXPECTED_ALL = [
|
|||||||
"CohereEmbeddings",
|
"CohereEmbeddings",
|
||||||
"CohereRagRetriever",
|
"CohereRagRetriever",
|
||||||
"CohereRerank",
|
"CohereRerank",
|
||||||
|
"create_cohere_tools_agent",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user