langchain/libs/community/tests/integration_tests/retrievers/test_thirdai_neuraldb.py
ccurme c010ec8b71
patch: deprecate (a)get_relevant_documents (#20477)
- `.get_relevant_documents(query)` -> `.invoke(query)`
- `.get_relevant_documents(query=query)` -> `.invoke(query)`
- `.get_relevant_documents(query, callbacks=callbacks)` ->
`.invoke(query, config={"callbacks": callbacks})`
- `.get_relevant_documents(query, **kwargs)` -> `.invoke(query,
**kwargs)`

---------

Co-authored-by: Erick Friis <erick@langchain.dev>
2024-04-22 11:14:53 -04:00

59 lines
1.8 KiB
Python

import os
import shutil
from typing import Generator
import pytest
from langchain_community.retrievers import NeuralDBRetriever
@pytest.fixture(scope="session")
def test_csv() -> Generator[str, None, None]:
csv = "thirdai-test.csv"
with open(csv, "w") as o:
o.write("column_1,column_2\n")
o.write("column one,column two\n")
yield csv
os.remove(csv)
def assert_result_correctness(documents: list) -> None:
assert len(documents) == 1
assert documents[0].page_content == "column_1: column one\n\ncolumn_2: column two"
@pytest.mark.requires("thirdai[neural_db]")
def test_neuraldb_retriever_from_scratch(test_csv: str) -> None:
retriever = NeuralDBRetriever.from_scratch()
retriever.insert([test_csv])
documents = retriever.invoke("column")
assert_result_correctness(documents)
@pytest.mark.requires("thirdai[neural_db]")
def test_neuraldb_retriever_from_checkpoint(test_csv: str) -> None:
checkpoint = "thirdai-test-save.ndb"
if os.path.exists(checkpoint):
shutil.rmtree(checkpoint)
try:
retriever = NeuralDBRetriever.from_scratch()
retriever.insert([test_csv])
retriever.save(checkpoint)
loaded_retriever = NeuralDBRetriever.from_checkpoint(checkpoint)
documents = loaded_retriever.invoke("column")
assert_result_correctness(documents)
finally:
if os.path.exists(checkpoint):
shutil.rmtree(checkpoint)
@pytest.mark.requires("thirdai[neural_db]")
def test_neuraldb_retriever_other_methods(test_csv: str) -> None:
retriever = NeuralDBRetriever.from_scratch()
retriever.insert([test_csv])
# Make sure they don't throw an error.
retriever.associate("A", "B")
retriever.associate_batch([("A", "B"), ("C", "D")])
retriever.upvote("A", 0)
retriever.upvote_batch([("A", 0), ("B", 0)])