forked from Archives/langchain
You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
175 lines
5.8 KiB
Python
175 lines
5.8 KiB
Python
"""Chain for question-answering against a vector database."""
|
|
from __future__ import annotations
|
|
|
|
from abc import abstractmethod
|
|
from typing import Any, Dict, List, Optional
|
|
|
|
from pydantic import BaseModel, Extra, Field, root_validator
|
|
|
|
from langchain.chains.base import Chain
|
|
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
|
|
from langchain.chains.combine_documents.stuff import StuffDocumentsChain
|
|
from langchain.chains.llm import LLMChain
|
|
from langchain.chains.question_answering import load_qa_chain
|
|
from langchain.chains.question_answering.stuff_prompt import PROMPT_SELECTOR
|
|
from langchain.prompts import PromptTemplate
|
|
from langchain.schema import BaseLanguageModel, BaseRetriever, Document
|
|
from langchain.vectorstores.base import VectorStore
|
|
|
|
|
|
class BaseRetrievalQA(Chain, BaseModel):
|
|
combine_documents_chain: BaseCombineDocumentsChain
|
|
"""Chain to use to combine the documents."""
|
|
input_key: str = "query" #: :meta private:
|
|
output_key: str = "result" #: :meta private:
|
|
return_source_documents: bool = False
|
|
"""Return the source documents."""
|
|
|
|
class Config:
|
|
"""Configuration for this pydantic object."""
|
|
|
|
extra = Extra.forbid
|
|
arbitrary_types_allowed = True
|
|
allow_population_by_field_name = True
|
|
|
|
@property
|
|
def input_keys(self) -> List[str]:
|
|
"""Return the input keys.
|
|
|
|
:meta private:
|
|
"""
|
|
return [self.input_key]
|
|
|
|
@property
|
|
def output_keys(self) -> List[str]:
|
|
"""Return the output keys.
|
|
|
|
:meta private:
|
|
"""
|
|
_output_keys = [self.output_key]
|
|
if self.return_source_documents:
|
|
_output_keys = _output_keys + ["source_documents"]
|
|
return _output_keys
|
|
|
|
@classmethod
|
|
def from_llm(
|
|
cls,
|
|
llm: BaseLanguageModel,
|
|
prompt: Optional[PromptTemplate] = None,
|
|
**kwargs: Any,
|
|
) -> BaseRetrievalQA:
|
|
"""Initialize from LLM."""
|
|
_prompt = prompt or PROMPT_SELECTOR.get_prompt(llm)
|
|
llm_chain = LLMChain(llm=llm, prompt=_prompt)
|
|
document_prompt = PromptTemplate(
|
|
input_variables=["page_content"], template="Context:\n{page_content}"
|
|
)
|
|
combine_documents_chain = StuffDocumentsChain(
|
|
llm_chain=llm_chain,
|
|
document_variable_name="context",
|
|
document_prompt=document_prompt,
|
|
)
|
|
|
|
return cls(combine_documents_chain=combine_documents_chain, **kwargs)
|
|
|
|
@classmethod
|
|
def from_chain_type(
|
|
cls,
|
|
llm: BaseLanguageModel,
|
|
chain_type: str = "stuff",
|
|
chain_type_kwargs: Optional[dict] = None,
|
|
**kwargs: Any,
|
|
) -> BaseRetrievalQA:
|
|
"""Load chain from chain type."""
|
|
_chain_type_kwargs = chain_type_kwargs or {}
|
|
combine_documents_chain = load_qa_chain(
|
|
llm, chain_type=chain_type, **_chain_type_kwargs
|
|
)
|
|
return cls(combine_documents_chain=combine_documents_chain, **kwargs)
|
|
|
|
@abstractmethod
|
|
def _get_docs(self, question: str) -> List[Document]:
|
|
"""Get documents to do question answering over."""
|
|
|
|
def _call(self, inputs: Dict[str, str]) -> Dict[str, Any]:
|
|
"""Run get_relevant_text and llm on input query.
|
|
|
|
If chain has 'return_source_documents' as 'True', returns
|
|
the retrieved documents as well under the key 'source_documents'.
|
|
|
|
Example:
|
|
.. code-block:: python
|
|
|
|
res = indexqa({'query': 'This is my query'})
|
|
answer, docs = res['result'], res['source_documents']
|
|
"""
|
|
question = inputs[self.input_key]
|
|
|
|
docs = self._get_docs(question)
|
|
answer, _ = self.combine_documents_chain.combine_docs(docs, question=question)
|
|
|
|
if self.return_source_documents:
|
|
return {self.output_key: answer, "source_documents": docs}
|
|
else:
|
|
return {self.output_key: answer}
|
|
|
|
|
|
class RetrievalQA(BaseRetrievalQA, BaseModel):
|
|
"""Chain for question-answering against an index.
|
|
|
|
Example:
|
|
.. code-block:: python
|
|
|
|
from langchain.llms import OpenAI
|
|
from langchain.chains import RetrievalQA
|
|
from langchain.faiss import FAISS
|
|
vectordb = FAISS(...)
|
|
retrievalQA = RetrievalQA.from_llm(llm=OpenAI(), retriever=vectordb)
|
|
|
|
"""
|
|
|
|
retriever: BaseRetriever = Field(exclude=True)
|
|
|
|
def _get_docs(self, question: str) -> List[Document]:
|
|
return self.retriever.get_relevant_texts(question)
|
|
|
|
|
|
class VectorDBQA(BaseRetrievalQA, BaseModel):
|
|
"""Chain for question-answering against a vector database."""
|
|
|
|
vectorstore: VectorStore = Field(exclude=True, alias="vectorstore")
|
|
"""Vector Database to connect to."""
|
|
k: int = 4
|
|
"""Number of documents to query for."""
|
|
search_type: str = "similarity"
|
|
"""Search type to use over vectorstore. `similarity` or `mmr`."""
|
|
search_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
|
"""Extra search args."""
|
|
|
|
@root_validator()
|
|
def validate_search_type(cls, values: Dict) -> Dict:
|
|
"""Validate search type."""
|
|
if "search_type" in values:
|
|
search_type = values["search_type"]
|
|
if search_type not in ("similarity", "mmr"):
|
|
raise ValueError(f"search_type of {search_type} not allowed.")
|
|
return values
|
|
|
|
def _get_docs(self, question: str) -> List[Document]:
|
|
if self.search_type == "similarity":
|
|
docs = self.vectorstore.similarity_search(
|
|
question, k=self.k, **self.search_kwargs
|
|
)
|
|
elif self.search_type == "mmr":
|
|
docs = self.vectorstore.max_marginal_relevance_search(
|
|
question, k=self.k, **self.search_kwargs
|
|
)
|
|
else:
|
|
raise ValueError(f"search_type of {self.search_type} not allowed.")
|
|
return docs
|
|
|
|
@property
|
|
def _chain_type(self) -> str:
|
|
"""Return the chain type."""
|
|
return "vector_db_qa"
|