From b3a8fc7cb17170b8a272a0a1366b0f18d8f7ab4a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Markus=20Tretzm=C3=BCller?= Date: Sat, 9 Sep 2023 01:57:10 +0200 Subject: [PATCH] enable serde retrieval qa with sources (#10132) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit #3983 mentions serialization/deserialization issues with both `RetrievalQA` & `RetrievalQAWithSourcesChain`. `RetrievalQA` has already been fixed in #5818. Mimicing #5818, I added the logic for `RetrievalQAWithSourcesChain`. --------- Co-authored-by: Markus Tretzmüller Co-authored-by: Bagatur --- libs/langchain/langchain/chains/loading.py | 26 +++++++++++++++++ .../chains/qa_with_sources/retrieval.py | 5 ++++ .../chains/test_retrieval_qa_with_sources.py | 28 +++++++++++++++++++ 3 files changed, 59 insertions(+) create mode 100644 libs/langchain/tests/integration_tests/chains/test_retrieval_qa_with_sources.py diff --git a/libs/langchain/langchain/chains/loading.py b/libs/langchain/langchain/chains/loading.py index c2e0b81397..9543f62988 100644 --- a/libs/langchain/langchain/chains/loading.py +++ b/libs/langchain/langchain/chains/loading.py @@ -20,6 +20,7 @@ from langchain.chains.llm_checker.base import LLMCheckerChain from langchain.chains.llm_math.base import LLMMathChain from langchain.chains.llm_requests import LLMRequestsChain from langchain.chains.qa_with_sources.base import QAWithSourcesChain +from langchain.chains.qa_with_sources.retrieval import RetrievalQAWithSourcesChain from langchain.chains.qa_with_sources.vector_db import VectorDBQAWithSourcesChain from langchain.chains.retrieval_qa.base import RetrievalQA, VectorDBQA from langchain.llms.loading import load_llm, load_llm_from_config @@ -424,6 +425,30 @@ def _load_retrieval_qa(config: dict, **kwargs: Any) -> RetrievalQA: ) +def _load_retrieval_qa_with_sources_chain( + config: dict, **kwargs: Any +) -> RetrievalQAWithSourcesChain: + if "retriever" in kwargs: + retriever = kwargs.pop("retriever") + else: + raise ValueError("`retriever` must be present.") + if "combine_documents_chain" in config: + combine_documents_chain_config = config.pop("combine_documents_chain") + combine_documents_chain = load_chain_from_config(combine_documents_chain_config) + elif "combine_documents_chain_path" in config: + combine_documents_chain = load_chain(config.pop("combine_documents_chain_path")) + else: + raise ValueError( + "One of `combine_documents_chain` or " + "`combine_documents_chain_path` must be present." + ) + return RetrievalQAWithSourcesChain( + combine_documents_chain=combine_documents_chain, + retriever=retriever, + **config, + ) + + def _load_vector_db_qa(config: dict, **kwargs: Any) -> VectorDBQA: if "vectorstore" in kwargs: vectorstore = kwargs.pop("vectorstore") @@ -537,6 +562,7 @@ type_to_loader_dict = { "vector_db_qa_with_sources_chain": _load_vector_db_qa_with_sources_chain, "vector_db_qa": _load_vector_db_qa, "retrieval_qa": _load_retrieval_qa, + "retrieval_qa_with_sources_chain": _load_retrieval_qa_with_sources_chain, "graph_cypher_chain": _load_graph_cypher_chain, } diff --git a/libs/langchain/langchain/chains/qa_with_sources/retrieval.py b/libs/langchain/langchain/chains/qa_with_sources/retrieval.py index c5d587b464..80018950d9 100644 --- a/libs/langchain/langchain/chains/qa_with_sources/retrieval.py +++ b/libs/langchain/langchain/chains/qa_with_sources/retrieval.py @@ -60,3 +60,8 @@ class RetrievalQAWithSourcesChain(BaseQAWithSourcesChain): question, callbacks=run_manager.get_child() ) return self._reduce_tokens_below_limit(docs) + + @property + def _chain_type(self) -> str: + """Return the chain type.""" + return "retrieval_qa_with_sources_chain" diff --git a/libs/langchain/tests/integration_tests/chains/test_retrieval_qa_with_sources.py b/libs/langchain/tests/integration_tests/chains/test_retrieval_qa_with_sources.py new file mode 100644 index 0000000000..70ee98513e --- /dev/null +++ b/libs/langchain/tests/integration_tests/chains/test_retrieval_qa_with_sources.py @@ -0,0 +1,28 @@ +"""Test RetrievalQA functionality.""" +from langchain.chains import RetrievalQAWithSourcesChain +from langchain.chains.loading import load_chain +from langchain.document_loaders import DirectoryLoader +from langchain.embeddings.openai import OpenAIEmbeddings +from langchain.llms import OpenAI +from langchain.text_splitter import CharacterTextSplitter +from langchain.vectorstores import FAISS + + +def test_retrieval_qa_with_sources_chain_saving_loading(tmp_path: str) -> None: + """Test saving and loading.""" + loader = DirectoryLoader("docs/extras/modules/", glob="*.txt") + documents = loader.load() + text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0) + texts = text_splitter.split_documents(documents) + embeddings = OpenAIEmbeddings() + docsearch = FAISS.from_documents(texts, embeddings) + qa = RetrievalQAWithSourcesChain.from_llm( + llm=OpenAI(), retriever=docsearch.as_retriever() + ) + qa("What did the president say about Ketanji Brown Jackson?") + + file_path = tmp_path + "/RetrievalQAWithSourcesChain.yaml" + qa.save(file_path=file_path) + qa_loaded = load_chain(file_path, retriever=docsearch.as_retriever()) + + assert qa_loaded == qa