cohere[patch]: Fix retriever (#19771)

* Replace `source_documents` with `documents`
* Pass `documents` as a named arg vs keyword
* Make `parsed_docs` more robust
* Fix edge case of doc page_content being `None`
pull/19818/head
Giannis 4 months ago committed by GitHub
parent b6ebddbacc
commit 8cf1d75d08
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -17,6 +17,7 @@ from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.documents import Document
from langchain_core.language_models import LanguageModelInput
from langchain_core.language_models.chat_models import (
BaseChatModel,
@ -73,7 +74,7 @@ def get_role(message: BaseMessage) -> str:
def get_cohere_chat_request(
messages: List[BaseMessage],
*,
documents: Optional[List[Dict[str, str]]] = None,
documents: Optional[List[Document]] = None,
connectors: Optional[List[Dict[str, str]]] = None,
**kwargs: Any,
) -> Dict[str, Any]:
@ -95,17 +96,25 @@ def get_cohere_chat_request(
"Received documents both as a keyword argument and as an prompt additional keyword argument. Please choose only one option." # noqa: E501
)
parsed_docs: Optional[List[Document]] = None
if "documents" in additional_kwargs:
parsed_docs = (
additional_kwargs["documents"]
if len(additional_kwargs["documents"]) > 0
else None
)
elif documents is not None and len(documents) > 0:
parsed_docs = documents
formatted_docs: Optional[List[Dict[str, Any]]] = None
if additional_kwargs.get("documents"):
if parsed_docs is not None:
formatted_docs = [
{
"text": doc.page_content,
"id": doc.metadata.get("id") or f"doc-{str(i)}",
}
for i, doc in enumerate(additional_kwargs.get("documents", []))
for i, doc in enumerate(parsed_docs)
]
elif documents:
formatted_docs = documents
# by enabling automatic prompt truncation, the probability of request failure is
# reduced with minimal impact on response quality

@ -1,6 +1,6 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Any, Dict, List
from typing import TYPE_CHECKING, Any, Dict, List, Optional
from langchain_core.callbacks import (
AsyncCallbackManagerForRetrieverRun,
@ -17,15 +17,16 @@ if TYPE_CHECKING:
def _get_docs(response: Any) -> List[Document]:
docs = (
[]
if "documents" not in response.generation_info
or len(response.generation_info["documents"]) == 0
else [
Document(page_content=doc["snippet"], metadata=doc)
for doc in response.generation_info["documents"]
]
)
docs = []
if (
"documents" in response.generation_info
and len(response.generation_info["documents"]) > 0
):
for doc in response.generation_info["documents"]:
content = doc.get("snippet", None) or doc.get("text", None)
if content is not None:
docs.append(Document(page_content=content, metadata=doc))
docs.append(
Document(
page_content=response.message.content,
@ -63,12 +64,18 @@ class CohereRagRetriever(BaseRetriever):
"""Allow arbitrary types."""
def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun, **kwargs: Any
self,
query: str,
*,
run_manager: CallbackManagerForRetrieverRun,
documents: Optional[List[Dict[str, str]]] = None,
**kwargs: Any,
) -> List[Document]:
messages: List[List[BaseMessage]] = [[HumanMessage(content=query)]]
res = self.llm.generate(
messages,
connectors=self.connectors,
connectors=self.connectors if documents is None else None,
documents=documents,
callbacks=run_manager.get_child(),
**kwargs,
).generations[0][0]
@ -79,13 +86,15 @@ class CohereRagRetriever(BaseRetriever):
query: str,
*,
run_manager: AsyncCallbackManagerForRetrieverRun,
documents: Optional[List[Dict[str, str]]] = None,
**kwargs: Any,
) -> List[Document]:
messages: List[List[BaseMessage]] = [[HumanMessage(content=query)]]
res = (
await self.llm.agenerate(
messages,
connectors=self.connectors,
connectors=self.connectors if documents is None else None,
documents=documents,
callbacks=run_manager.get_child(),
**kwargs,
)

Loading…
Cancel
Save