langchain/libs/community/tests/integration_tests/retrievers/test_thirdai_neuraldb.py

59 lines
1.8 KiB
Python
Raw Normal View History

import os
import shutil
from typing import Generator
import pytest
from langchain_community.retrievers import NeuralDBRetriever
@pytest.fixture(scope="session")
def test_csv() -> Generator[str, None, None]:
csv = "thirdai-test.csv"
with open(csv, "w") as o:
o.write("column_1,column_2\n")
o.write("column one,column two\n")
yield csv
os.remove(csv)
def assert_result_correctness(documents: list) -> None:
assert len(documents) == 1
assert documents[0].page_content == "column_1: column one\n\ncolumn_2: column two"
@pytest.mark.requires("thirdai[neural_db]")
def test_neuraldb_retriever_from_scratch(test_csv: str) -> None:
retriever = NeuralDBRetriever.from_scratch()
retriever.insert([test_csv])
documents = retriever.invoke("column")
assert_result_correctness(documents)
@pytest.mark.requires("thirdai[neural_db]")
def test_neuraldb_retriever_from_checkpoint(test_csv: str) -> None:
checkpoint = "thirdai-test-save.ndb"
if os.path.exists(checkpoint):
shutil.rmtree(checkpoint)
try:
retriever = NeuralDBRetriever.from_scratch()
retriever.insert([test_csv])
retriever.save(checkpoint)
loaded_retriever = NeuralDBRetriever.from_checkpoint(checkpoint)
documents = loaded_retriever.invoke("column")
assert_result_correctness(documents)
finally:
if os.path.exists(checkpoint):
shutil.rmtree(checkpoint)
@pytest.mark.requires("thirdai[neural_db]")
def test_neuraldb_retriever_other_methods(test_csv: str) -> None:
retriever = NeuralDBRetriever.from_scratch()
retriever.insert([test_csv])
# Make sure they don't throw an error.
retriever.associate("A", "B")
retriever.associate_batch([("A", "B"), ("C", "D")])
retriever.upvote("A", 0)
retriever.upvote_batch([("A", 0), ("B", 0)])