forked from Archives/langchain
You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
82 lines
3.0 KiB
Python
82 lines
3.0 KiB
Python
"""Tools for interacting with vectorstores."""
|
|
|
|
import json
|
|
from typing import Any, Dict
|
|
|
|
from pydantic import BaseModel, Field
|
|
|
|
from langchain.chains.qa_with_sources.vector_db import VectorDBQAWithSourcesChain
|
|
from langchain.chains.vector_db_qa.base import VectorDBQA
|
|
from langchain.llms.base import BaseLLM
|
|
from langchain.llms.openai import OpenAI
|
|
from langchain.tools.base import BaseTool
|
|
from langchain.vectorstores.base import VectorStore
|
|
|
|
|
|
class BaseVectorStoreTool(BaseModel):
|
|
"""Base class for tools that use a VectorStore."""
|
|
|
|
vectorstore: VectorStore = Field(exclude=True)
|
|
llm: BaseLLM = Field(default_factory=lambda: OpenAI(temperature=0))
|
|
|
|
class Config(BaseTool.Config):
|
|
"""Configuration for this pydantic object."""
|
|
|
|
arbitrary_types_allowed = True
|
|
|
|
|
|
def _create_description_from_template(values: Dict[str, Any]) -> Dict[str, Any]:
|
|
values["description"] = values["template"].format(name=values["name"])
|
|
return values
|
|
|
|
|
|
class VectorStoreQATool(BaseVectorStoreTool, BaseTool):
|
|
"""Tool for the VectorDBQA chain. To be initialized with name and chain."""
|
|
|
|
@staticmethod
|
|
def get_description(name: str, description: str) -> str:
|
|
template: str = (
|
|
"Useful for when you need to answer questions about {name}. "
|
|
"Whenever you need information about {description} "
|
|
"you should ALWAYS use this. "
|
|
"Input should be a fully formed question."
|
|
)
|
|
return template.format(name=name, description=description)
|
|
|
|
def _run(self, query: str) -> str:
|
|
"""Use the tool."""
|
|
chain = VectorDBQA.from_chain_type(self.llm, vectorstore=self.vectorstore)
|
|
return chain.run(query)
|
|
|
|
async def _arun(self, query: str) -> str:
|
|
"""Use the tool asynchronously."""
|
|
raise NotImplementedError("VectorDBQATool does not support async")
|
|
|
|
|
|
class VectorStoreQAWithSourcesTool(BaseVectorStoreTool, BaseTool):
|
|
"""Tool for the VectorDBQAWithSources chain."""
|
|
|
|
@staticmethod
|
|
def get_description(name: str, description: str) -> str:
|
|
template: str = (
|
|
"Useful for when you need to answer questions about {name} and the sources "
|
|
"used to construct the answer. "
|
|
"Whenever you need information about {description} "
|
|
"you should ALWAYS use this. "
|
|
" Input should be a fully formed question. "
|
|
"Output is a json serialized dictionary with keys `answer` and `sources`. "
|
|
"Only use this tool if the user explicitly asks for sources."
|
|
)
|
|
return template.format(name=name, description=description)
|
|
|
|
def _run(self, query: str) -> str:
|
|
"""Use the tool."""
|
|
chain = VectorDBQAWithSourcesChain.from_chain_type(
|
|
self.llm, vectorstore=self.vectorstore
|
|
)
|
|
return json.dumps(chain({chain.question_key: query}, return_only_outputs=True))
|
|
|
|
async def _arun(self, query: str) -> str:
|
|
"""Use the tool asynchronously."""
|
|
raise NotImplementedError("VectorDBQATool does not support async")
|