mirror of https://github.com/hwchase17/langchain
vector db qa (#71)
parent
4c0b684f79
commit
47af2bcee4
@ -0,0 +1 @@
|
||||
"""Chain for question-answering against a vector database."""
|
@ -0,0 +1,80 @@
|
||||
"""Chain for question-answering against a vector database."""
|
||||
from typing import Dict, List
|
||||
|
||||
from pydantic import BaseModel, Extra
|
||||
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.chains.vector_db_qa.prompt import prompt
|
||||
from langchain.llms.base import LLM
|
||||
from langchain.vectorstores.base import VectorStore
|
||||
|
||||
|
||||
class VectorDBQA(Chain, BaseModel):
|
||||
"""Chain for question-answering against a vector database.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain import OpenAI, VectorDBQA
|
||||
from langchain.faiss import FAISS
|
||||
vectordb = FAISS(...)
|
||||
vectordbQA = VectorDBQA(llm=OpenAI(), vector_db=vectordb)
|
||||
|
||||
"""
|
||||
|
||||
llm: LLM
|
||||
"""LLM wrapper to use."""
|
||||
vectorstore: VectorStore
|
||||
"""Vector Database to connect to."""
|
||||
input_key: str = "query" #: :meta private:
|
||||
output_key: str = "result" #: :meta private:
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
extra = Extra.forbid
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
"""Return the singular input key.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
return [self.input_key]
|
||||
|
||||
@property
|
||||
def output_keys(self) -> List[str]:
|
||||
"""Return the singular output key.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
return [self.output_key]
|
||||
|
||||
def _run(self, inputs: Dict[str, str]) -> Dict[str, str]:
|
||||
question = inputs[self.input_key]
|
||||
llm_chain = LLMChain(llm=self.llm, prompt=prompt)
|
||||
docs = self.vectorstore.similarity_search(question)
|
||||
contexts = []
|
||||
for j, doc in enumerate(docs):
|
||||
contexts.append(f"Context {j}:\n{doc.page_content}")
|
||||
# TODO: handle cases where this context is too long.
|
||||
answer = llm_chain.predict(question=question, context="\n\n".join(contexts))
|
||||
return {self.output_key: answer}
|
||||
|
||||
def run(self, question: str) -> str:
|
||||
"""Run Question-Answering on a vector database.
|
||||
|
||||
Args:
|
||||
question: Question to get the answer for.
|
||||
|
||||
Returns:
|
||||
The final answer
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
answer = vectordbqa.run("What is the capital of Idaho?")
|
||||
"""
|
||||
return self({self.input_key: question})[self.output_key]
|
@ -0,0 +1,10 @@
|
||||
# flake8: noqa
|
||||
from langchain.prompts import Prompt
|
||||
|
||||
prompt_template = """Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer.
|
||||
|
||||
{context}
|
||||
|
||||
Question: {question}
|
||||
Helpful Answer:"""
|
||||
prompt = Prompt(template=prompt_template, input_variables=["context", "question"])
|
Loading…
Reference in New Issue