From 35a3218e84accfa7a28a6884dedd1e3590540975 Mon Sep 17 00:00:00 2001 From: Kei Kamikawa Date: Thu, 30 Mar 2023 23:14:05 +0900 Subject: [PATCH] supported async retriever (#2149) --- .../chains/conversational_retrieval/base.py | 14 ++++++-- langchain/chains/qa_with_sources/base.py | 22 ++++++++++++ langchain/chains/qa_with_sources/retrieval.py | 5 +++ langchain/chains/qa_with_sources/vector_db.py | 3 ++ langchain/chains/retrieval_qa/base.py | 34 +++++++++++++++++++ .../retrievers/chatgpt_plugin_retriever.py | 34 ++++++++++++++++++- langchain/retrievers/llama_index.py | 6 ++++ langchain/schema.py | 11 ++++++ langchain/vectorstores/base.py | 3 ++ langchain/vectorstores/redis.py | 3 ++ 10 files changed, 132 insertions(+), 3 deletions(-) diff --git a/langchain/chains/conversational_retrieval/base.py b/langchain/chains/conversational_retrieval/base.py index 5cec8cc257..e58382171e 100644 --- a/langchain/chains/conversational_retrieval/base.py +++ b/langchain/chains/conversational_retrieval/base.py @@ -86,6 +86,10 @@ class BaseConversationalRetrievalChain(Chain, BaseModel): else: return {self.output_key: answer} + @abstractmethod + async def _aget_docs(self, question: str, inputs: Dict[str, Any]) -> List[Document]: + """Get docs.""" + async def _acall(self, inputs: Dict[str, Any]) -> Dict[str, Any]: question = inputs["question"] get_chat_history = self.get_chat_history or _get_chat_history @@ -96,8 +100,7 @@ class BaseConversationalRetrievalChain(Chain, BaseModel): ) else: new_question = question - # TODO: This blocks the event loop, but it's not clear how to avoid it. - docs = self._get_docs(new_question, inputs) + docs = await self._aget_docs(new_question, inputs) new_inputs = inputs.copy() new_inputs["question"] = new_question new_inputs["chat_history"] = chat_history_str @@ -143,6 +146,10 @@ class ConversationalRetrievalChain(BaseConversationalRetrievalChain, BaseModel): docs = self.retriever.get_relevant_documents(question) return self._reduce_tokens_below_limit(docs) + async def _aget_docs(self, question: str, inputs: Dict[str, Any]) -> List[Document]: + docs = await self.retriever.aget_relevant_documents(question) + return self._reduce_tokens_below_limit(docs) + @classmethod def from_llm( cls, @@ -194,6 +201,9 @@ class ChatVectorDBChain(BaseConversationalRetrievalChain, BaseModel): question, k=self.top_k_docs_for_context, **full_kwargs ) + async def _aget_docs(self, question: str, inputs: Dict[str, Any]) -> List[Document]: + raise NotImplementedError("ChatVectorDBChain does not support async") + @classmethod def from_llm( cls, diff --git a/langchain/chains/qa_with_sources/base.py b/langchain/chains/qa_with_sources/base.py index 499e4e2531..ae5efd11d8 100644 --- a/langchain/chains/qa_with_sources/base.py +++ b/langchain/chains/qa_with_sources/base.py @@ -129,6 +129,25 @@ class BaseQAWithSourcesChain(Chain, BaseModel, ABC): result["source_documents"] = docs return result + @abstractmethod + async def _aget_docs(self, inputs: Dict[str, Any]) -> List[Document]: + """Get docs to run questioning over.""" + + async def _acall(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + docs = await self._aget_docs(inputs) + answer, _ = await self.combine_documents_chain.acombine_docs(docs, **inputs) + if re.search(r"SOURCES:\s", answer): + answer, sources = re.split(r"SOURCES:\s", answer) + else: + sources = "" + result: Dict[str, Any] = { + self.answer_key: answer, + self.sources_answer_key: sources, + } + if self.return_source_documents: + result["source_documents"] = docs + return result + class QAWithSourcesChain(BaseQAWithSourcesChain, BaseModel): """Question answering with sources over documents.""" @@ -146,6 +165,9 @@ class QAWithSourcesChain(BaseQAWithSourcesChain, BaseModel): def _get_docs(self, inputs: Dict[str, Any]) -> List[Document]: return inputs.pop(self.input_docs_key) + async def _aget_docs(self, inputs: Dict[str, Any]) -> List[Document]: + return inputs.pop(self.input_docs_key) + @property def _chain_type(self) -> str: return "qa_with_sources_chain" diff --git a/langchain/chains/qa_with_sources/retrieval.py b/langchain/chains/qa_with_sources/retrieval.py index c9e9072212..6d1944f0ee 100644 --- a/langchain/chains/qa_with_sources/retrieval.py +++ b/langchain/chains/qa_with_sources/retrieval.py @@ -44,3 +44,8 @@ class RetrievalQAWithSourcesChain(BaseQAWithSourcesChain, BaseModel): question = inputs[self.question_key] docs = self.retriever.get_relevant_documents(question) return self._reduce_tokens_below_limit(docs) + + async def _aget_docs(self, inputs: Dict[str, Any]) -> List[Document]: + question = inputs[self.question_key] + docs = await self.retriever.aget_relevant_documents(question) + return self._reduce_tokens_below_limit(docs) diff --git a/langchain/chains/qa_with_sources/vector_db.py b/langchain/chains/qa_with_sources/vector_db.py index 5f95d8c1e6..4f53393964 100644 --- a/langchain/chains/qa_with_sources/vector_db.py +++ b/langchain/chains/qa_with_sources/vector_db.py @@ -52,6 +52,9 @@ class VectorDBQAWithSourcesChain(BaseQAWithSourcesChain, BaseModel): ) return self._reduce_tokens_below_limit(docs) + async def _aget_docs(self, inputs: Dict[str, Any]) -> List[Document]: + raise NotImplementedError("VectorDBQAWithSourcesChain does not support async") + @root_validator() def raise_deprecation(cls, values: Dict) -> Dict: warnings.warn( diff --git a/langchain/chains/retrieval_qa/base.py b/langchain/chains/retrieval_qa/base.py index 40ad34d482..a6bf675d55 100644 --- a/langchain/chains/retrieval_qa/base.py +++ b/langchain/chains/retrieval_qa/base.py @@ -114,6 +114,34 @@ class BaseRetrievalQA(Chain, BaseModel): else: return {self.output_key: answer} + @abstractmethod + async def _aget_docs(self, question: str) -> List[Document]: + """Get documents to do question answering over.""" + + async def _acall(self, inputs: Dict[str, str]) -> Dict[str, Any]: + """Run get_relevant_text and llm on input query. + + If chain has 'return_source_documents' as 'True', returns + the retrieved documents as well under the key 'source_documents'. + + Example: + .. code-block:: python + + res = indexqa({'query': 'This is my query'}) + answer, docs = res['result'], res['source_documents'] + """ + question = inputs[self.input_key] + + docs = await self._aget_docs(question) + answer, _ = await self.combine_documents_chain.acombine_docs( + docs, question=question + ) + + if self.return_source_documents: + return {self.output_key: answer, "source_documents": docs} + else: + return {self.output_key: answer} + class RetrievalQA(BaseRetrievalQA, BaseModel): """Chain for question-answering against an index. @@ -134,6 +162,9 @@ class RetrievalQA(BaseRetrievalQA, BaseModel): def _get_docs(self, question: str) -> List[Document]: return self.retriever.get_relevant_documents(question) + async def _aget_docs(self, question: str) -> List[Document]: + return await self.retriever.aget_relevant_documents(question) + class VectorDBQA(BaseRetrievalQA, BaseModel): """Chain for question-answering against a vector database.""" @@ -177,6 +208,9 @@ class VectorDBQA(BaseRetrievalQA, BaseModel): raise ValueError(f"search_type of {self.search_type} not allowed.") return docs + async def _aget_docs(self, question: str) -> List[Document]: + raise NotImplementedError("VectorDBQA does not support async") + @property def _chain_type(self) -> str: """Return the chain type.""" diff --git a/langchain/retrievers/chatgpt_plugin_retriever.py b/langchain/retrievers/chatgpt_plugin_retriever.py index 610a79a281..d2d41327e0 100644 --- a/langchain/retrievers/chatgpt_plugin_retriever.py +++ b/langchain/retrievers/chatgpt_plugin_retriever.py @@ -1,5 +1,6 @@ -from typing import List +from typing import List, Optional +import aiohttp import requests from pydantic import BaseModel @@ -9,6 +10,12 @@ from langchain.schema import BaseRetriever, Document class ChatGPTPluginRetriever(BaseRetriever, BaseModel): url: str bearer_token: str + aiosession: Optional[aiohttp.ClientSession] = None + + class Config: + """Configuration for this pydantic object.""" + + arbitrary_types_allowed = True def get_relevant_documents(self, query: str) -> List[Document]: response = requests.post( @@ -25,3 +32,28 @@ class ChatGPTPluginRetriever(BaseRetriever, BaseModel): content = d.pop("text") docs.append(Document(page_content=content, metadata=d)) return docs + + async def aget_relevant_documents(self, query: str) -> List[Document]: + url = f"{self.url}/query" + json = {"queries": [{"query": query}]} + headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {self.bearer_token}", + } + + if not self.aiosession: + async with aiohttp.ClientSession() as session: + async with session.post(url, headers=headers, json=json) as response: + res = await response.json() + else: + async with self.aiosession.post( + url, headers=headers, json=json + ) as response: + res = await response.json() + + results = res["results"][0]["results"] + docs = [] + for d in results: + content = d.pop("text") + docs.append(Document(page_content=content, metadata=d)) + return docs diff --git a/langchain/retrievers/llama_index.py b/langchain/retrievers/llama_index.py index 4b3ee6453f..9e650c53c9 100644 --- a/langchain/retrievers/llama_index.py +++ b/langchain/retrievers/llama_index.py @@ -33,6 +33,9 @@ class LlamaIndexRetriever(BaseRetriever, BaseModel): ) return docs + async def aget_relevant_documents(self, query: str) -> List[Document]: + raise NotImplementedError("LlamaIndexRetriever does not support async") + class LlamaIndexGraphRetriever(BaseRetriever, BaseModel): """Question-answering with sources over an LlamaIndex graph data structure.""" @@ -69,3 +72,6 @@ class LlamaIndexGraphRetriever(BaseRetriever, BaseModel): Document(page_content=source_node.source_text, metadata=metadata) ) return docs + + async def aget_relevant_documents(self, query: str) -> List[Document]: + raise NotImplementedError("LlamaIndexGraphRetriever does not support async") diff --git a/langchain/schema.py b/langchain/schema.py index 2c09b65759..cc8e77bedd 100644 --- a/langchain/schema.py +++ b/langchain/schema.py @@ -310,6 +310,17 @@ class BaseRetriever(ABC): List of relevant documents """ + @abstractmethod + async def aget_relevant_documents(self, query: str) -> List[Document]: + """Get documents relevant for a query. + + Args: + query: string to find relevant documents for + + Returns: + List of relevant documents + """ + # For backwards compatibility diff --git a/langchain/vectorstores/base.py b/langchain/vectorstores/base.py index 87caddd12e..f3939da95e 100644 --- a/langchain/vectorstores/base.py +++ b/langchain/vectorstores/base.py @@ -159,3 +159,6 @@ class VectorStoreRetriever(BaseRetriever, BaseModel): else: raise ValueError(f"search_type of {self.search_type} not allowed.") return docs + + async def aget_relevant_documents(self, query: str) -> List[Document]: + raise NotImplementedError("VectorStoreRetriever does not support async") diff --git a/langchain/vectorstores/redis.py b/langchain/vectorstores/redis.py index 5e0a6986a0..c436480144 100644 --- a/langchain/vectorstores/redis.py +++ b/langchain/vectorstores/redis.py @@ -375,3 +375,6 @@ class RedisVectorStoreRetriever(BaseRetriever, BaseModel): else: raise ValueError(f"search_type of {self.search_type} not allowed.") return docs + + async def aget_relevant_documents(self, query: str) -> List[Document]: + raise NotImplementedError("RedisVectorStoreRetriever does not support async")