cohere[patch]: Add additional kwargs support for Cohere SDK params (#19533)

* Adds support for `additional_kwargs` in `get_cohere_chat_request`
* This functionality passes in Cohere SDK specific parameters from
`BaseMessage` based classes to the API

---------

Co-authored-by: Erick Friis <erick@langchain.dev>
pull/19585/head
Giannis 3 months ago committed by GitHub
parent 2763d8cbe5
commit 9ea2a9b0c1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -47,6 +47,7 @@ def get_role(message: BaseMessage) -> str:
def get_cohere_chat_request(
messages: List[BaseMessage],
*,
documents: Optional[List[Dict[str, str]]] = None,
connectors: Optional[List[Dict[str, str]]] = None,
**kwargs: Any,
) -> Dict[str, Any]:
@ -60,24 +61,33 @@ def get_cohere_chat_request(
Returns:
The request for the Cohere chat API.
"""
documents = (
None
if "source_documents" not in kwargs
else [
{
"snippet": doc.page_content,
"id": doc.metadata.get("id") or f"doc-{str(i)}",
}
for i, doc in enumerate(kwargs["source_documents"])
]
)
kwargs.pop("source_documents", None)
maybe_connectors = connectors if documents is None else None
additional_kwargs = messages[-1].additional_kwargs
# cohere SDK will fail loudly if both connectors and documents are provided
if (
len(additional_kwargs.get("documents", [])) > 0
and documents
and len(documents) > 0
):
raise ValueError(
"Received documents both as a keyword argument and as an prompt additional"
"keywword argument. Please choose only one option."
)
formatted_docs = [
{
"text": doc.page_content,
"id": doc.metadata.get("id") or f"doc-{str(i)}",
}
for i, doc in enumerate(additional_kwargs.get("documents", []))
] or documents
if not formatted_docs:
formatted_docs = None
# by enabling automatic prompt truncation, the probability of request failure is
# reduced with minimal impact on response quality
prompt_truncation = (
"AUTO" if documents is not None or connectors is not None else None
"AUTO" if formatted_docs is not None or connectors is not None else None
)
req = {
@ -85,8 +95,8 @@ def get_cohere_chat_request(
"chat_history": [
{"role": get_role(x), "message": x.content} for x in messages[:-1]
],
"documents": documents,
"connectors": maybe_connectors,
"documents": formatted_docs,
"connectors": connectors,
"prompt_truncation": prompt_truncation,
**kwargs,
}

@ -3,7 +3,7 @@ from langchain_cohere import ChatCohere
def test_stream() -> None:
"""Test streaming tokens from OpenAI."""
"""Test streaming tokens from ChatCohere."""
llm = ChatCohere()
for token in llm.stream("I'm Pickle Rick"):
@ -11,7 +11,7 @@ def test_stream() -> None:
async def test_astream() -> None:
"""Test streaming tokens from OpenAI."""
"""Test streaming tokens from ChatCohere."""
llm = ChatCohere()
async for token in llm.astream("I'm Pickle Rick"):

@ -0,0 +1,63 @@
"""Test ChatCohere chat model."""
from typing import Any, Dict, List
from langchain_core.documents import Document
from langchain_core.messages.human import HumanMessage
from langchain_core.prompts.chat import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.runnables import (
RunnablePassthrough,
RunnableSerializable,
)
from langchain_cohere import ChatCohere
def test_connectors() -> None:
"""Test connectors parameter support from ChatCohere."""
llm = ChatCohere().bind(connectors=[{"id": "web-search"}])
result = llm.invoke("Who directed dune two? reply with just the name.")
assert isinstance(result.content, str)
def test_documents() -> None:
"""Test documents paraneter support from ChatCohere."""
docs = [{"text": "The sky is green."}]
llm = ChatCohere().bind(documents=docs)
prompt = "What color is the sky?"
result = llm.invoke(prompt)
assert isinstance(result.content, str)
assert len(result.response_metadata["documents"]) == 1
def test_documents_chain() -> None:
"""Test documents paraneter support from ChatCohere."""
llm = ChatCohere()
def get_documents(_: Any) -> List[Document]:
return [Document(page_content="The sky is green.")]
def format_input_msgs(input: Dict[str, Any]) -> List[HumanMessage]:
return [
HumanMessage(
content=input["message"],
additional_kwargs={
"documents": input.get("documents", None),
},
)
]
prompt = ChatPromptTemplate.from_messages([MessagesPlaceholder("input_msgs")])
chain: RunnableSerializable[Any, Any] = (
{"message": RunnablePassthrough(), "documents": get_documents}
| RunnablePassthrough()
| {"input_msgs": format_input_msgs}
| prompt
| llm
)
result = chain.invoke("What color is the sky?")
assert isinstance(result.content, str)
assert len(result.response_metadata["documents"]) == 1
Loading…
Cancel
Save