mirror of https://github.com/hwchase17/langchain
enable serde retrieval qa with sources (#10132)
#3983 mentions serialization/deserialization issues with both `RetrievalQA` & `RetrievalQAWithSourcesChain`. `RetrievalQA` has already been fixed in #5818. Mimicing #5818, I added the logic for `RetrievalQAWithSourcesChain`. --------- Co-authored-by: Markus Tretzmüller <markus.tretzmueller@cortecs.at> Co-authored-by: Bagatur <baskaryan@gmail.com>pull/10291/head
parent
62fa2bc518
commit
b3a8fc7cb1
@ -0,0 +1,28 @@
|
|||||||
|
"""Test RetrievalQA functionality."""
|
||||||
|
from langchain.chains import RetrievalQAWithSourcesChain
|
||||||
|
from langchain.chains.loading import load_chain
|
||||||
|
from langchain.document_loaders import DirectoryLoader
|
||||||
|
from langchain.embeddings.openai import OpenAIEmbeddings
|
||||||
|
from langchain.llms import OpenAI
|
||||||
|
from langchain.text_splitter import CharacterTextSplitter
|
||||||
|
from langchain.vectorstores import FAISS
|
||||||
|
|
||||||
|
|
||||||
|
def test_retrieval_qa_with_sources_chain_saving_loading(tmp_path: str) -> None:
|
||||||
|
"""Test saving and loading."""
|
||||||
|
loader = DirectoryLoader("docs/extras/modules/", glob="*.txt")
|
||||||
|
documents = loader.load()
|
||||||
|
text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
|
||||||
|
texts = text_splitter.split_documents(documents)
|
||||||
|
embeddings = OpenAIEmbeddings()
|
||||||
|
docsearch = FAISS.from_documents(texts, embeddings)
|
||||||
|
qa = RetrievalQAWithSourcesChain.from_llm(
|
||||||
|
llm=OpenAI(), retriever=docsearch.as_retriever()
|
||||||
|
)
|
||||||
|
qa("What did the president say about Ketanji Brown Jackson?")
|
||||||
|
|
||||||
|
file_path = tmp_path + "/RetrievalQAWithSourcesChain.yaml"
|
||||||
|
qa.save(file_path=file_path)
|
||||||
|
qa_loaded = load_chain(file_path, retriever=docsearch.as_retriever())
|
||||||
|
|
||||||
|
assert qa_loaded == qa
|
Loading…
Reference in New Issue