community[patch]: Add async methods to VectorStoreQATool (#16949)

pull/16822/head
Christophe Bornet 8 months ago committed by GitHub
parent fb7552bfcf
commit ab025507bc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -3,7 +3,10 @@
import json
from typing import Any, Dict, Optional
from langchain_core.callbacks import CallbackManagerForToolRun
from langchain_core.callbacks import (
AsyncCallbackManagerForToolRun,
CallbackManagerForToolRun,
)
from langchain_core.language_models import BaseLanguageModel
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_core.tools import BaseTool
@ -51,9 +54,30 @@ class VectorStoreQATool(BaseVectorStoreTool, BaseTool):
chain = RetrievalQA.from_chain_type(
self.llm, retriever=self.vectorstore.as_retriever()
)
return chain.run(
query, callbacks=run_manager.get_child() if run_manager else None
return chain.invoke(
{chain.input_key: query},
config={"callbacks": [run_manager.get_child() if run_manager else None]},
)[chain.output_key]
async def _arun(
self,
query: str,
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
) -> str:
"""Use the tool asynchronously."""
from langchain.chains.retrieval_qa.base import RetrievalQA
chain = RetrievalQA.from_chain_type(
self.llm, retriever=self.vectorstore.as_retriever()
)
return (
await chain.ainvoke(
{chain.input_key: query},
config={
"callbacks": [run_manager.get_child() if run_manager else None]
},
)
)[chain.output_key]
class VectorStoreQAWithSourcesTool(BaseVectorStoreTool, BaseTool):
@ -87,7 +111,28 @@ class VectorStoreQAWithSourcesTool(BaseVectorStoreTool, BaseTool):
self.llm, retriever=self.vectorstore.as_retriever()
)
return json.dumps(
chain(
chain.invoke(
{chain.question_key: query},
return_only_outputs=True,
callbacks=run_manager.get_child() if run_manager else None,
)
)
async def _arun(
self,
query: str,
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
) -> str:
"""Use the tool asynchronously."""
from langchain.chains.qa_with_sources.retrieval import (
RetrievalQAWithSourcesChain,
)
chain = RetrievalQAWithSourcesChain.from_chain_type(
self.llm, retriever=self.vectorstore.as_retriever()
)
return json.dumps(
await chain.ainvoke(
{chain.question_key: query},
return_only_outputs=True,
callbacks=run_manager.get_child() if run_manager else None,

Loading…
Cancel
Save