diff --git a/libs/langchain/langchain/chat_models/cohere.py b/libs/langchain/langchain/chat_models/cohere.py index 19eaffc685..4b09e06aff 100644 --- a/libs/langchain/langchain/chat_models/cohere.py +++ b/libs/langchain/langchain/chat_models/cohere.py @@ -166,6 +166,16 @@ class ChatCohere(BaseChatModel, BaseCohere): if run_manager: await run_manager.on_llm_new_token(delta) + def _get_generation_info(self, response: Any) -> Dict[str, Any]: + """Get the generation info from cohere API response.""" + return { + "documents": response.documents, + "citations": response.citations, + "search_results": response.search_results, + "search_queries": response.search_queries, + "token_count": response.token_count, + } + def _generate( self, messages: List[BaseMessage], @@ -185,7 +195,7 @@ class ChatCohere(BaseChatModel, BaseCohere): message = AIMessage(content=response.text) generation_info = None if hasattr(response, "documents"): - generation_info = {"documents": response.documents} + generation_info = self._get_generation_info(response) return ChatResult( generations=[ ChatGeneration(message=message, generation_info=generation_info) @@ -211,7 +221,7 @@ class ChatCohere(BaseChatModel, BaseCohere): message = AIMessage(content=response.text) generation_info = None if hasattr(response, "documents"): - generation_info = {"documents": response.documents} + generation_info = self._get_generation_info(response) return ChatResult( generations=[ ChatGeneration(message=message, generation_info=generation_info) diff --git a/libs/langchain/langchain/retrievers/cohere_rag_retriever.py b/libs/langchain/langchain/retrievers/cohere_rag_retriever.py index 1fd88c093e..9d79adee69 100644 --- a/libs/langchain/langchain/retrievers/cohere_rag_retriever.py +++ b/libs/langchain/langchain/retrievers/cohere_rag_retriever.py @@ -15,17 +15,27 @@ if TYPE_CHECKING: def _get_docs(response: Any) -> List[Document]: - return [ + docs = [ Document(page_content=doc["snippet"], metadata=doc) for doc in response.generation_info["documents"] ] + docs.append( + Document( + page_content=response.message.content, + metadata={ + "type": "model_response", + "citations": response.generation_info["citations"], + "search_results": response.generation_info["search_results"], + "search_queries": response.generation_info["search_queries"], + "token_count": response.generation_info["token_count"], + }, + ) + ) + return docs class CohereRagRetriever(BaseRetriever): - """`ChatGPT plugin` retriever.""" - - top_k: int = 3 - """Number of documents to return.""" + """Cohere Chat API with RAG.""" connectors: List[Dict] = Field(default_factory=lambda: [{"id": "web-search"}]) """ @@ -55,7 +65,7 @@ class CohereRagRetriever(BaseRetriever): callbacks=run_manager.get_child(), **kwargs, ).generations[0][0] - return _get_docs(res)[: self.top_k] + return _get_docs(res) async def _aget_relevant_documents( self, @@ -73,4 +83,4 @@ class CohereRagRetriever(BaseRetriever): **kwargs, ) ).generations[0][0] - return _get_docs(res)[: self.top_k] + return _get_docs(res)