From 47af2bcee45f2c0b8017b2702b24d6f99a5f55a9 Mon Sep 17 00:00:00 2001 From: Harrison Chase Date: Sat, 12 Nov 2022 07:24:49 -0800 Subject: [PATCH] vector db qa (#71) --- examples/vector_db_qa.ipynb | 94 +++++++++++++++++++++++ langchain/__init__.py | 2 + langchain/chains/__init__.py | 2 + langchain/chains/vector_db_qa/__init__.py | 1 + langchain/chains/vector_db_qa/base.py | 80 +++++++++++++++++++ langchain/chains/vector_db_qa/prompt.py | 10 +++ 6 files changed, 189 insertions(+) create mode 100644 examples/vector_db_qa.ipynb create mode 100644 langchain/chains/vector_db_qa/__init__.py create mode 100644 langchain/chains/vector_db_qa/base.py create mode 100644 langchain/chains/vector_db_qa/prompt.py diff --git a/examples/vector_db_qa.ipynb b/examples/vector_db_qa.ipynb new file mode 100644 index 00000000..c6ec0a57 --- /dev/null +++ b/examples/vector_db_qa.ipynb @@ -0,0 +1,94 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "82525493", + "metadata": {}, + "outputs": [], + "source": [ + "from langchain.embeddings.openai import OpenAIEmbeddings\n", + "from langchain.vectorstores.faiss import FAISS\n", + "from langchain.text_splitter import CharacterTextSplitter\n", + "from langchain import OpenAI, VectorDBQA" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "5c7049db", + "metadata": {}, + "outputs": [], + "source": [ + "with open('state_of_the_union.txt') as f:\n", + " state_of_the_union = f.read()\n", + "text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)\n", + "texts = text_splitter.split_text(state_of_the_union)\n", + "\n", + "embeddings = OpenAIEmbeddings()\n", + "docsearch = FAISS.from_texts(texts, embeddings)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "3018f865", + "metadata": {}, + "outputs": [], + "source": [ + "qa = VectorDBQA(llm=OpenAI(), vectorstore=docsearch)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "032a47f8", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "\" The president said that Ketanji Brown Jackson is one of our nation's top legal minds, who will continue Justice Breyer’s legacy of excellence.\"" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "query = \"What did the president say about Ketanji Brown Jackson\"\n", + "qa.run(query)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f0f20b92", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.7" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/langchain/__init__.py b/langchain/__init__.py index 50113af2..4ccd7dc5 100644 --- a/langchain/__init__.py +++ b/langchain/__init__.py @@ -14,6 +14,7 @@ from langchain.chains import ( SelfAskWithSearchChain, SerpAPIChain, SQLDatabaseChain, + VectorDBQA, ) from langchain.docstore import Wikipedia from langchain.llms import Cohere, HuggingFaceHub, OpenAI @@ -39,5 +40,6 @@ __all__ = [ "SQLDatabaseChain", "FAISS", "MRKLChain", + "VectorDBQA", "ElasticVectorSearch", ] diff --git a/langchain/chains/__init__.py b/langchain/chains/__init__.py index 45fb20cb..ae27d37e 100644 --- a/langchain/chains/__init__.py +++ b/langchain/chains/__init__.py @@ -7,6 +7,7 @@ from langchain.chains.react.base import ReActChain from langchain.chains.self_ask_with_search.base import SelfAskWithSearchChain from langchain.chains.serpapi import SerpAPIChain from langchain.chains.sql_database.base import SQLDatabaseChain +from langchain.chains.vector_db_qa.base import VectorDBQA __all__ = [ "LLMChain", @@ -17,4 +18,5 @@ __all__ = [ "ReActChain", "SQLDatabaseChain", "MRKLChain", + "VectorDBQA", ] diff --git a/langchain/chains/vector_db_qa/__init__.py b/langchain/chains/vector_db_qa/__init__.py new file mode 100644 index 00000000..b8e4d9aa --- /dev/null +++ b/langchain/chains/vector_db_qa/__init__.py @@ -0,0 +1 @@ +"""Chain for question-answering against a vector database.""" diff --git a/langchain/chains/vector_db_qa/base.py b/langchain/chains/vector_db_qa/base.py new file mode 100644 index 00000000..208d7f55 --- /dev/null +++ b/langchain/chains/vector_db_qa/base.py @@ -0,0 +1,80 @@ +"""Chain for question-answering against a vector database.""" +from typing import Dict, List + +from pydantic import BaseModel, Extra + +from langchain.chains.base import Chain +from langchain.chains.llm import LLMChain +from langchain.chains.vector_db_qa.prompt import prompt +from langchain.llms.base import LLM +from langchain.vectorstores.base import VectorStore + + +class VectorDBQA(Chain, BaseModel): + """Chain for question-answering against a vector database. + + Example: + .. code-block:: python + + from langchain import OpenAI, VectorDBQA + from langchain.faiss import FAISS + vectordb = FAISS(...) + vectordbQA = VectorDBQA(llm=OpenAI(), vector_db=vectordb) + + """ + + llm: LLM + """LLM wrapper to use.""" + vectorstore: VectorStore + """Vector Database to connect to.""" + input_key: str = "query" #: :meta private: + output_key: str = "result" #: :meta private: + + class Config: + """Configuration for this pydantic object.""" + + extra = Extra.forbid + arbitrary_types_allowed = True + + @property + def input_keys(self) -> List[str]: + """Return the singular input key. + + :meta private: + """ + return [self.input_key] + + @property + def output_keys(self) -> List[str]: + """Return the singular output key. + + :meta private: + """ + return [self.output_key] + + def _run(self, inputs: Dict[str, str]) -> Dict[str, str]: + question = inputs[self.input_key] + llm_chain = LLMChain(llm=self.llm, prompt=prompt) + docs = self.vectorstore.similarity_search(question) + contexts = [] + for j, doc in enumerate(docs): + contexts.append(f"Context {j}:\n{doc.page_content}") + # TODO: handle cases where this context is too long. + answer = llm_chain.predict(question=question, context="\n\n".join(contexts)) + return {self.output_key: answer} + + def run(self, question: str) -> str: + """Run Question-Answering on a vector database. + + Args: + question: Question to get the answer for. + + Returns: + The final answer + + Example: + .. code-block:: python + + answer = vectordbqa.run("What is the capital of Idaho?") + """ + return self({self.input_key: question})[self.output_key] diff --git a/langchain/chains/vector_db_qa/prompt.py b/langchain/chains/vector_db_qa/prompt.py new file mode 100644 index 00000000..54c4d7f6 --- /dev/null +++ b/langchain/chains/vector_db_qa/prompt.py @@ -0,0 +1,10 @@ +# flake8: noqa +from langchain.prompts import Prompt + +prompt_template = """Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer. + +{context} + +Question: {question} +Helpful Answer:""" +prompt = Prompt(template=prompt_template, input_variables=["context", "question"])