forked from Archives/langchain
supported async retriever (#2149)
This commit is contained in:
parent
65c0c73597
commit
35a3218e84
@ -86,6 +86,10 @@ class BaseConversationalRetrievalChain(Chain, BaseModel):
|
|||||||
else:
|
else:
|
||||||
return {self.output_key: answer}
|
return {self.output_key: answer}
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def _aget_docs(self, question: str, inputs: Dict[str, Any]) -> List[Document]:
|
||||||
|
"""Get docs."""
|
||||||
|
|
||||||
async def _acall(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
async def _acall(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
question = inputs["question"]
|
question = inputs["question"]
|
||||||
get_chat_history = self.get_chat_history or _get_chat_history
|
get_chat_history = self.get_chat_history or _get_chat_history
|
||||||
@ -96,8 +100,7 @@ class BaseConversationalRetrievalChain(Chain, BaseModel):
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
new_question = question
|
new_question = question
|
||||||
# TODO: This blocks the event loop, but it's not clear how to avoid it.
|
docs = await self._aget_docs(new_question, inputs)
|
||||||
docs = self._get_docs(new_question, inputs)
|
|
||||||
new_inputs = inputs.copy()
|
new_inputs = inputs.copy()
|
||||||
new_inputs["question"] = new_question
|
new_inputs["question"] = new_question
|
||||||
new_inputs["chat_history"] = chat_history_str
|
new_inputs["chat_history"] = chat_history_str
|
||||||
@ -143,6 +146,10 @@ class ConversationalRetrievalChain(BaseConversationalRetrievalChain, BaseModel):
|
|||||||
docs = self.retriever.get_relevant_documents(question)
|
docs = self.retriever.get_relevant_documents(question)
|
||||||
return self._reduce_tokens_below_limit(docs)
|
return self._reduce_tokens_below_limit(docs)
|
||||||
|
|
||||||
|
async def _aget_docs(self, question: str, inputs: Dict[str, Any]) -> List[Document]:
|
||||||
|
docs = await self.retriever.aget_relevant_documents(question)
|
||||||
|
return self._reduce_tokens_below_limit(docs)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_llm(
|
def from_llm(
|
||||||
cls,
|
cls,
|
||||||
@ -194,6 +201,9 @@ class ChatVectorDBChain(BaseConversationalRetrievalChain, BaseModel):
|
|||||||
question, k=self.top_k_docs_for_context, **full_kwargs
|
question, k=self.top_k_docs_for_context, **full_kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def _aget_docs(self, question: str, inputs: Dict[str, Any]) -> List[Document]:
|
||||||
|
raise NotImplementedError("ChatVectorDBChain does not support async")
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_llm(
|
def from_llm(
|
||||||
cls,
|
cls,
|
||||||
|
@ -129,6 +129,25 @@ class BaseQAWithSourcesChain(Chain, BaseModel, ABC):
|
|||||||
result["source_documents"] = docs
|
result["source_documents"] = docs
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def _aget_docs(self, inputs: Dict[str, Any]) -> List[Document]:
|
||||||
|
"""Get docs to run questioning over."""
|
||||||
|
|
||||||
|
async def _acall(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
docs = await self._aget_docs(inputs)
|
||||||
|
answer, _ = await self.combine_documents_chain.acombine_docs(docs, **inputs)
|
||||||
|
if re.search(r"SOURCES:\s", answer):
|
||||||
|
answer, sources = re.split(r"SOURCES:\s", answer)
|
||||||
|
else:
|
||||||
|
sources = ""
|
||||||
|
result: Dict[str, Any] = {
|
||||||
|
self.answer_key: answer,
|
||||||
|
self.sources_answer_key: sources,
|
||||||
|
}
|
||||||
|
if self.return_source_documents:
|
||||||
|
result["source_documents"] = docs
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
class QAWithSourcesChain(BaseQAWithSourcesChain, BaseModel):
|
class QAWithSourcesChain(BaseQAWithSourcesChain, BaseModel):
|
||||||
"""Question answering with sources over documents."""
|
"""Question answering with sources over documents."""
|
||||||
@ -146,6 +165,9 @@ class QAWithSourcesChain(BaseQAWithSourcesChain, BaseModel):
|
|||||||
def _get_docs(self, inputs: Dict[str, Any]) -> List[Document]:
|
def _get_docs(self, inputs: Dict[str, Any]) -> List[Document]:
|
||||||
return inputs.pop(self.input_docs_key)
|
return inputs.pop(self.input_docs_key)
|
||||||
|
|
||||||
|
async def _aget_docs(self, inputs: Dict[str, Any]) -> List[Document]:
|
||||||
|
return inputs.pop(self.input_docs_key)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _chain_type(self) -> str:
|
def _chain_type(self) -> str:
|
||||||
return "qa_with_sources_chain"
|
return "qa_with_sources_chain"
|
||||||
|
@ -44,3 +44,8 @@ class RetrievalQAWithSourcesChain(BaseQAWithSourcesChain, BaseModel):
|
|||||||
question = inputs[self.question_key]
|
question = inputs[self.question_key]
|
||||||
docs = self.retriever.get_relevant_documents(question)
|
docs = self.retriever.get_relevant_documents(question)
|
||||||
return self._reduce_tokens_below_limit(docs)
|
return self._reduce_tokens_below_limit(docs)
|
||||||
|
|
||||||
|
async def _aget_docs(self, inputs: Dict[str, Any]) -> List[Document]:
|
||||||
|
question = inputs[self.question_key]
|
||||||
|
docs = await self.retriever.aget_relevant_documents(question)
|
||||||
|
return self._reduce_tokens_below_limit(docs)
|
||||||
|
@ -52,6 +52,9 @@ class VectorDBQAWithSourcesChain(BaseQAWithSourcesChain, BaseModel):
|
|||||||
)
|
)
|
||||||
return self._reduce_tokens_below_limit(docs)
|
return self._reduce_tokens_below_limit(docs)
|
||||||
|
|
||||||
|
async def _aget_docs(self, inputs: Dict[str, Any]) -> List[Document]:
|
||||||
|
raise NotImplementedError("VectorDBQAWithSourcesChain does not support async")
|
||||||
|
|
||||||
@root_validator()
|
@root_validator()
|
||||||
def raise_deprecation(cls, values: Dict) -> Dict:
|
def raise_deprecation(cls, values: Dict) -> Dict:
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
|
@ -114,6 +114,34 @@ class BaseRetrievalQA(Chain, BaseModel):
|
|||||||
else:
|
else:
|
||||||
return {self.output_key: answer}
|
return {self.output_key: answer}
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def _aget_docs(self, question: str) -> List[Document]:
|
||||||
|
"""Get documents to do question answering over."""
|
||||||
|
|
||||||
|
async def _acall(self, inputs: Dict[str, str]) -> Dict[str, Any]:
|
||||||
|
"""Run get_relevant_text and llm on input query.
|
||||||
|
|
||||||
|
If chain has 'return_source_documents' as 'True', returns
|
||||||
|
the retrieved documents as well under the key 'source_documents'.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
res = indexqa({'query': 'This is my query'})
|
||||||
|
answer, docs = res['result'], res['source_documents']
|
||||||
|
"""
|
||||||
|
question = inputs[self.input_key]
|
||||||
|
|
||||||
|
docs = await self._aget_docs(question)
|
||||||
|
answer, _ = await self.combine_documents_chain.acombine_docs(
|
||||||
|
docs, question=question
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.return_source_documents:
|
||||||
|
return {self.output_key: answer, "source_documents": docs}
|
||||||
|
else:
|
||||||
|
return {self.output_key: answer}
|
||||||
|
|
||||||
|
|
||||||
class RetrievalQA(BaseRetrievalQA, BaseModel):
|
class RetrievalQA(BaseRetrievalQA, BaseModel):
|
||||||
"""Chain for question-answering against an index.
|
"""Chain for question-answering against an index.
|
||||||
@ -134,6 +162,9 @@ class RetrievalQA(BaseRetrievalQA, BaseModel):
|
|||||||
def _get_docs(self, question: str) -> List[Document]:
|
def _get_docs(self, question: str) -> List[Document]:
|
||||||
return self.retriever.get_relevant_documents(question)
|
return self.retriever.get_relevant_documents(question)
|
||||||
|
|
||||||
|
async def _aget_docs(self, question: str) -> List[Document]:
|
||||||
|
return await self.retriever.aget_relevant_documents(question)
|
||||||
|
|
||||||
|
|
||||||
class VectorDBQA(BaseRetrievalQA, BaseModel):
|
class VectorDBQA(BaseRetrievalQA, BaseModel):
|
||||||
"""Chain for question-answering against a vector database."""
|
"""Chain for question-answering against a vector database."""
|
||||||
@ -177,6 +208,9 @@ class VectorDBQA(BaseRetrievalQA, BaseModel):
|
|||||||
raise ValueError(f"search_type of {self.search_type} not allowed.")
|
raise ValueError(f"search_type of {self.search_type} not allowed.")
|
||||||
return docs
|
return docs
|
||||||
|
|
||||||
|
async def _aget_docs(self, question: str) -> List[Document]:
|
||||||
|
raise NotImplementedError("VectorDBQA does not support async")
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _chain_type(self) -> str:
|
def _chain_type(self) -> str:
|
||||||
"""Return the chain type."""
|
"""Return the chain type."""
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
from typing import List
|
from typing import List, Optional
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
import requests
|
import requests
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
@ -9,6 +10,12 @@ from langchain.schema import BaseRetriever, Document
|
|||||||
class ChatGPTPluginRetriever(BaseRetriever, BaseModel):
|
class ChatGPTPluginRetriever(BaseRetriever, BaseModel):
|
||||||
url: str
|
url: str
|
||||||
bearer_token: str
|
bearer_token: str
|
||||||
|
aiosession: Optional[aiohttp.ClientSession] = None
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
"""Configuration for this pydantic object."""
|
||||||
|
|
||||||
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
def get_relevant_documents(self, query: str) -> List[Document]:
|
def get_relevant_documents(self, query: str) -> List[Document]:
|
||||||
response = requests.post(
|
response = requests.post(
|
||||||
@ -25,3 +32,28 @@ class ChatGPTPluginRetriever(BaseRetriever, BaseModel):
|
|||||||
content = d.pop("text")
|
content = d.pop("text")
|
||||||
docs.append(Document(page_content=content, metadata=d))
|
docs.append(Document(page_content=content, metadata=d))
|
||||||
return docs
|
return docs
|
||||||
|
|
||||||
|
async def aget_relevant_documents(self, query: str) -> List[Document]:
|
||||||
|
url = f"{self.url}/query"
|
||||||
|
json = {"queries": [{"query": query}]}
|
||||||
|
headers = {
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"Authorization": f"Bearer {self.bearer_token}",
|
||||||
|
}
|
||||||
|
|
||||||
|
if not self.aiosession:
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.post(url, headers=headers, json=json) as response:
|
||||||
|
res = await response.json()
|
||||||
|
else:
|
||||||
|
async with self.aiosession.post(
|
||||||
|
url, headers=headers, json=json
|
||||||
|
) as response:
|
||||||
|
res = await response.json()
|
||||||
|
|
||||||
|
results = res["results"][0]["results"]
|
||||||
|
docs = []
|
||||||
|
for d in results:
|
||||||
|
content = d.pop("text")
|
||||||
|
docs.append(Document(page_content=content, metadata=d))
|
||||||
|
return docs
|
||||||
|
@ -33,6 +33,9 @@ class LlamaIndexRetriever(BaseRetriever, BaseModel):
|
|||||||
)
|
)
|
||||||
return docs
|
return docs
|
||||||
|
|
||||||
|
async def aget_relevant_documents(self, query: str) -> List[Document]:
|
||||||
|
raise NotImplementedError("LlamaIndexRetriever does not support async")
|
||||||
|
|
||||||
|
|
||||||
class LlamaIndexGraphRetriever(BaseRetriever, BaseModel):
|
class LlamaIndexGraphRetriever(BaseRetriever, BaseModel):
|
||||||
"""Question-answering with sources over an LlamaIndex graph data structure."""
|
"""Question-answering with sources over an LlamaIndex graph data structure."""
|
||||||
@ -69,3 +72,6 @@ class LlamaIndexGraphRetriever(BaseRetriever, BaseModel):
|
|||||||
Document(page_content=source_node.source_text, metadata=metadata)
|
Document(page_content=source_node.source_text, metadata=metadata)
|
||||||
)
|
)
|
||||||
return docs
|
return docs
|
||||||
|
|
||||||
|
async def aget_relevant_documents(self, query: str) -> List[Document]:
|
||||||
|
raise NotImplementedError("LlamaIndexGraphRetriever does not support async")
|
||||||
|
@ -310,6 +310,17 @@ class BaseRetriever(ABC):
|
|||||||
List of relevant documents
|
List of relevant documents
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def aget_relevant_documents(self, query: str) -> List[Document]:
|
||||||
|
"""Get documents relevant for a query.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: string to find relevant documents for
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of relevant documents
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
# For backwards compatibility
|
# For backwards compatibility
|
||||||
|
|
||||||
|
@ -159,3 +159,6 @@ class VectorStoreRetriever(BaseRetriever, BaseModel):
|
|||||||
else:
|
else:
|
||||||
raise ValueError(f"search_type of {self.search_type} not allowed.")
|
raise ValueError(f"search_type of {self.search_type} not allowed.")
|
||||||
return docs
|
return docs
|
||||||
|
|
||||||
|
async def aget_relevant_documents(self, query: str) -> List[Document]:
|
||||||
|
raise NotImplementedError("VectorStoreRetriever does not support async")
|
||||||
|
@ -375,3 +375,6 @@ class RedisVectorStoreRetriever(BaseRetriever, BaseModel):
|
|||||||
else:
|
else:
|
||||||
raise ValueError(f"search_type of {self.search_type} not allowed.")
|
raise ValueError(f"search_type of {self.search_type} not allowed.")
|
||||||
return docs
|
return docs
|
||||||
|
|
||||||
|
async def aget_relevant_documents(self, query: str) -> List[Document]:
|
||||||
|
raise NotImplementedError("RedisVectorStoreRetriever does not support async")
|
||||||
|
Loading…
Reference in New Issue
Block a user