diff --git a/docs/examples/demos/vector_db_qa.ipynb b/docs/examples/demos/vector_db_qa.ipynb index bf1fcbf9..c69d9624 100644 --- a/docs/examples/demos/vector_db_qa.ipynb +++ b/docs/examples/demos/vector_db_qa.ipynb @@ -25,8 +25,8 @@ }, { "cell_type": "code", - "execution_count": 3, - "id": "5c7049db", + "execution_count": 2, + "id": "0d71cf4f", "metadata": {}, "outputs": [], "source": [ @@ -35,8 +35,17 @@ "text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)\n", "texts = text_splitter.split_text(state_of_the_union)\n", "\n", - "embeddings = OpenAIEmbeddings()\n", - "docsearch = FAISS.from_texts(texts, embeddings)" + "embeddings = OpenAIEmbeddings()" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "5c7049db", + "metadata": {}, + "outputs": [], + "source": [ + "docsearch = FAISS.from_texts(texts, embeddings, metadatas=[{\"source\": f\"www.{i}.com\"} for i in range(len(texts))])" ] }, { @@ -58,7 +67,7 @@ { "data": { "text/plain": [ - "' The President said that Ketanji Brown Jackson is a consensus builder and has received a broad range of support since she was nominated.'" + "' The President said that Judge Ketanji Brown Jackson is \"One of our nation’s top legal minds, who will continue Justice Breyer’s legacy of excellence.\"'" ] }, "execution_count": 5, @@ -96,7 +105,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.6" + "version": "3.8.7" } }, "nbformat": 4, diff --git a/langchain/chains/vector_db_qa/base.py b/langchain/chains/vector_db_qa/base.py index 3e010710..a13746bb 100644 --- a/langchain/chains/vector_db_qa/base.py +++ b/langchain/chains/vector_db_qa/base.py @@ -58,7 +58,10 @@ class VectorDBQA(Chain, BaseModel): docs = self.vectorstore.similarity_search(question) contexts = [] for j, doc in enumerate(docs): - contexts.append(f"Context {j}:\n{doc.page_content}") + context_str = f"Context {j}:\n{doc.page_content}" + if doc.metadata is not None: + context_str += f"\nSource: {doc.metadata['source']}" + contexts.append(context_str) # TODO: handle cases where this context is too long. answer = llm_chain.predict(question=question, context="\n\n".join(contexts)) return {self.output_key: answer} diff --git a/langchain/chains/vector_db_qa/prompt.py b/langchain/chains/vector_db_qa/prompt.py index 54c4d7f6..feda2a30 100644 --- a/langchain/chains/vector_db_qa/prompt.py +++ b/langchain/chains/vector_db_qa/prompt.py @@ -1,7 +1,12 @@ # flake8: noqa from langchain.prompts import Prompt -prompt_template = """Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer. +prompt_template = """Use the following pieces of context to answer the question at the end. +If you don't know the answer, just say that you don't know, don't try to make up an answer. +If a source is provided for a particular piece of context, please cite that source when repsonding. + +For example: +"Answer goes here..." - [source goes here] {context} diff --git a/langchain/docstore/document.py b/langchain/docstore/document.py index 2c6e04bb..db1588a5 100644 --- a/langchain/docstore/document.py +++ b/langchain/docstore/document.py @@ -1,5 +1,5 @@ """Interface for interacting with a document.""" -from typing import List +from typing import List, Optional from pydantic import BaseModel @@ -10,6 +10,7 @@ class Document(BaseModel): page_content: str lookup_str: str = "" lookup_index = 0 + metadata: Optional[dict] = None @property def paragraphs(self) -> List[str]: diff --git a/langchain/vectorstores/faiss.py b/langchain/vectorstores/faiss.py index 937ad80e..b95b9ae3 100644 --- a/langchain/vectorstores/faiss.py +++ b/langchain/vectorstores/faiss.py @@ -1,5 +1,5 @@ """Wrapper around FAISS vector database.""" -from typing import Any, Callable, List +from typing import Any, Callable, List, Optional import numpy as np @@ -54,7 +54,7 @@ class FAISS(VectorStore): @classmethod def from_texts( - cls, texts: List[str], embedding: Embeddings, **kwargs: Any + cls, texts: List[str], embedding: Embeddings, metadatas: Optional[List[dict]] = None, **kwargs: Any ) -> "FAISS": """Construct FAISS wrapper from raw documents. @@ -84,6 +84,8 @@ class FAISS(VectorStore): embeddings = embedding.embed_documents(texts) index = faiss.IndexFlatL2(len(embeddings[0])) index.add(np.array(embeddings, dtype=np.float32)) - documents = [Document(page_content=text) for text in texts] + if metadatas is None: + metadatas = [None] * len(texts) + documents = [Document(page_content=text, metadata=metadatas[i]) for i, text in enumerate(texts)] docstore = InMemoryDocstore({str(i): doc for i, doc in enumerate(documents)}) return cls(embedding.embed_query, index, docstore)