From fd5f549a9e93a58a63ca2c8953aefcba9cd8902d Mon Sep 17 00:00:00 2001 From: Richard Adams Date: Wed, 25 Oct 2023 05:33:34 +0100 Subject: [PATCH] demonstrate use of RetrievalQAWithSourcesChain.from_chain (#12235) **Description:** Documents further usage of RetrievalQAWithSourcesChain in an existing test. I'd not found much documented usage of RetrievalQAWithSourcesChain and how to get the sources out. This additional code will hopefully be useful to other potential users of this retriever. **Issue:** No raised issue **Dependencies:** No new dependencies needed to run the test (it already needs `open-ai`, `faiss-cpu` and `unstructured`). Note - `make lint` showed 8 linting errors in unrelated files --------- Co-authored-by: richarda23 --- .../chains/test_retrieval_qa_with_sources.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/libs/langchain/tests/integration_tests/chains/test_retrieval_qa_with_sources.py b/libs/langchain/tests/integration_tests/chains/test_retrieval_qa_with_sources.py index 70ee98513e..7ace5730ad 100644 --- a/libs/langchain/tests/integration_tests/chains/test_retrieval_qa_with_sources.py +++ b/libs/langchain/tests/integration_tests/chains/test_retrieval_qa_with_sources.py @@ -19,10 +19,20 @@ def test_retrieval_qa_with_sources_chain_saving_loading(tmp_path: str) -> None: qa = RetrievalQAWithSourcesChain.from_llm( llm=OpenAI(), retriever=docsearch.as_retriever() ) - qa("What did the president say about Ketanji Brown Jackson?") - - file_path = tmp_path + "/RetrievalQAWithSourcesChain.yaml" + result = qa("What did the president say about Ketanji Brown Jackson?") + assert "question" in result.keys() + assert "answer" in result.keys() + assert "sources" in result.keys() + file_path = str(tmp_path) + "/RetrievalQAWithSourcesChain.yaml" qa.save(file_path=file_path) qa_loaded = load_chain(file_path, retriever=docsearch.as_retriever()) assert qa_loaded == qa + + qa2 = RetrievalQAWithSourcesChain.from_chain_type( + llm=OpenAI(), retriever=docsearch.as_retriever(), chain_type="stuff" + ) + result2 = qa2("What did the president say about Ketanji Brown Jackson?") + assert "question" in result2.keys() + assert "answer" in result2.keys() + assert "sources" in result2.keys()