From 8cf1d75d08badf5ac7f6c3806f46bad7a365d904 Mon Sep 17 00:00:00 2001 From: Giannis <145396613+giannis2two@users.noreply.github.com> Date: Sun, 31 Mar 2024 17:47:03 -0400 Subject: [PATCH] cohere[patch]: Fix retriever (#19771) * Replace `source_documents` with `documents` * Pass `documents` as a named arg vs keyword * Make `parsed_docs` more robust * Fix edge case of doc page_content being `None` --- .../cohere/langchain_cohere/chat_models.py | 19 +++++++--- .../cohere/langchain_cohere/rag_retrievers.py | 35 ++++++++++++------- 2 files changed, 36 insertions(+), 18 deletions(-) diff --git a/libs/partners/cohere/langchain_cohere/chat_models.py b/libs/partners/cohere/langchain_cohere/chat_models.py index 51a66cfcdf..cfc2df4fe3 100644 --- a/libs/partners/cohere/langchain_cohere/chat_models.py +++ b/libs/partners/cohere/langchain_cohere/chat_models.py @@ -17,6 +17,7 @@ from langchain_core.callbacks import ( AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun, ) +from langchain_core.documents import Document from langchain_core.language_models import LanguageModelInput from langchain_core.language_models.chat_models import ( BaseChatModel, @@ -73,7 +74,7 @@ def get_role(message: BaseMessage) -> str: def get_cohere_chat_request( messages: List[BaseMessage], *, - documents: Optional[List[Dict[str, str]]] = None, + documents: Optional[List[Document]] = None, connectors: Optional[List[Dict[str, str]]] = None, **kwargs: Any, ) -> Dict[str, Any]: @@ -95,17 +96,25 @@ def get_cohere_chat_request( "Received documents both as a keyword argument and as an prompt additional keyword argument. Please choose only one option." # noqa: E501 ) + parsed_docs: Optional[List[Document]] = None + if "documents" in additional_kwargs: + parsed_docs = ( + additional_kwargs["documents"] + if len(additional_kwargs["documents"]) > 0 + else None + ) + elif documents is not None and len(documents) > 0: + parsed_docs = documents + formatted_docs: Optional[List[Dict[str, Any]]] = None - if additional_kwargs.get("documents"): + if parsed_docs is not None: formatted_docs = [ { "text": doc.page_content, "id": doc.metadata.get("id") or f"doc-{str(i)}", } - for i, doc in enumerate(additional_kwargs.get("documents", [])) + for i, doc in enumerate(parsed_docs) ] - elif documents: - formatted_docs = documents # by enabling automatic prompt truncation, the probability of request failure is # reduced with minimal impact on response quality diff --git a/libs/partners/cohere/langchain_cohere/rag_retrievers.py b/libs/partners/cohere/langchain_cohere/rag_retrievers.py index 0d19459620..99d7aa17d5 100644 --- a/libs/partners/cohere/langchain_cohere/rag_retrievers.py +++ b/libs/partners/cohere/langchain_cohere/rag_retrievers.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Dict, List +from typing import TYPE_CHECKING, Any, Dict, List, Optional from langchain_core.callbacks import ( AsyncCallbackManagerForRetrieverRun, @@ -17,15 +17,16 @@ if TYPE_CHECKING: def _get_docs(response: Any) -> List[Document]: - docs = ( - [] - if "documents" not in response.generation_info - or len(response.generation_info["documents"]) == 0 - else [ - Document(page_content=doc["snippet"], metadata=doc) - for doc in response.generation_info["documents"] - ] - ) + docs = [] + if ( + "documents" in response.generation_info + and len(response.generation_info["documents"]) > 0 + ): + for doc in response.generation_info["documents"]: + content = doc.get("snippet", None) or doc.get("text", None) + if content is not None: + docs.append(Document(page_content=content, metadata=doc)) + docs.append( Document( page_content=response.message.content, @@ -63,12 +64,18 @@ class CohereRagRetriever(BaseRetriever): """Allow arbitrary types.""" def _get_relevant_documents( - self, query: str, *, run_manager: CallbackManagerForRetrieverRun, **kwargs: Any + self, + query: str, + *, + run_manager: CallbackManagerForRetrieverRun, + documents: Optional[List[Dict[str, str]]] = None, + **kwargs: Any, ) -> List[Document]: messages: List[List[BaseMessage]] = [[HumanMessage(content=query)]] res = self.llm.generate( messages, - connectors=self.connectors, + connectors=self.connectors if documents is None else None, + documents=documents, callbacks=run_manager.get_child(), **kwargs, ).generations[0][0] @@ -79,13 +86,15 @@ class CohereRagRetriever(BaseRetriever): query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun, + documents: Optional[List[Dict[str, str]]] = None, **kwargs: Any, ) -> List[Document]: messages: List[List[BaseMessage]] = [[HumanMessage(content=query)]] res = ( await self.llm.agenerate( messages, - connectors=self.connectors, + connectors=self.connectors if documents is None else None, + documents=documents, callbacks=run_manager.get_child(), **kwargs, )