enable serde retrieval qa with sources (#10132)

#3983 mentions serialization/deserialization issues with both
`RetrievalQA` & `RetrievalQAWithSourcesChain`.
`RetrievalQA` has already been fixed in #5818. 

Mimicing #5818, I added the logic for `RetrievalQAWithSourcesChain`.

---------

Co-authored-by: Markus Tretzmüller <markus.tretzmueller@cortecs.at>
Co-authored-by: Bagatur <baskaryan@gmail.com>
pull/10291/head
Markus Tretzmüller 12 months ago committed by GitHub
parent 62fa2bc518
commit b3a8fc7cb1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -20,6 +20,7 @@ from langchain.chains.llm_checker.base import LLMCheckerChain
from langchain.chains.llm_math.base import LLMMathChain
from langchain.chains.llm_requests import LLMRequestsChain
from langchain.chains.qa_with_sources.base import QAWithSourcesChain
from langchain.chains.qa_with_sources.retrieval import RetrievalQAWithSourcesChain
from langchain.chains.qa_with_sources.vector_db import VectorDBQAWithSourcesChain
from langchain.chains.retrieval_qa.base import RetrievalQA, VectorDBQA
from langchain.llms.loading import load_llm, load_llm_from_config
@ -424,6 +425,30 @@ def _load_retrieval_qa(config: dict, **kwargs: Any) -> RetrievalQA:
)
def _load_retrieval_qa_with_sources_chain(
config: dict, **kwargs: Any
) -> RetrievalQAWithSourcesChain:
if "retriever" in kwargs:
retriever = kwargs.pop("retriever")
else:
raise ValueError("`retriever` must be present.")
if "combine_documents_chain" in config:
combine_documents_chain_config = config.pop("combine_documents_chain")
combine_documents_chain = load_chain_from_config(combine_documents_chain_config)
elif "combine_documents_chain_path" in config:
combine_documents_chain = load_chain(config.pop("combine_documents_chain_path"))
else:
raise ValueError(
"One of `combine_documents_chain` or "
"`combine_documents_chain_path` must be present."
)
return RetrievalQAWithSourcesChain(
combine_documents_chain=combine_documents_chain,
retriever=retriever,
**config,
)
def _load_vector_db_qa(config: dict, **kwargs: Any) -> VectorDBQA:
if "vectorstore" in kwargs:
vectorstore = kwargs.pop("vectorstore")
@ -537,6 +562,7 @@ type_to_loader_dict = {
"vector_db_qa_with_sources_chain": _load_vector_db_qa_with_sources_chain,
"vector_db_qa": _load_vector_db_qa,
"retrieval_qa": _load_retrieval_qa,
"retrieval_qa_with_sources_chain": _load_retrieval_qa_with_sources_chain,
"graph_cypher_chain": _load_graph_cypher_chain,
}

@ -60,3 +60,8 @@ class RetrievalQAWithSourcesChain(BaseQAWithSourcesChain):
question, callbacks=run_manager.get_child()
)
return self._reduce_tokens_below_limit(docs)
@property
def _chain_type(self) -> str:
"""Return the chain type."""
return "retrieval_qa_with_sources_chain"

@ -0,0 +1,28 @@
"""Test RetrievalQA functionality."""
from langchain.chains import RetrievalQAWithSourcesChain
from langchain.chains.loading import load_chain
from langchain.document_loaders import DirectoryLoader
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.llms import OpenAI
from langchain.text_splitter import CharacterTextSplitter
from langchain.vectorstores import FAISS
def test_retrieval_qa_with_sources_chain_saving_loading(tmp_path: str) -> None:
"""Test saving and loading."""
loader = DirectoryLoader("docs/extras/modules/", glob="*.txt")
documents = loader.load()
text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
texts = text_splitter.split_documents(documents)
embeddings = OpenAIEmbeddings()
docsearch = FAISS.from_documents(texts, embeddings)
qa = RetrievalQAWithSourcesChain.from_llm(
llm=OpenAI(), retriever=docsearch.as_retriever()
)
qa("What did the president say about Ketanji Brown Jackson?")
file_path = tmp_path + "/RetrievalQAWithSourcesChain.yaml"
qa.save(file_path=file_path)
qa_loaded = load_chain(file_path, retriever=docsearch.as_retriever())
assert qa_loaded == qa
Loading…
Cancel
Save