diff --git a/langchain/chains/loading.py b/langchain/chains/loading.py index f26125bd..cc628667 100644 --- a/langchain/chains/loading.py +++ b/langchain/chains/loading.py @@ -20,7 +20,7 @@ from langchain.chains.llm_requests import LLMRequestsChain from langchain.chains.pal.base import PALChain from langchain.chains.qa_with_sources.base import QAWithSourcesChain from langchain.chains.qa_with_sources.vector_db import VectorDBQAWithSourcesChain -from langchain.chains.retrieval_qa.base import VectorDBQA +from langchain.chains.retrieval_qa.base import RetrievalQA, VectorDBQA from langchain.chains.sql_database.base import SQLDatabaseChain from langchain.llms.loading import load_llm, load_llm_from_config from langchain.prompts.loading import load_prompt, load_prompt_from_config @@ -372,6 +372,28 @@ def _load_vector_db_qa_with_sources_chain( ) +def _load_retrieval_qa(config: dict, **kwargs: Any) -> RetrievalQA: + if "retriever" in kwargs: + retriever = kwargs.pop("retriever") + else: + raise ValueError("`retriever` must be present.") + if "combine_documents_chain" in config: + combine_documents_chain_config = config.pop("combine_documents_chain") + combine_documents_chain = load_chain_from_config(combine_documents_chain_config) + elif "combine_documents_chain_path" in config: + combine_documents_chain = load_chain(config.pop("combine_documents_chain_path")) + else: + raise ValueError( + "One of `combine_documents_chain` or " + "`combine_documents_chain_path` must be present." + ) + return RetrievalQA( + combine_documents_chain=combine_documents_chain, + retriever=retriever, + **config, + ) + + def _load_vector_db_qa(config: dict, **kwargs: Any) -> VectorDBQA: if "vectorstore" in kwargs: vectorstore = kwargs.pop("vectorstore") @@ -459,6 +481,7 @@ type_to_loader_dict = { "sql_database_chain": _load_sql_database_chain, "vector_db_qa_with_sources_chain": _load_vector_db_qa_with_sources_chain, "vector_db_qa": _load_vector_db_qa, + "retrieval_qa": _load_retrieval_qa, } diff --git a/langchain/chains/retrieval_qa/base.py b/langchain/chains/retrieval_qa/base.py index 2255f957..40ec2c8d 100644 --- a/langchain/chains/retrieval_qa/base.py +++ b/langchain/chains/retrieval_qa/base.py @@ -183,6 +183,11 @@ class RetrievalQA(BaseRetrievalQA): async def _aget_docs(self, question: str) -> List[Document]: return await self.retriever.aget_relevant_documents(question) + @property + def _chain_type(self) -> str: + """Return the chain type.""" + return "retrieval_qa" + class VectorDBQA(BaseRetrievalQA): """Chain for question-answering against a vector database.""" diff --git a/tests/integration_tests/chains/test_retrieval_qa.py b/tests/integration_tests/chains/test_retrieval_qa.py new file mode 100644 index 00000000..7ce6fdff --- /dev/null +++ b/tests/integration_tests/chains/test_retrieval_qa.py @@ -0,0 +1,27 @@ +"""Test RetrievalQA functionality.""" +from pathlib import Path + +from langchain.chains import RetrievalQA +from langchain.chains.loading import load_chain +from langchain.document_loaders import TextLoader +from langchain.embeddings.openai import OpenAIEmbeddings +from langchain.llms import OpenAI +from langchain.text_splitter import CharacterTextSplitter +from langchain.vectorstores import Chroma + + +def test_retrieval_qa_saving_loading(tmp_path: Path) -> None: + """Test saving and loading.""" + loader = TextLoader("docs/modules/state_of_the_union.txt") + documents = loader.load() + text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0) + texts = text_splitter.split_documents(documents) + embeddings = OpenAIEmbeddings() + docsearch = Chroma.from_documents(texts, embeddings) + qa = RetrievalQA.from_llm(llm=OpenAI(), retriever=docsearch.as_retriever()) + + file_path = tmp_path / "RetrievalQA_chain.yaml" + qa.save(file_path=file_path) + qa_loaded = load_chain(file_path, retriever=docsearch.as_retriever()) + + assert qa_loaded == qa