forked from Archives/langchain
supported async retriever (#2149)
This commit is contained in:
parent
65c0c73597
commit
35a3218e84
@ -86,6 +86,10 @@ class BaseConversationalRetrievalChain(Chain, BaseModel):
|
||||
else:
|
||||
return {self.output_key: answer}
|
||||
|
||||
@abstractmethod
|
||||
async def _aget_docs(self, question: str, inputs: Dict[str, Any]) -> List[Document]:
|
||||
"""Get docs."""
|
||||
|
||||
async def _acall(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
question = inputs["question"]
|
||||
get_chat_history = self.get_chat_history or _get_chat_history
|
||||
@ -96,8 +100,7 @@ class BaseConversationalRetrievalChain(Chain, BaseModel):
|
||||
)
|
||||
else:
|
||||
new_question = question
|
||||
# TODO: This blocks the event loop, but it's not clear how to avoid it.
|
||||
docs = self._get_docs(new_question, inputs)
|
||||
docs = await self._aget_docs(new_question, inputs)
|
||||
new_inputs = inputs.copy()
|
||||
new_inputs["question"] = new_question
|
||||
new_inputs["chat_history"] = chat_history_str
|
||||
@ -143,6 +146,10 @@ class ConversationalRetrievalChain(BaseConversationalRetrievalChain, BaseModel):
|
||||
docs = self.retriever.get_relevant_documents(question)
|
||||
return self._reduce_tokens_below_limit(docs)
|
||||
|
||||
async def _aget_docs(self, question: str, inputs: Dict[str, Any]) -> List[Document]:
|
||||
docs = await self.retriever.aget_relevant_documents(question)
|
||||
return self._reduce_tokens_below_limit(docs)
|
||||
|
||||
@classmethod
|
||||
def from_llm(
|
||||
cls,
|
||||
@ -194,6 +201,9 @@ class ChatVectorDBChain(BaseConversationalRetrievalChain, BaseModel):
|
||||
question, k=self.top_k_docs_for_context, **full_kwargs
|
||||
)
|
||||
|
||||
async def _aget_docs(self, question: str, inputs: Dict[str, Any]) -> List[Document]:
|
||||
raise NotImplementedError("ChatVectorDBChain does not support async")
|
||||
|
||||
@classmethod
|
||||
def from_llm(
|
||||
cls,
|
||||
|
@ -129,6 +129,25 @@ class BaseQAWithSourcesChain(Chain, BaseModel, ABC):
|
||||
result["source_documents"] = docs
|
||||
return result
|
||||
|
||||
@abstractmethod
|
||||
async def _aget_docs(self, inputs: Dict[str, Any]) -> List[Document]:
|
||||
"""Get docs to run questioning over."""
|
||||
|
||||
async def _acall(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
docs = await self._aget_docs(inputs)
|
||||
answer, _ = await self.combine_documents_chain.acombine_docs(docs, **inputs)
|
||||
if re.search(r"SOURCES:\s", answer):
|
||||
answer, sources = re.split(r"SOURCES:\s", answer)
|
||||
else:
|
||||
sources = ""
|
||||
result: Dict[str, Any] = {
|
||||
self.answer_key: answer,
|
||||
self.sources_answer_key: sources,
|
||||
}
|
||||
if self.return_source_documents:
|
||||
result["source_documents"] = docs
|
||||
return result
|
||||
|
||||
|
||||
class QAWithSourcesChain(BaseQAWithSourcesChain, BaseModel):
|
||||
"""Question answering with sources over documents."""
|
||||
@ -146,6 +165,9 @@ class QAWithSourcesChain(BaseQAWithSourcesChain, BaseModel):
|
||||
def _get_docs(self, inputs: Dict[str, Any]) -> List[Document]:
|
||||
return inputs.pop(self.input_docs_key)
|
||||
|
||||
async def _aget_docs(self, inputs: Dict[str, Any]) -> List[Document]:
|
||||
return inputs.pop(self.input_docs_key)
|
||||
|
||||
@property
|
||||
def _chain_type(self) -> str:
|
||||
return "qa_with_sources_chain"
|
||||
|
@ -44,3 +44,8 @@ class RetrievalQAWithSourcesChain(BaseQAWithSourcesChain, BaseModel):
|
||||
question = inputs[self.question_key]
|
||||
docs = self.retriever.get_relevant_documents(question)
|
||||
return self._reduce_tokens_below_limit(docs)
|
||||
|
||||
async def _aget_docs(self, inputs: Dict[str, Any]) -> List[Document]:
|
||||
question = inputs[self.question_key]
|
||||
docs = await self.retriever.aget_relevant_documents(question)
|
||||
return self._reduce_tokens_below_limit(docs)
|
||||
|
@ -52,6 +52,9 @@ class VectorDBQAWithSourcesChain(BaseQAWithSourcesChain, BaseModel):
|
||||
)
|
||||
return self._reduce_tokens_below_limit(docs)
|
||||
|
||||
async def _aget_docs(self, inputs: Dict[str, Any]) -> List[Document]:
|
||||
raise NotImplementedError("VectorDBQAWithSourcesChain does not support async")
|
||||
|
||||
@root_validator()
|
||||
def raise_deprecation(cls, values: Dict) -> Dict:
|
||||
warnings.warn(
|
||||
|
@ -114,6 +114,34 @@ class BaseRetrievalQA(Chain, BaseModel):
|
||||
else:
|
||||
return {self.output_key: answer}
|
||||
|
||||
@abstractmethod
|
||||
async def _aget_docs(self, question: str) -> List[Document]:
|
||||
"""Get documents to do question answering over."""
|
||||
|
||||
async def _acall(self, inputs: Dict[str, str]) -> Dict[str, Any]:
|
||||
"""Run get_relevant_text and llm on input query.
|
||||
|
||||
If chain has 'return_source_documents' as 'True', returns
|
||||
the retrieved documents as well under the key 'source_documents'.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
res = indexqa({'query': 'This is my query'})
|
||||
answer, docs = res['result'], res['source_documents']
|
||||
"""
|
||||
question = inputs[self.input_key]
|
||||
|
||||
docs = await self._aget_docs(question)
|
||||
answer, _ = await self.combine_documents_chain.acombine_docs(
|
||||
docs, question=question
|
||||
)
|
||||
|
||||
if self.return_source_documents:
|
||||
return {self.output_key: answer, "source_documents": docs}
|
||||
else:
|
||||
return {self.output_key: answer}
|
||||
|
||||
|
||||
class RetrievalQA(BaseRetrievalQA, BaseModel):
|
||||
"""Chain for question-answering against an index.
|
||||
@ -134,6 +162,9 @@ class RetrievalQA(BaseRetrievalQA, BaseModel):
|
||||
def _get_docs(self, question: str) -> List[Document]:
|
||||
return self.retriever.get_relevant_documents(question)
|
||||
|
||||
async def _aget_docs(self, question: str) -> List[Document]:
|
||||
return await self.retriever.aget_relevant_documents(question)
|
||||
|
||||
|
||||
class VectorDBQA(BaseRetrievalQA, BaseModel):
|
||||
"""Chain for question-answering against a vector database."""
|
||||
@ -177,6 +208,9 @@ class VectorDBQA(BaseRetrievalQA, BaseModel):
|
||||
raise ValueError(f"search_type of {self.search_type} not allowed.")
|
||||
return docs
|
||||
|
||||
async def _aget_docs(self, question: str) -> List[Document]:
|
||||
raise NotImplementedError("VectorDBQA does not support async")
|
||||
|
||||
@property
|
||||
def _chain_type(self) -> str:
|
||||
"""Return the chain type."""
|
||||
|
@ -1,5 +1,6 @@
|
||||
from typing import List
|
||||
from typing import List, Optional
|
||||
|
||||
import aiohttp
|
||||
import requests
|
||||
from pydantic import BaseModel
|
||||
|
||||
@ -9,6 +10,12 @@ from langchain.schema import BaseRetriever, Document
|
||||
class ChatGPTPluginRetriever(BaseRetriever, BaseModel):
|
||||
url: str
|
||||
bearer_token: str
|
||||
aiosession: Optional[aiohttp.ClientSession] = None
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def get_relevant_documents(self, query: str) -> List[Document]:
|
||||
response = requests.post(
|
||||
@ -25,3 +32,28 @@ class ChatGPTPluginRetriever(BaseRetriever, BaseModel):
|
||||
content = d.pop("text")
|
||||
docs.append(Document(page_content=content, metadata=d))
|
||||
return docs
|
||||
|
||||
async def aget_relevant_documents(self, query: str) -> List[Document]:
|
||||
url = f"{self.url}/query"
|
||||
json = {"queries": [{"query": query}]}
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {self.bearer_token}",
|
||||
}
|
||||
|
||||
if not self.aiosession:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(url, headers=headers, json=json) as response:
|
||||
res = await response.json()
|
||||
else:
|
||||
async with self.aiosession.post(
|
||||
url, headers=headers, json=json
|
||||
) as response:
|
||||
res = await response.json()
|
||||
|
||||
results = res["results"][0]["results"]
|
||||
docs = []
|
||||
for d in results:
|
||||
content = d.pop("text")
|
||||
docs.append(Document(page_content=content, metadata=d))
|
||||
return docs
|
||||
|
@ -33,6 +33,9 @@ class LlamaIndexRetriever(BaseRetriever, BaseModel):
|
||||
)
|
||||
return docs
|
||||
|
||||
async def aget_relevant_documents(self, query: str) -> List[Document]:
|
||||
raise NotImplementedError("LlamaIndexRetriever does not support async")
|
||||
|
||||
|
||||
class LlamaIndexGraphRetriever(BaseRetriever, BaseModel):
|
||||
"""Question-answering with sources over an LlamaIndex graph data structure."""
|
||||
@ -69,3 +72,6 @@ class LlamaIndexGraphRetriever(BaseRetriever, BaseModel):
|
||||
Document(page_content=source_node.source_text, metadata=metadata)
|
||||
)
|
||||
return docs
|
||||
|
||||
async def aget_relevant_documents(self, query: str) -> List[Document]:
|
||||
raise NotImplementedError("LlamaIndexGraphRetriever does not support async")
|
||||
|
@ -310,6 +310,17 @@ class BaseRetriever(ABC):
|
||||
List of relevant documents
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def aget_relevant_documents(self, query: str) -> List[Document]:
|
||||
"""Get documents relevant for a query.
|
||||
|
||||
Args:
|
||||
query: string to find relevant documents for
|
||||
|
||||
Returns:
|
||||
List of relevant documents
|
||||
"""
|
||||
|
||||
|
||||
# For backwards compatibility
|
||||
|
||||
|
@ -159,3 +159,6 @@ class VectorStoreRetriever(BaseRetriever, BaseModel):
|
||||
else:
|
||||
raise ValueError(f"search_type of {self.search_type} not allowed.")
|
||||
return docs
|
||||
|
||||
async def aget_relevant_documents(self, query: str) -> List[Document]:
|
||||
raise NotImplementedError("VectorStoreRetriever does not support async")
|
||||
|
@ -375,3 +375,6 @@ class RedisVectorStoreRetriever(BaseRetriever, BaseModel):
|
||||
else:
|
||||
raise ValueError(f"search_type of {self.search_type} not allowed.")
|
||||
return docs
|
||||
|
||||
async def aget_relevant_documents(self, query: str) -> List[Document]:
|
||||
raise NotImplementedError("RedisVectorStoreRetriever does not support async")
|
||||
|
Loading…
Reference in New Issue
Block a user