forked from Archives/langchain
You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
106 lines
3.6 KiB
Python
106 lines
3.6 KiB
Python
"""Chain for question-answering against a vector database."""
|
|
from __future__ import annotations
|
|
|
|
from typing import Any, Dict, List
|
|
|
|
from pydantic import BaseModel, Extra, root_validator
|
|
|
|
from langchain.chains.base import Chain
|
|
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
|
|
from langchain.chains.combine_documents.stuff import StuffDocumentsChain
|
|
from langchain.chains.llm import LLMChain
|
|
from langchain.chains.vector_db_qa.prompt import PROMPT
|
|
from langchain.llms.base import BaseLLM
|
|
from langchain.prompts import PromptTemplate
|
|
from langchain.vectorstores.base import VectorStore
|
|
|
|
|
|
class VectorDBQA(Chain, BaseModel):
|
|
"""Chain for question-answering against a vector database.
|
|
|
|
Example:
|
|
.. code-block:: python
|
|
|
|
from langchain import OpenAI, VectorDBQA
|
|
from langchain.faiss import FAISS
|
|
vectordb = FAISS(...)
|
|
vectordbQA = VectorDBQA(llm=OpenAI(), vector_db=vectordb)
|
|
|
|
"""
|
|
|
|
vectorstore: VectorStore
|
|
"""Vector Database to connect to."""
|
|
k: int = 4
|
|
"""Number of documents to query for."""
|
|
combine_documents_chain: BaseCombineDocumentsChain
|
|
"""Chain to use to combine the documents."""
|
|
input_key: str = "query" #: :meta private:
|
|
output_key: str = "result" #: :meta private:
|
|
|
|
class Config:
|
|
"""Configuration for this pydantic object."""
|
|
|
|
extra = Extra.forbid
|
|
arbitrary_types_allowed = True
|
|
|
|
@property
|
|
def input_keys(self) -> List[str]:
|
|
"""Return the singular input key.
|
|
|
|
:meta private:
|
|
"""
|
|
return [self.input_key]
|
|
|
|
@property
|
|
def output_keys(self) -> List[str]:
|
|
"""Return the singular output key.
|
|
|
|
:meta private:
|
|
"""
|
|
return [self.output_key]
|
|
|
|
# TODO: deprecate this
|
|
@root_validator(pre=True)
|
|
def load_combine_documents_chain(cls, values: Dict) -> Dict:
|
|
"""Validate question chain."""
|
|
if "combine_documents_chain" not in values:
|
|
if "llm" not in values:
|
|
raise ValueError(
|
|
"If `combine_documents_chain` not provided, `llm` should be."
|
|
)
|
|
prompt = values.pop("prompt", PROMPT)
|
|
llm = values.pop("llm")
|
|
llm_chain = LLMChain(llm=llm, prompt=prompt)
|
|
document_prompt = PromptTemplate(
|
|
input_variables=["page_content"], template="Context:\n{page_content}"
|
|
)
|
|
combine_documents_chain = StuffDocumentsChain(
|
|
llm_chain=llm_chain,
|
|
document_variable_name="context",
|
|
document_prompt=document_prompt,
|
|
)
|
|
values["combine_documents_chain"] = combine_documents_chain
|
|
return values
|
|
|
|
@classmethod
|
|
def from_llm(
|
|
cls, llm: BaseLLM, prompt: PromptTemplate = PROMPT, **kwargs: Any
|
|
) -> VectorDBQA:
|
|
"""Initialize from LLM."""
|
|
llm_chain = LLMChain(llm=llm, prompt=prompt)
|
|
document_prompt = PromptTemplate(
|
|
input_variables=["page_content"], template="Context:\n{page_content}"
|
|
)
|
|
combine_documents_chain = StuffDocumentsChain(
|
|
llm_chain=llm_chain,
|
|
document_variable_name="context",
|
|
document_prompt=document_prompt,
|
|
)
|
|
return cls(combine_documents_chain=combine_documents_chain, **kwargs)
|
|
|
|
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
|
|
question = inputs[self.input_key]
|
|
docs = self.vectorstore.similarity_search(question, k=self.k)
|
|
answer = self.combine_documents_chain.combine_docs(docs, question=question)
|
|
return {self.output_key: answer}
|