supported async retriever (#2149)

This commit is contained in:
Kei Kamikawa 2023-03-30 23:14:05 +09:00 committed by GitHub
parent 65c0c73597
commit 35a3218e84
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 132 additions and 3 deletions

View File

@ -86,6 +86,10 @@ class BaseConversationalRetrievalChain(Chain, BaseModel):
else:
return {self.output_key: answer}
@abstractmethod
async def _aget_docs(self, question: str, inputs: Dict[str, Any]) -> List[Document]:
"""Get docs."""
async def _acall(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
question = inputs["question"]
get_chat_history = self.get_chat_history or _get_chat_history
@ -96,8 +100,7 @@ class BaseConversationalRetrievalChain(Chain, BaseModel):
)
else:
new_question = question
# TODO: This blocks the event loop, but it's not clear how to avoid it.
docs = self._get_docs(new_question, inputs)
docs = await self._aget_docs(new_question, inputs)
new_inputs = inputs.copy()
new_inputs["question"] = new_question
new_inputs["chat_history"] = chat_history_str
@ -143,6 +146,10 @@ class ConversationalRetrievalChain(BaseConversationalRetrievalChain, BaseModel):
docs = self.retriever.get_relevant_documents(question)
return self._reduce_tokens_below_limit(docs)
async def _aget_docs(self, question: str, inputs: Dict[str, Any]) -> List[Document]:
docs = await self.retriever.aget_relevant_documents(question)
return self._reduce_tokens_below_limit(docs)
@classmethod
def from_llm(
cls,
@ -194,6 +201,9 @@ class ChatVectorDBChain(BaseConversationalRetrievalChain, BaseModel):
question, k=self.top_k_docs_for_context, **full_kwargs
)
async def _aget_docs(self, question: str, inputs: Dict[str, Any]) -> List[Document]:
raise NotImplementedError("ChatVectorDBChain does not support async")
@classmethod
def from_llm(
cls,

View File

@ -129,6 +129,25 @@ class BaseQAWithSourcesChain(Chain, BaseModel, ABC):
result["source_documents"] = docs
return result
@abstractmethod
async def _aget_docs(self, inputs: Dict[str, Any]) -> List[Document]:
"""Get docs to run questioning over."""
async def _acall(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
docs = await self._aget_docs(inputs)
answer, _ = await self.combine_documents_chain.acombine_docs(docs, **inputs)
if re.search(r"SOURCES:\s", answer):
answer, sources = re.split(r"SOURCES:\s", answer)
else:
sources = ""
result: Dict[str, Any] = {
self.answer_key: answer,
self.sources_answer_key: sources,
}
if self.return_source_documents:
result["source_documents"] = docs
return result
class QAWithSourcesChain(BaseQAWithSourcesChain, BaseModel):
"""Question answering with sources over documents."""
@ -146,6 +165,9 @@ class QAWithSourcesChain(BaseQAWithSourcesChain, BaseModel):
def _get_docs(self, inputs: Dict[str, Any]) -> List[Document]:
return inputs.pop(self.input_docs_key)
async def _aget_docs(self, inputs: Dict[str, Any]) -> List[Document]:
return inputs.pop(self.input_docs_key)
@property
def _chain_type(self) -> str:
return "qa_with_sources_chain"

View File

@ -44,3 +44,8 @@ class RetrievalQAWithSourcesChain(BaseQAWithSourcesChain, BaseModel):
question = inputs[self.question_key]
docs = self.retriever.get_relevant_documents(question)
return self._reduce_tokens_below_limit(docs)
async def _aget_docs(self, inputs: Dict[str, Any]) -> List[Document]:
question = inputs[self.question_key]
docs = await self.retriever.aget_relevant_documents(question)
return self._reduce_tokens_below_limit(docs)

View File

@ -52,6 +52,9 @@ class VectorDBQAWithSourcesChain(BaseQAWithSourcesChain, BaseModel):
)
return self._reduce_tokens_below_limit(docs)
async def _aget_docs(self, inputs: Dict[str, Any]) -> List[Document]:
raise NotImplementedError("VectorDBQAWithSourcesChain does not support async")
@root_validator()
def raise_deprecation(cls, values: Dict) -> Dict:
warnings.warn(

View File

@ -114,6 +114,34 @@ class BaseRetrievalQA(Chain, BaseModel):
else:
return {self.output_key: answer}
@abstractmethod
async def _aget_docs(self, question: str) -> List[Document]:
"""Get documents to do question answering over."""
async def _acall(self, inputs: Dict[str, str]) -> Dict[str, Any]:
"""Run get_relevant_text and llm on input query.
If chain has 'return_source_documents' as 'True', returns
the retrieved documents as well under the key 'source_documents'.
Example:
.. code-block:: python
res = indexqa({'query': 'This is my query'})
answer, docs = res['result'], res['source_documents']
"""
question = inputs[self.input_key]
docs = await self._aget_docs(question)
answer, _ = await self.combine_documents_chain.acombine_docs(
docs, question=question
)
if self.return_source_documents:
return {self.output_key: answer, "source_documents": docs}
else:
return {self.output_key: answer}
class RetrievalQA(BaseRetrievalQA, BaseModel):
"""Chain for question-answering against an index.
@ -134,6 +162,9 @@ class RetrievalQA(BaseRetrievalQA, BaseModel):
def _get_docs(self, question: str) -> List[Document]:
return self.retriever.get_relevant_documents(question)
async def _aget_docs(self, question: str) -> List[Document]:
return await self.retriever.aget_relevant_documents(question)
class VectorDBQA(BaseRetrievalQA, BaseModel):
"""Chain for question-answering against a vector database."""
@ -177,6 +208,9 @@ class VectorDBQA(BaseRetrievalQA, BaseModel):
raise ValueError(f"search_type of {self.search_type} not allowed.")
return docs
async def _aget_docs(self, question: str) -> List[Document]:
raise NotImplementedError("VectorDBQA does not support async")
@property
def _chain_type(self) -> str:
"""Return the chain type."""

View File

@ -1,5 +1,6 @@
from typing import List
from typing import List, Optional
import aiohttp
import requests
from pydantic import BaseModel
@ -9,6 +10,12 @@ from langchain.schema import BaseRetriever, Document
class ChatGPTPluginRetriever(BaseRetriever, BaseModel):
url: str
bearer_token: str
aiosession: Optional[aiohttp.ClientSession] = None
class Config:
"""Configuration for this pydantic object."""
arbitrary_types_allowed = True
def get_relevant_documents(self, query: str) -> List[Document]:
response = requests.post(
@ -25,3 +32,28 @@ class ChatGPTPluginRetriever(BaseRetriever, BaseModel):
content = d.pop("text")
docs.append(Document(page_content=content, metadata=d))
return docs
async def aget_relevant_documents(self, query: str) -> List[Document]:
url = f"{self.url}/query"
json = {"queries": [{"query": query}]}
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {self.bearer_token}",
}
if not self.aiosession:
async with aiohttp.ClientSession() as session:
async with session.post(url, headers=headers, json=json) as response:
res = await response.json()
else:
async with self.aiosession.post(
url, headers=headers, json=json
) as response:
res = await response.json()
results = res["results"][0]["results"]
docs = []
for d in results:
content = d.pop("text")
docs.append(Document(page_content=content, metadata=d))
return docs

View File

@ -33,6 +33,9 @@ class LlamaIndexRetriever(BaseRetriever, BaseModel):
)
return docs
async def aget_relevant_documents(self, query: str) -> List[Document]:
raise NotImplementedError("LlamaIndexRetriever does not support async")
class LlamaIndexGraphRetriever(BaseRetriever, BaseModel):
"""Question-answering with sources over an LlamaIndex graph data structure."""
@ -69,3 +72,6 @@ class LlamaIndexGraphRetriever(BaseRetriever, BaseModel):
Document(page_content=source_node.source_text, metadata=metadata)
)
return docs
async def aget_relevant_documents(self, query: str) -> List[Document]:
raise NotImplementedError("LlamaIndexGraphRetriever does not support async")

View File

@ -310,6 +310,17 @@ class BaseRetriever(ABC):
List of relevant documents
"""
@abstractmethod
async def aget_relevant_documents(self, query: str) -> List[Document]:
"""Get documents relevant for a query.
Args:
query: string to find relevant documents for
Returns:
List of relevant documents
"""
# For backwards compatibility

View File

@ -159,3 +159,6 @@ class VectorStoreRetriever(BaseRetriever, BaseModel):
else:
raise ValueError(f"search_type of {self.search_type} not allowed.")
return docs
async def aget_relevant_documents(self, query: str) -> List[Document]:
raise NotImplementedError("VectorStoreRetriever does not support async")

View File

@ -375,3 +375,6 @@ class RedisVectorStoreRetriever(BaseRetriever, BaseModel):
else:
raise ValueError(f"search_type of {self.search_type} not allowed.")
return docs
async def aget_relevant_documents(self, query: str) -> List[Document]:
raise NotImplementedError("RedisVectorStoreRetriever does not support async")