Update VectorDBQA to RetrievalQA in tools (#3698)

Because `VectorDBQA` and `VectorDBQAWithSourcesChain` are deprecated
fix_agent_callbacks
Zach Schillaci 1 year ago committed by GitHub
parent 32793f94fd
commit 1bf1c37c0c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -5,8 +5,7 @@ from typing import Any, Dict
from pydantic import BaseModel, Field
from langchain.chains.qa_with_sources.vector_db import VectorDBQAWithSourcesChain
from langchain.chains.retrieval_qa.base import VectorDBQA
from langchain.chains import RetrievalQA, RetrievalQAWithSourcesChain
from langchain.llms.base import BaseLLM
from langchain.llms.openai import OpenAI
from langchain.tools.base import BaseTool
@ -45,12 +44,14 @@ class VectorStoreQATool(BaseVectorStoreTool, BaseTool):
def _run(self, query: str) -> str:
"""Use the tool."""
chain = VectorDBQA.from_chain_type(self.llm, vectorstore=self.vectorstore)
chain = RetrievalQA.from_chain_type(
self.llm, retriever=self.vectorstore.as_retriever()
)
return chain.run(query)
async def _arun(self, query: str) -> str:
"""Use the tool asynchronously."""
raise NotImplementedError("VectorDBQATool does not support async")
raise NotImplementedError("VectorStoreQATool does not support async")
class VectorStoreQAWithSourcesTool(BaseVectorStoreTool, BaseTool):
@ -71,11 +72,11 @@ class VectorStoreQAWithSourcesTool(BaseVectorStoreTool, BaseTool):
def _run(self, query: str) -> str:
"""Use the tool."""
chain = VectorDBQAWithSourcesChain.from_chain_type(
self.llm, vectorstore=self.vectorstore
chain = RetrievalQAWithSourcesChain.from_chain_type(
self.llm, retriever=self.vectorstore.as_retriever()
)
return json.dumps(chain({chain.question_key: query}, return_only_outputs=True))
async def _arun(self, query: str) -> str:
"""Use the tool asynchronously."""
raise NotImplementedError("VectorDBQATool does not support async")
raise NotImplementedError("VectorStoreQAWithSourcesTool does not support async")

Loading…
Cancel
Save