diff --git a/libs/partners/cohere/langchain_cohere/chat_models.py b/libs/partners/cohere/langchain_cohere/chat_models.py index 25a537abe1..f60f5636dd 100644 --- a/libs/partners/cohere/langchain_cohere/chat_models.py +++ b/libs/partners/cohere/langchain_cohere/chat_models.py @@ -47,6 +47,7 @@ def get_role(message: BaseMessage) -> str: def get_cohere_chat_request( messages: List[BaseMessage], *, + documents: Optional[List[Dict[str, str]]] = None, connectors: Optional[List[Dict[str, str]]] = None, **kwargs: Any, ) -> Dict[str, Any]: @@ -60,24 +61,33 @@ def get_cohere_chat_request( Returns: The request for the Cohere chat API. """ - documents = ( - None - if "source_documents" not in kwargs - else [ - { - "snippet": doc.page_content, - "id": doc.metadata.get("id") or f"doc-{str(i)}", - } - for i, doc in enumerate(kwargs["source_documents"]) - ] - ) - kwargs.pop("source_documents", None) - maybe_connectors = connectors if documents is None else None + additional_kwargs = messages[-1].additional_kwargs + + # cohere SDK will fail loudly if both connectors and documents are provided + if ( + len(additional_kwargs.get("documents", [])) > 0 + and documents + and len(documents) > 0 + ): + raise ValueError( + "Received documents both as a keyword argument and as an prompt additional" + "keywword argument. Please choose only one option." + ) + + formatted_docs = [ + { + "text": doc.page_content, + "id": doc.metadata.get("id") or f"doc-{str(i)}", + } + for i, doc in enumerate(additional_kwargs.get("documents", [])) + ] or documents + if not formatted_docs: + formatted_docs = None # by enabling automatic prompt truncation, the probability of request failure is # reduced with minimal impact on response quality prompt_truncation = ( - "AUTO" if documents is not None or connectors is not None else None + "AUTO" if formatted_docs is not None or connectors is not None else None ) req = { @@ -85,8 +95,8 @@ def get_cohere_chat_request( "chat_history": [ {"role": get_role(x), "message": x.content} for x in messages[:-1] ], - "documents": documents, - "connectors": maybe_connectors, + "documents": formatted_docs, + "connectors": connectors, "prompt_truncation": prompt_truncation, **kwargs, } diff --git a/libs/partners/cohere/tests/integration_tests/test_chat_models.py b/libs/partners/cohere/tests/integration_tests/test_chat_models.py index 5e249a1616..81246c37aa 100644 --- a/libs/partners/cohere/tests/integration_tests/test_chat_models.py +++ b/libs/partners/cohere/tests/integration_tests/test_chat_models.py @@ -3,7 +3,7 @@ from langchain_cohere import ChatCohere def test_stream() -> None: - """Test streaming tokens from OpenAI.""" + """Test streaming tokens from ChatCohere.""" llm = ChatCohere() for token in llm.stream("I'm Pickle Rick"): @@ -11,7 +11,7 @@ def test_stream() -> None: async def test_astream() -> None: - """Test streaming tokens from OpenAI.""" + """Test streaming tokens from ChatCohere.""" llm = ChatCohere() async for token in llm.astream("I'm Pickle Rick"): diff --git a/libs/partners/cohere/tests/integration_tests/test_rag.py b/libs/partners/cohere/tests/integration_tests/test_rag.py new file mode 100644 index 0000000000..310598b522 --- /dev/null +++ b/libs/partners/cohere/tests/integration_tests/test_rag.py @@ -0,0 +1,63 @@ +"""Test ChatCohere chat model.""" + +from typing import Any, Dict, List + +from langchain_core.documents import Document +from langchain_core.messages.human import HumanMessage +from langchain_core.prompts.chat import ChatPromptTemplate, MessagesPlaceholder +from langchain_core.runnables import ( + RunnablePassthrough, + RunnableSerializable, +) + +from langchain_cohere import ChatCohere + + +def test_connectors() -> None: + """Test connectors parameter support from ChatCohere.""" + llm = ChatCohere().bind(connectors=[{"id": "web-search"}]) + + result = llm.invoke("Who directed dune two? reply with just the name.") + assert isinstance(result.content, str) + + +def test_documents() -> None: + """Test documents paraneter support from ChatCohere.""" + docs = [{"text": "The sky is green."}] + llm = ChatCohere().bind(documents=docs) + prompt = "What color is the sky?" + + result = llm.invoke(prompt) + assert isinstance(result.content, str) + assert len(result.response_metadata["documents"]) == 1 + + +def test_documents_chain() -> None: + """Test documents paraneter support from ChatCohere.""" + llm = ChatCohere() + + def get_documents(_: Any) -> List[Document]: + return [Document(page_content="The sky is green.")] + + def format_input_msgs(input: Dict[str, Any]) -> List[HumanMessage]: + return [ + HumanMessage( + content=input["message"], + additional_kwargs={ + "documents": input.get("documents", None), + }, + ) + ] + + prompt = ChatPromptTemplate.from_messages([MessagesPlaceholder("input_msgs")]) + chain: RunnableSerializable[Any, Any] = ( + {"message": RunnablePassthrough(), "documents": get_documents} + | RunnablePassthrough() + | {"input_msgs": format_input_msgs} + | prompt + | llm + ) + + result = chain.invoke("What color is the sky?") + assert isinstance(result.content, str) + assert len(result.response_metadata["documents"]) == 1