langchain/libs/partners/cohere/langchain_cohere/rag_retrievers.py
billytrend-cohere 63343b4987
cohere[patch]: add cohere as a partner package (#19049)
Description: adds support for langchain_cohere

---------

Co-authored-by: Harry M <127103098+harry-cohere@users.noreply.github.com>
Co-authored-by: Erick Friis <erick@langchain.dev>
2024-03-25 20:23:47 +00:00

93 lines
2.8 KiB
Python

from __future__ import annotations
from typing import TYPE_CHECKING, Any, Dict, List
from langchain_core.callbacks import (
AsyncCallbackManagerForRetrieverRun,
CallbackManagerForRetrieverRun,
)
from langchain_core.documents import Document
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import HumanMessage
from langchain_core.pydantic_v1 import Field
from langchain_core.retrievers import BaseRetriever
if TYPE_CHECKING:
from langchain_core.messages import BaseMessage
def _get_docs(response: Any) -> List[Document]:
docs = (
[]
if "documents" not in response.generation_info
else [
Document(page_content=doc["snippet"], metadata=doc)
for doc in response.generation_info["documents"]
]
)
docs.append(
Document(
page_content=response.message.content,
metadata={
"type": "model_response",
"citations": response.generation_info["citations"],
"search_results": response.generation_info["search_results"],
"search_queries": response.generation_info["search_queries"],
"token_count": response.generation_info["token_count"],
},
)
)
return docs
class CohereRagRetriever(BaseRetriever):
"""Cohere Chat API with RAG."""
connectors: List[Dict] = Field(default_factory=lambda: [{"id": "web-search"}])
"""
When specified, the model's reply will be enriched with information found by
querying each of the connectors (RAG). These will be returned as langchain
documents.
Currently only accepts {"id": "web-search"}.
"""
llm: BaseChatModel
"""Cohere ChatModel to use."""
class Config:
"""Configuration for this pydantic object."""
arbitrary_types_allowed = True
"""Allow arbitrary types."""
def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun, **kwargs: Any
) -> List[Document]:
messages: List[List[BaseMessage]] = [[HumanMessage(content=query)]]
res = self.llm.generate(
messages,
connectors=self.connectors,
callbacks=run_manager.get_child(),
**kwargs,
).generations[0][0]
return _get_docs(res)
async def _aget_relevant_documents(
self,
query: str,
*,
run_manager: AsyncCallbackManagerForRetrieverRun,
**kwargs: Any,
) -> List[Document]:
messages: List[List[BaseMessage]] = [[HumanMessage(content=query)]]
res = (
await self.llm.agenerate(
messages,
connectors=self.connectors,
callbacks=run_manager.get_child(),
**kwargs,
)
).generations[0][0]
return _get_docs(res)