2023-12-11 21:53:30 +00:00
|
|
|
from __future__ import annotations
|
|
|
|
|
|
|
|
from typing import TYPE_CHECKING, Any, Dict, List
|
|
|
|
|
|
|
|
from langchain_core.callbacks import (
|
|
|
|
AsyncCallbackManagerForRetrieverRun,
|
|
|
|
CallbackManagerForRetrieverRun,
|
|
|
|
)
|
|
|
|
from langchain_core.documents import Document
|
|
|
|
from langchain_core.language_models.chat_models import BaseChatModel
|
|
|
|
from langchain_core.messages import HumanMessage
|
|
|
|
from langchain_core.pydantic_v1 import Field
|
|
|
|
from langchain_core.retrievers import BaseRetriever
|
|
|
|
|
|
|
|
if TYPE_CHECKING:
|
|
|
|
from langchain_core.messages import BaseMessage
|
|
|
|
|
|
|
|
|
|
|
|
def _get_docs(response: Any) -> List[Document]:
|
2024-01-17 17:11:00 +00:00
|
|
|
docs = (
|
|
|
|
[]
|
|
|
|
if "documents" not in response.generation_info
|
|
|
|
else [
|
|
|
|
Document(page_content=doc["snippet"], metadata=doc)
|
|
|
|
for doc in response.generation_info["documents"]
|
|
|
|
]
|
|
|
|
)
|
2023-12-11 21:53:30 +00:00
|
|
|
docs.append(
|
|
|
|
Document(
|
|
|
|
page_content=response.message.content,
|
|
|
|
metadata={
|
|
|
|
"type": "model_response",
|
|
|
|
"citations": response.generation_info["citations"],
|
|
|
|
"search_results": response.generation_info["search_results"],
|
|
|
|
"search_queries": response.generation_info["search_queries"],
|
|
|
|
"token_count": response.generation_info["token_count"],
|
|
|
|
},
|
|
|
|
)
|
|
|
|
)
|
|
|
|
return docs
|
|
|
|
|
|
|
|
|
|
|
|
class CohereRagRetriever(BaseRetriever):
|
|
|
|
"""Cohere Chat API with RAG."""
|
|
|
|
|
|
|
|
connectors: List[Dict] = Field(default_factory=lambda: [{"id": "web-search"}])
|
|
|
|
"""
|
|
|
|
When specified, the model's reply will be enriched with information found by
|
|
|
|
querying each of the connectors (RAG). These will be returned as langchain
|
|
|
|
documents.
|
|
|
|
|
|
|
|
Currently only accepts {"id": "web-search"}.
|
|
|
|
"""
|
|
|
|
|
|
|
|
llm: BaseChatModel
|
|
|
|
"""Cohere ChatModel to use."""
|
|
|
|
|
|
|
|
class Config:
|
|
|
|
"""Configuration for this pydantic object."""
|
|
|
|
|
|
|
|
arbitrary_types_allowed = True
|
|
|
|
"""Allow arbitrary types."""
|
|
|
|
|
|
|
|
def _get_relevant_documents(
|
|
|
|
self, query: str, *, run_manager: CallbackManagerForRetrieverRun, **kwargs: Any
|
|
|
|
) -> List[Document]:
|
|
|
|
messages: List[List[BaseMessage]] = [[HumanMessage(content=query)]]
|
|
|
|
res = self.llm.generate(
|
|
|
|
messages,
|
|
|
|
connectors=self.connectors,
|
|
|
|
callbacks=run_manager.get_child(),
|
|
|
|
**kwargs,
|
|
|
|
).generations[0][0]
|
|
|
|
return _get_docs(res)
|
|
|
|
|
|
|
|
async def _aget_relevant_documents(
|
|
|
|
self,
|
|
|
|
query: str,
|
|
|
|
*,
|
|
|
|
run_manager: AsyncCallbackManagerForRetrieverRun,
|
|
|
|
**kwargs: Any,
|
|
|
|
) -> List[Document]:
|
|
|
|
messages: List[List[BaseMessage]] = [[HumanMessage(content=query)]]
|
|
|
|
res = (
|
|
|
|
await self.llm.agenerate(
|
|
|
|
messages,
|
|
|
|
connectors=self.connectors,
|
|
|
|
callbacks=run_manager.get_child(),
|
|
|
|
**kwargs,
|
|
|
|
)
|
|
|
|
).generations[0][0]
|
|
|
|
return _get_docs(res)
|