|
|
|
@ -114,6 +114,34 @@ class BaseRetrievalQA(Chain, BaseModel):
|
|
|
|
|
else:
|
|
|
|
|
return {self.output_key: answer}
|
|
|
|
|
|
|
|
|
|
@abstractmethod
|
|
|
|
|
async def _aget_docs(self, question: str) -> List[Document]:
|
|
|
|
|
"""Get documents to do question answering over."""
|
|
|
|
|
|
|
|
|
|
async def _acall(self, inputs: Dict[str, str]) -> Dict[str, Any]:
|
|
|
|
|
"""Run get_relevant_text and llm on input query.
|
|
|
|
|
|
|
|
|
|
If chain has 'return_source_documents' as 'True', returns
|
|
|
|
|
the retrieved documents as well under the key 'source_documents'.
|
|
|
|
|
|
|
|
|
|
Example:
|
|
|
|
|
.. code-block:: python
|
|
|
|
|
|
|
|
|
|
res = indexqa({'query': 'This is my query'})
|
|
|
|
|
answer, docs = res['result'], res['source_documents']
|
|
|
|
|
"""
|
|
|
|
|
question = inputs[self.input_key]
|
|
|
|
|
|
|
|
|
|
docs = await self._aget_docs(question)
|
|
|
|
|
answer, _ = await self.combine_documents_chain.acombine_docs(
|
|
|
|
|
docs, question=question
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if self.return_source_documents:
|
|
|
|
|
return {self.output_key: answer, "source_documents": docs}
|
|
|
|
|
else:
|
|
|
|
|
return {self.output_key: answer}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class RetrievalQA(BaseRetrievalQA, BaseModel):
|
|
|
|
|
"""Chain for question-answering against an index.
|
|
|
|
@ -134,6 +162,9 @@ class RetrievalQA(BaseRetrievalQA, BaseModel):
|
|
|
|
|
def _get_docs(self, question: str) -> List[Document]:
|
|
|
|
|
return self.retriever.get_relevant_documents(question)
|
|
|
|
|
|
|
|
|
|
async def _aget_docs(self, question: str) -> List[Document]:
|
|
|
|
|
return await self.retriever.aget_relevant_documents(question)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class VectorDBQA(BaseRetrievalQA, BaseModel):
|
|
|
|
|
"""Chain for question-answering against a vector database."""
|
|
|
|
@ -177,6 +208,9 @@ class VectorDBQA(BaseRetrievalQA, BaseModel):
|
|
|
|
|
raise ValueError(f"search_type of {self.search_type} not allowed.")
|
|
|
|
|
return docs
|
|
|
|
|
|
|
|
|
|
async def _aget_docs(self, question: str) -> List[Document]:
|
|
|
|
|
raise NotImplementedError("VectorDBQA does not support async")
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def _chain_type(self) -> str:
|
|
|
|
|
"""Return the chain type."""
|
|
|
|
|