mirror of
https://github.com/hwchase17/langchain
synced 2024-11-10 01:10:59 +00:00
28 lines
1.1 KiB
Python
28 lines
1.1 KiB
Python
|
"""Test RetrievalQA functionality."""
|
||
|
from pathlib import Path
|
||
|
|
||
|
from langchain.chains import RetrievalQA
|
||
|
from langchain.chains.loading import load_chain
|
||
|
from langchain.document_loaders import TextLoader
|
||
|
from langchain.embeddings.openai import OpenAIEmbeddings
|
||
|
from langchain.llms import OpenAI
|
||
|
from langchain.text_splitter import CharacterTextSplitter
|
||
|
from langchain.vectorstores import Chroma
|
||
|
|
||
|
|
||
|
def test_retrieval_qa_saving_loading(tmp_path: Path) -> None:
|
||
|
"""Test saving and loading."""
|
||
|
loader = TextLoader("docs/modules/state_of_the_union.txt")
|
||
|
documents = loader.load()
|
||
|
text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
|
||
|
texts = text_splitter.split_documents(documents)
|
||
|
embeddings = OpenAIEmbeddings()
|
||
|
docsearch = Chroma.from_documents(texts, embeddings)
|
||
|
qa = RetrievalQA.from_llm(llm=OpenAI(), retriever=docsearch.as_retriever())
|
||
|
|
||
|
file_path = tmp_path / "RetrievalQA_chain.yaml"
|
||
|
qa.save(file_path=file_path)
|
||
|
qa_loaded = load_chain(file_path, retriever=docsearch.as_retriever())
|
||
|
|
||
|
assert qa_loaded == qa
|