Add message to documents (#12552)

This adds the response message as a document to the rag retriever so
users can choose to use this. Also drops document limit.

---------

Co-authored-by: Bagatur <baskaryan@gmail.com>
pull/13135/head
billytrend-cohere 11 months ago committed by GitHub
parent 5f38770161
commit b346d4a455
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -166,6 +166,16 @@ class ChatCohere(BaseChatModel, BaseCohere):
if run_manager:
await run_manager.on_llm_new_token(delta)
def _get_generation_info(self, response: Any) -> Dict[str, Any]:
"""Get the generation info from cohere API response."""
return {
"documents": response.documents,
"citations": response.citations,
"search_results": response.search_results,
"search_queries": response.search_queries,
"token_count": response.token_count,
}
def _generate(
self,
messages: List[BaseMessage],
@ -185,7 +195,7 @@ class ChatCohere(BaseChatModel, BaseCohere):
message = AIMessage(content=response.text)
generation_info = None
if hasattr(response, "documents"):
generation_info = {"documents": response.documents}
generation_info = self._get_generation_info(response)
return ChatResult(
generations=[
ChatGeneration(message=message, generation_info=generation_info)
@ -211,7 +221,7 @@ class ChatCohere(BaseChatModel, BaseCohere):
message = AIMessage(content=response.text)
generation_info = None
if hasattr(response, "documents"):
generation_info = {"documents": response.documents}
generation_info = self._get_generation_info(response)
return ChatResult(
generations=[
ChatGeneration(message=message, generation_info=generation_info)

@ -15,17 +15,27 @@ if TYPE_CHECKING:
def _get_docs(response: Any) -> List[Document]:
return [
docs = [
Document(page_content=doc["snippet"], metadata=doc)
for doc in response.generation_info["documents"]
]
docs.append(
Document(
page_content=response.message.content,
metadata={
"type": "model_response",
"citations": response.generation_info["citations"],
"search_results": response.generation_info["search_results"],
"search_queries": response.generation_info["search_queries"],
"token_count": response.generation_info["token_count"],
},
)
)
return docs
class CohereRagRetriever(BaseRetriever):
"""`ChatGPT plugin` retriever."""
top_k: int = 3
"""Number of documents to return."""
"""Cohere Chat API with RAG."""
connectors: List[Dict] = Field(default_factory=lambda: [{"id": "web-search"}])
"""
@ -55,7 +65,7 @@ class CohereRagRetriever(BaseRetriever):
callbacks=run_manager.get_child(),
**kwargs,
).generations[0][0]
return _get_docs(res)[: self.top_k]
return _get_docs(res)
async def _aget_relevant_documents(
self,
@ -73,4 +83,4 @@ class CohereRagRetriever(BaseRetriever):
**kwargs,
)
).generations[0][0]
return _get_docs(res)[: self.top_k]
return _get_docs(res)

Loading…
Cancel
Save