From 1bf1c37c0cccb7c8c73d87ace27cf742f814dbe5 Mon Sep 17 00:00:00 2001 From: Zach Schillaci <40636930+zachschillaci27@users.noreply.github.com> Date: Fri, 28 Apr 2023 16:39:59 +0200 Subject: [PATCH] Update VectorDBQA to RetrievalQA in tools (#3698) Because `VectorDBQA` and `VectorDBQAWithSourcesChain` are deprecated --- langchain/tools/vectorstore/tool.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/langchain/tools/vectorstore/tool.py b/langchain/tools/vectorstore/tool.py index adc35e7f..1dd18fd2 100644 --- a/langchain/tools/vectorstore/tool.py +++ b/langchain/tools/vectorstore/tool.py @@ -5,8 +5,7 @@ from typing import Any, Dict from pydantic import BaseModel, Field -from langchain.chains.qa_with_sources.vector_db import VectorDBQAWithSourcesChain -from langchain.chains.retrieval_qa.base import VectorDBQA +from langchain.chains import RetrievalQA, RetrievalQAWithSourcesChain from langchain.llms.base import BaseLLM from langchain.llms.openai import OpenAI from langchain.tools.base import BaseTool @@ -45,12 +44,14 @@ class VectorStoreQATool(BaseVectorStoreTool, BaseTool): def _run(self, query: str) -> str: """Use the tool.""" - chain = VectorDBQA.from_chain_type(self.llm, vectorstore=self.vectorstore) + chain = RetrievalQA.from_chain_type( + self.llm, retriever=self.vectorstore.as_retriever() + ) return chain.run(query) async def _arun(self, query: str) -> str: """Use the tool asynchronously.""" - raise NotImplementedError("VectorDBQATool does not support async") + raise NotImplementedError("VectorStoreQATool does not support async") class VectorStoreQAWithSourcesTool(BaseVectorStoreTool, BaseTool): @@ -71,11 +72,11 @@ class VectorStoreQAWithSourcesTool(BaseVectorStoreTool, BaseTool): def _run(self, query: str) -> str: """Use the tool.""" - chain = VectorDBQAWithSourcesChain.from_chain_type( - self.llm, vectorstore=self.vectorstore + chain = RetrievalQAWithSourcesChain.from_chain_type( + self.llm, retriever=self.vectorstore.as_retriever() ) return json.dumps(chain({chain.question_key: query}, return_only_outputs=True)) async def _arun(self, query: str) -> str: """Use the tool asynchronously.""" - raise NotImplementedError("VectorDBQATool does not support async") + raise NotImplementedError("VectorStoreQAWithSourcesTool does not support async")