|
|
|
@ -5,8 +5,7 @@ from typing import Any, Dict
|
|
|
|
|
|
|
|
|
|
from pydantic import BaseModel, Field
|
|
|
|
|
|
|
|
|
|
from langchain.chains.qa_with_sources.vector_db import VectorDBQAWithSourcesChain
|
|
|
|
|
from langchain.chains.retrieval_qa.base import VectorDBQA
|
|
|
|
|
from langchain.chains import RetrievalQA, RetrievalQAWithSourcesChain
|
|
|
|
|
from langchain.llms.base import BaseLLM
|
|
|
|
|
from langchain.llms.openai import OpenAI
|
|
|
|
|
from langchain.tools.base import BaseTool
|
|
|
|
@ -45,12 +44,14 @@ class VectorStoreQATool(BaseVectorStoreTool, BaseTool):
|
|
|
|
|
|
|
|
|
|
def _run(self, query: str) -> str:
|
|
|
|
|
"""Use the tool."""
|
|
|
|
|
chain = VectorDBQA.from_chain_type(self.llm, vectorstore=self.vectorstore)
|
|
|
|
|
chain = RetrievalQA.from_chain_type(
|
|
|
|
|
self.llm, retriever=self.vectorstore.as_retriever()
|
|
|
|
|
)
|
|
|
|
|
return chain.run(query)
|
|
|
|
|
|
|
|
|
|
async def _arun(self, query: str) -> str:
|
|
|
|
|
"""Use the tool asynchronously."""
|
|
|
|
|
raise NotImplementedError("VectorDBQATool does not support async")
|
|
|
|
|
raise NotImplementedError("VectorStoreQATool does not support async")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class VectorStoreQAWithSourcesTool(BaseVectorStoreTool, BaseTool):
|
|
|
|
@ -71,11 +72,11 @@ class VectorStoreQAWithSourcesTool(BaseVectorStoreTool, BaseTool):
|
|
|
|
|
|
|
|
|
|
def _run(self, query: str) -> str:
|
|
|
|
|
"""Use the tool."""
|
|
|
|
|
chain = VectorDBQAWithSourcesChain.from_chain_type(
|
|
|
|
|
self.llm, vectorstore=self.vectorstore
|
|
|
|
|
chain = RetrievalQAWithSourcesChain.from_chain_type(
|
|
|
|
|
self.llm, retriever=self.vectorstore.as_retriever()
|
|
|
|
|
)
|
|
|
|
|
return json.dumps(chain({chain.question_key: query}, return_only_outputs=True))
|
|
|
|
|
|
|
|
|
|
async def _arun(self, query: str) -> str:
|
|
|
|
|
"""Use the tool asynchronously."""
|
|
|
|
|
raise NotImplementedError("VectorDBQATool does not support async")
|
|
|
|
|
raise NotImplementedError("VectorStoreQAWithSourcesTool does not support async")
|
|
|
|
|