cohere[patch]: Fix retriever (#19771)

* Replace `source_documents` with `documents`
* Pass `documents` as a named arg vs keyword
* Make `parsed_docs` more robust
* Fix edge case of doc page_content being `None`
pull/19818/head
Giannis 5 months ago committed by GitHub
parent b6ebddbacc
commit 8cf1d75d08
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -17,6 +17,7 @@ from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun, AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun, CallbackManagerForLLMRun,
) )
from langchain_core.documents import Document
from langchain_core.language_models import LanguageModelInput from langchain_core.language_models import LanguageModelInput
from langchain_core.language_models.chat_models import ( from langchain_core.language_models.chat_models import (
BaseChatModel, BaseChatModel,
@ -73,7 +74,7 @@ def get_role(message: BaseMessage) -> str:
def get_cohere_chat_request( def get_cohere_chat_request(
messages: List[BaseMessage], messages: List[BaseMessage],
*, *,
documents: Optional[List[Dict[str, str]]] = None, documents: Optional[List[Document]] = None,
connectors: Optional[List[Dict[str, str]]] = None, connectors: Optional[List[Dict[str, str]]] = None,
**kwargs: Any, **kwargs: Any,
) -> Dict[str, Any]: ) -> Dict[str, Any]:
@ -95,17 +96,25 @@ def get_cohere_chat_request(
"Received documents both as a keyword argument and as an prompt additional keyword argument. Please choose only one option." # noqa: E501 "Received documents both as a keyword argument and as an prompt additional keyword argument. Please choose only one option." # noqa: E501
) )
parsed_docs: Optional[List[Document]] = None
if "documents" in additional_kwargs:
parsed_docs = (
additional_kwargs["documents"]
if len(additional_kwargs["documents"]) > 0
else None
)
elif documents is not None and len(documents) > 0:
parsed_docs = documents
formatted_docs: Optional[List[Dict[str, Any]]] = None formatted_docs: Optional[List[Dict[str, Any]]] = None
if additional_kwargs.get("documents"): if parsed_docs is not None:
formatted_docs = [ formatted_docs = [
{ {
"text": doc.page_content, "text": doc.page_content,
"id": doc.metadata.get("id") or f"doc-{str(i)}", "id": doc.metadata.get("id") or f"doc-{str(i)}",
} }
for i, doc in enumerate(additional_kwargs.get("documents", [])) for i, doc in enumerate(parsed_docs)
] ]
elif documents:
formatted_docs = documents
# by enabling automatic prompt truncation, the probability of request failure is # by enabling automatic prompt truncation, the probability of request failure is
# reduced with minimal impact on response quality # reduced with minimal impact on response quality

@ -1,6 +1,6 @@
from __future__ import annotations from __future__ import annotations
from typing import TYPE_CHECKING, Any, Dict, List from typing import TYPE_CHECKING, Any, Dict, List, Optional
from langchain_core.callbacks import ( from langchain_core.callbacks import (
AsyncCallbackManagerForRetrieverRun, AsyncCallbackManagerForRetrieverRun,
@ -17,15 +17,16 @@ if TYPE_CHECKING:
def _get_docs(response: Any) -> List[Document]: def _get_docs(response: Any) -> List[Document]:
docs = ( docs = []
[] if (
if "documents" not in response.generation_info "documents" in response.generation_info
or len(response.generation_info["documents"]) == 0 and len(response.generation_info["documents"]) > 0
else [ ):
Document(page_content=doc["snippet"], metadata=doc) for doc in response.generation_info["documents"]:
for doc in response.generation_info["documents"] content = doc.get("snippet", None) or doc.get("text", None)
] if content is not None:
) docs.append(Document(page_content=content, metadata=doc))
docs.append( docs.append(
Document( Document(
page_content=response.message.content, page_content=response.message.content,
@ -63,12 +64,18 @@ class CohereRagRetriever(BaseRetriever):
"""Allow arbitrary types.""" """Allow arbitrary types."""
def _get_relevant_documents( def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun, **kwargs: Any self,
query: str,
*,
run_manager: CallbackManagerForRetrieverRun,
documents: Optional[List[Dict[str, str]]] = None,
**kwargs: Any,
) -> List[Document]: ) -> List[Document]:
messages: List[List[BaseMessage]] = [[HumanMessage(content=query)]] messages: List[List[BaseMessage]] = [[HumanMessage(content=query)]]
res = self.llm.generate( res = self.llm.generate(
messages, messages,
connectors=self.connectors, connectors=self.connectors if documents is None else None,
documents=documents,
callbacks=run_manager.get_child(), callbacks=run_manager.get_child(),
**kwargs, **kwargs,
).generations[0][0] ).generations[0][0]
@ -79,13 +86,15 @@ class CohereRagRetriever(BaseRetriever):
query: str, query: str,
*, *,
run_manager: AsyncCallbackManagerForRetrieverRun, run_manager: AsyncCallbackManagerForRetrieverRun,
documents: Optional[List[Dict[str, str]]] = None,
**kwargs: Any, **kwargs: Any,
) -> List[Document]: ) -> List[Document]:
messages: List[List[BaseMessage]] = [[HumanMessage(content=query)]] messages: List[List[BaseMessage]] = [[HumanMessage(content=query)]]
res = ( res = (
await self.llm.agenerate( await self.llm.agenerate(
messages, messages,
connectors=self.connectors, connectors=self.connectors if documents is None else None,
documents=documents,
callbacks=run_manager.get_child(), callbacks=run_manager.get_child(),
**kwargs, **kwargs,
) )

Loading…
Cancel
Save