mirror of
https://github.com/hwchase17/langchain
synced 2024-11-18 09:25:54 +00:00
community: Added integrations for ThirdAI's NeuralDB with Retriever and VectorStore frameworks (#15280)
**Description:** Adds ThirdAI NeuralDB retriever and vectorstore integration. NeuralDB is a CPU-friendly and fine-tunable text retrieval engine.
This commit is contained in:
parent
815896ff13
commit
f3fdc5c5da
160
docs/docs/integrations/vectorstores/thirdai_neuraldb.ipynb
Normal file
160
docs/docs/integrations/vectorstores/thirdai_neuraldb.ipynb
Normal file
@ -0,0 +1,160 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# **NeuralDB**\n",
|
||||
"NeuralDB is a CPU-friendly and fine-tunable vector store developed by ThirdAI.\n",
|
||||
"\n",
|
||||
"### **Initialization**\n",
|
||||
"There are three initialization methods:\n",
|
||||
"- From Scratch: Basic model\n",
|
||||
"- From Bazaar: Download a pretrained base model from our model bazaar for better performance\n",
|
||||
"- From Checkpoint: Load a model that was previously saved\n",
|
||||
"\n",
|
||||
"For all of the following initialization methods, the `thirdai_key` parameter can be ommitted if the `THIRDAI_KEY` environment variable is set.\n",
|
||||
"\n",
|
||||
"ThirdAI API keys can be obtained at https://www.thirdai.com/try-bolt/"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.vectorstores import NeuralDBVectorStore\n",
|
||||
"\n",
|
||||
"# From scratch\n",
|
||||
"vectorstore = NeuralDBVectorStore.from_scratch(thirdai_key=\"your-thirdai-key\")\n",
|
||||
"\n",
|
||||
"# From bazaar\n",
|
||||
"vectorstore = NeuralDBVectorStore.from_bazaar(\n",
|
||||
" # Name of base model to be downloaded from model bazaar.\n",
|
||||
" # \"General QnA\" gives better performance on question-answering.\n",
|
||||
" base=\"General QnA\",\n",
|
||||
" # Path to a directory that caches models to prevent repeated downloading.\n",
|
||||
" # Defaults to {CWD}/model_bazaar\n",
|
||||
" bazaar_cache=\"/path/to/bazaar_cache\",\n",
|
||||
" thirdai_key=\"your-thirdai-key\",\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"# From checkpoint\n",
|
||||
"vectorstore = NeuralDBVectorStore.from_checkpoint(\n",
|
||||
" # Path to a NeuralDB checkpoint. For example, if you call\n",
|
||||
" # vectorstore.save(\"/path/to/checkpoint.ndb\") in one script, then you can\n",
|
||||
" # call NeuralDBVectorStore.from_checkpoint(\"/path/to/checkpoint.ndb\") in\n",
|
||||
" # another script to load the saved model.\n",
|
||||
" checkpoint=\"/path/to/checkpoint.ndb\",\n",
|
||||
" thirdai_key=\"your-thirdai-key\",\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### **Inserting document sources**"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"vectorstore.insert(\n",
|
||||
" # If you have PDF, DOCX, or CSV files, you can directly pass the paths to the documents\n",
|
||||
" sources=[\"/path/to/doc.pdf\", \"/path/to/doc.docx\", \"/path/to/doc.csv\"],\n",
|
||||
" # When True this means that the underlying model in the NeuralDB will\n",
|
||||
" # undergo unsupervised pretraining on the inserted files. Defaults to True.\n",
|
||||
" train=True,\n",
|
||||
" # Much faster insertion with a slight drop in performance. Defaults to True.\n",
|
||||
" fast_mode=True,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"from thirdai import neural_db as ndb\n",
|
||||
"\n",
|
||||
"vectorstore.insert(\n",
|
||||
" # If you have files in other formats, or prefer to configure how\n",
|
||||
" # your files are parsed, then you can pass in NeuralDB document objects\n",
|
||||
" # like this.\n",
|
||||
" sources=[\n",
|
||||
" ndb.PDF(\n",
|
||||
" \"/path/to/doc.pdf\",\n",
|
||||
" version=\"v2\",\n",
|
||||
" chunk_size=100,\n",
|
||||
" metadata={\"published\": 2022},\n",
|
||||
" ),\n",
|
||||
" ndb.Unstructured(\"/path/to/deck.pptx\"),\n",
|
||||
" ]\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### **Similarity search**\n",
|
||||
"To query the vectorstore, you can use the standard LangChain vectorstore method `similarity_search`, which returns a list of LangChain Document objects. Each document object represents a chunk of text from the indexed files. For example, it may contain a paragraph from one of the indexed PDF files. In addition to the text, the document's metadata field contains information such as the document's ID, the source of this document (which file it came from), and the score of the document."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# This returns a list of LangChain Document objects\n",
|
||||
"documents = vectorstore.similarity_search(\"query\", k=10)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### **Fine tuning**\n",
|
||||
"NeuralDBVectorStore can be fine-tuned to user behavior and domain-specific knowledge. It can be fine-tuned in two ways:\n",
|
||||
"1. Association: the vectorstore associates a source phrase with a target phrase. When the vectorstore sees the source phrase, it will also consider results that are relevant to the target phrase.\n",
|
||||
"2. Upvoting: the vectorstore upweights the score of a document for a specific query. This is useful when you want to fine-tune the vectorstore to user behavior. For example, if a user searches \"how is a car manufactured\" and likes the returned document with id 52, then we can upvote the document with id 52 for the query \"how is a car manufactured\"."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"vectorstore.associate(source=\"source phrase\", target=\"target phrase\")\n",
|
||||
"vectorstore.associate_batch(\n",
|
||||
" [\n",
|
||||
" (\"source phrase 1\", \"target phrase 1\"),\n",
|
||||
" (\"source phrase 2\", \"target phrase 2\"),\n",
|
||||
" ]\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"vectorstore.upvote(query=\"how is a car manufactured\", document_id=52)\n",
|
||||
"vectorstore.upvote_batch(\n",
|
||||
" [\n",
|
||||
" (\"query 1\", 52),\n",
|
||||
" (\"query 2\", 20),\n",
|
||||
" ]\n",
|
||||
")"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "langchain",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"name": "python",
|
||||
"version": "3.10.0"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
@ -470,6 +470,12 @@ def _import_zilliz() -> Any:
|
||||
return Zilliz
|
||||
|
||||
|
||||
def _import_neuraldb() -> Any:
|
||||
from langchain_community.vectorstores.thirdai_neuraldb import NeuralDBVectorStore
|
||||
|
||||
return NeuralDBVectorStore
|
||||
|
||||
|
||||
def _import_lantern() -> Any:
|
||||
from langchain_community.vectorstores.lantern import Lantern
|
||||
|
||||
@ -621,6 +627,8 @@ def __getattr__(name: str) -> Any:
|
||||
return _import_zilliz()
|
||||
elif name == "VespaStore":
|
||||
return _import_vespa()
|
||||
elif name == "NeuralDBVectorStore":
|
||||
return _import_neuraldb()
|
||||
elif name == "Lantern":
|
||||
return _import_lantern()
|
||||
else:
|
||||
@ -699,5 +707,6 @@ __all__ = [
|
||||
"TencentVectorDB",
|
||||
"AzureCosmosDBVectorSearch",
|
||||
"VectorStore",
|
||||
"NeuralDBVectorStore",
|
||||
"Lantern",
|
||||
]
|
||||
|
@ -0,0 +1,344 @@
|
||||
import importlib
|
||||
import os
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
|
||||
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.pydantic_v1 import Extra, root_validator
|
||||
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
|
||||
from langchain_core.vectorstores import VectorStore
|
||||
|
||||
|
||||
class NeuralDBVectorStore(VectorStore):
|
||||
"""Vectorstore that uses ThirdAI's NeuralDB."""
|
||||
|
||||
db: Any = None #: :meta private:
|
||||
"""NeuralDB instance"""
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
extra = Extra.forbid
|
||||
underscore_attrs_are_private = True
|
||||
|
||||
@staticmethod
|
||||
def _verify_thirdai_library(thirdai_key: Optional[str] = None):
|
||||
try:
|
||||
from thirdai import licensing
|
||||
|
||||
importlib.util.find_spec("thirdai.neural_db")
|
||||
|
||||
licensing.activate(thirdai_key or os.getenv("THIRDAI_KEY"))
|
||||
except ImportError:
|
||||
raise ModuleNotFoundError(
|
||||
"Could not import thirdai python package and neuraldb dependencies. "
|
||||
"Please install it with `pip install thirdai[neural_db]`."
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_scratch(
|
||||
cls,
|
||||
thirdai_key: Optional[str] = None,
|
||||
**model_kwargs,
|
||||
):
|
||||
"""
|
||||
Create a NeuralDBVectorStore from scratch.
|
||||
|
||||
To use, set the ``THIRDAI_KEY`` environment variable with your ThirdAI
|
||||
API key, or pass ``thirdai_key`` as a named parameter.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_community.vectorstores import NeuralDBVectorStore
|
||||
|
||||
vectorstore = NeuralDBVectorStore.from_scratch(
|
||||
thirdai_key="your-thirdai-key",
|
||||
)
|
||||
|
||||
vectorstore.insert([
|
||||
"/path/to/doc.pdf",
|
||||
"/path/to/doc.docx",
|
||||
"/path/to/doc.csv",
|
||||
])
|
||||
|
||||
documents = vectorstore.similarity_search("AI-driven music therapy")
|
||||
"""
|
||||
NeuralDBVectorStore._verify_thirdai_library(thirdai_key)
|
||||
from thirdai import neural_db as ndb
|
||||
|
||||
return cls(db=ndb.NeuralDB(**model_kwargs))
|
||||
|
||||
@classmethod
|
||||
def from_bazaar(
|
||||
cls,
|
||||
base: str,
|
||||
bazaar_cache: Optional[str] = None,
|
||||
thirdai_key: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Create a NeuralDBVectorStore with a base model from the ThirdAI
|
||||
model bazaar.
|
||||
|
||||
To use, set the ``THIRDAI_KEY`` environment variable with your ThirdAI
|
||||
API key, or pass ``thirdai_key`` as a named parameter.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_community.vectorstores import NeuralDBVectorStore
|
||||
|
||||
vectorstore = NeuralDBVectorStore.from_bazaar(
|
||||
base="General QnA",
|
||||
thirdai_key="your-thirdai-key",
|
||||
)
|
||||
|
||||
vectorstore.insert([
|
||||
"/path/to/doc.pdf",
|
||||
"/path/to/doc.docx",
|
||||
"/path/to/doc.csv",
|
||||
])
|
||||
|
||||
documents = vectorstore.similarity_search("AI-driven music therapy")
|
||||
"""
|
||||
NeuralDBVectorStore._verify_thirdai_library(thirdai_key)
|
||||
from thirdai import neural_db as ndb
|
||||
|
||||
cache = bazaar_cache or str(Path(os.getcwd()) / "model_bazaar")
|
||||
if not os.path.exists(cache):
|
||||
os.mkdir(cache)
|
||||
model_bazaar = ndb.Bazaar(cache)
|
||||
model_bazaar.fetch()
|
||||
return cls(db=model_bazaar.get_model(base))
|
||||
|
||||
@classmethod
|
||||
def from_checkpoint(
|
||||
cls,
|
||||
checkpoint: Union[str, Path],
|
||||
thirdai_key: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Create a NeuralDBVectorStore with a base model from a saved checkpoint
|
||||
|
||||
To use, set the ``THIRDAI_KEY`` environment variable with your ThirdAI
|
||||
API key, or pass ``thirdai_key`` as a named parameter.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_community.vectorstores import NeuralDBVectorStore
|
||||
|
||||
vectorstore = NeuralDBVectorStore.from_checkpoint(
|
||||
checkpoint="/path/to/checkpoint.ndb",
|
||||
thirdai_key="your-thirdai-key",
|
||||
)
|
||||
|
||||
vectorstore.insert([
|
||||
"/path/to/doc.pdf",
|
||||
"/path/to/doc.docx",
|
||||
"/path/to/doc.csv",
|
||||
])
|
||||
|
||||
documents = vectorstore.similarity_search("AI-driven music therapy")
|
||||
"""
|
||||
NeuralDBVectorStore._verify_thirdai_library(thirdai_key)
|
||||
from thirdai import neural_db as ndb
|
||||
|
||||
return cls(db=ndb.NeuralDB.from_checkpoint(checkpoint))
|
||||
|
||||
@classmethod
|
||||
def from_texts(
|
||||
cls,
|
||||
texts: List[str],
|
||||
embedding: Embeddings,
|
||||
metadatas: Optional[List[dict]] = None,
|
||||
**kwargs: Any,
|
||||
) -> "NeuralDBVectorStore":
|
||||
"""Return VectorStore initialized from texts and embeddings."""
|
||||
model_kwargs = {}
|
||||
if "thirdai_key" in kwargs:
|
||||
model_kwargs["thirdai_key"] = kwargs["thirdai_key"]
|
||||
del kwargs["thirdai_key"]
|
||||
vectorstore = cls.from_scratch(**model_kwargs)
|
||||
vectorstore.add_texts(texts, metadatas, **kwargs)
|
||||
return vectorstore
|
||||
|
||||
def add_texts(
|
||||
self,
|
||||
texts: Iterable[str],
|
||||
metadatas: Optional[List[dict]] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[str]:
|
||||
"""Run more texts through the embeddings and add to the vectorstore.
|
||||
|
||||
Args:
|
||||
texts: Iterable of strings to add to the vectorstore.
|
||||
metadatas: Optional list of metadatas associated with the texts.
|
||||
kwargs: vectorstore specific parameters
|
||||
|
||||
Returns:
|
||||
List of ids from adding the texts into the vectorstore.
|
||||
"""
|
||||
import pandas as pd
|
||||
from thirdai import neural_db as ndb
|
||||
|
||||
df = pd.DataFrame({"texts": texts})
|
||||
if metadatas:
|
||||
df = pd.concat([df, pd.DataFrame.from_records(metadatas)], axis=1)
|
||||
temp = tempfile.NamedTemporaryFile("w", delete=False, delete_on_close=False)
|
||||
df.to_csv(temp)
|
||||
source_id = self.insert([ndb.CSV(temp.name)], **kwargs)[0]
|
||||
offset = self.db._savable_state.documents.get_source_by_id(source_id)[1]
|
||||
return [str(offset + i) for i in range(len(texts))]
|
||||
|
||||
@root_validator()
|
||||
def validate_environments(cls, values: Dict) -> Dict:
|
||||
"""Validate ThirdAI environment variables."""
|
||||
values["thirdai_key"] = convert_to_secret_str(
|
||||
get_from_dict_or_env(
|
||||
values,
|
||||
"thirdai_key",
|
||||
"THIRDAI_KEY",
|
||||
)
|
||||
)
|
||||
return values
|
||||
|
||||
def insert(
|
||||
self,
|
||||
sources: List[Any],
|
||||
train: bool = True,
|
||||
fast_mode: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
"""Inserts files / document sources into the vectorstore.
|
||||
|
||||
Args:
|
||||
train: When True this means that the underlying model in the
|
||||
NeuralDB will undergo unsupervised pretraining on the inserted files.
|
||||
Defaults to True.
|
||||
fast_mode: Much faster insertion with a slight drop in performance.
|
||||
Defaults to True.
|
||||
"""
|
||||
sources = self._preprocess_sources(sources)
|
||||
self.db.insert(
|
||||
sources=sources,
|
||||
train=train,
|
||||
fast_approximation=fast_mode,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def _preprocess_sources(self, sources):
|
||||
"""Checks if the provided sources are string paths. If they are, convert
|
||||
to NeuralDB document objects.
|
||||
|
||||
Args:
|
||||
sources: list of either string paths to PDF, DOCX or CSV files, or
|
||||
NeuralDB document objects.
|
||||
"""
|
||||
from thirdai import neural_db as ndb
|
||||
|
||||
if not sources:
|
||||
return sources
|
||||
preprocessed_sources = []
|
||||
for doc in sources:
|
||||
if not isinstance(doc, str):
|
||||
preprocessed_sources.append(doc)
|
||||
else:
|
||||
if doc.lower().endswith(".pdf"):
|
||||
preprocessed_sources.append(ndb.PDF(doc))
|
||||
elif doc.lower().endswith(".docx"):
|
||||
preprocessed_sources.append(ndb.DOCX(doc))
|
||||
elif doc.lower().endswith(".csv"):
|
||||
preprocessed_sources.append(ndb.CSV(doc))
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"Could not automatically load {doc}. Only files "
|
||||
"with .pdf, .docx, or .csv extensions can be loaded "
|
||||
"automatically. For other formats, please use the "
|
||||
"appropriate document object from the ThirdAI library."
|
||||
)
|
||||
return preprocessed_sources
|
||||
|
||||
def upvote(self, query: str, document_id: Union[int, str]):
|
||||
"""The vectorstore upweights the score of a document for a specific query.
|
||||
This is useful for fine-tuning the vectorstore to user behavior.
|
||||
|
||||
Args:
|
||||
query: text to associate with `document_id`
|
||||
document_id: id of the document to associate query with.
|
||||
"""
|
||||
self.db.text_to_result(query, int(document_id))
|
||||
|
||||
def upvote_batch(self, query_id_pairs: List[Tuple[str, int]]):
|
||||
"""Given a batch of (query, document id) pairs, the vectorstore upweights
|
||||
the scores of the document for the corresponding queries.
|
||||
This is useful for fine-tuning the vectorstore to user behavior.
|
||||
|
||||
Args:
|
||||
query_id_pairs: list of (query, document id) pairs. For each pair in
|
||||
this list, the model will upweight the document id for the query.
|
||||
"""
|
||||
self.db.text_to_result_batch(
|
||||
[(query, int(doc_id)) for query, doc_id in query_id_pairs]
|
||||
)
|
||||
|
||||
def associate(self, source: str, target: str):
|
||||
"""The vectorstore associates a source phrase with a target phrase.
|
||||
When the vectorstore sees the source phrase, it will also consider results
|
||||
that are relevant to the target phrase.
|
||||
|
||||
Args:
|
||||
source: text to associate to `target`.
|
||||
target: text to associate `source` to.
|
||||
"""
|
||||
self.db.associate(source, target)
|
||||
|
||||
def associate_batch(self, text_pairs: List[Tuple[str, str]]):
|
||||
"""Given a batch of (source, target) pairs, the vectorstore associates
|
||||
each source phrase with the corresponding target phrase.
|
||||
|
||||
Args:
|
||||
text_pairs: list of (source, target) text pairs. For each pair in
|
||||
this list, the source will be associated with the target.
|
||||
"""
|
||||
self.db.associate_batch(text_pairs)
|
||||
|
||||
def similarity_search(
|
||||
self, query: str, k: int = 10, **kwargs: Any
|
||||
) -> List[Document]:
|
||||
"""Retrieve {k} contexts with for a given query
|
||||
|
||||
Args:
|
||||
query: Query to submit to the model
|
||||
k: The max number of context results to retrieve. Defaults to 10.
|
||||
"""
|
||||
try:
|
||||
references = self.db.search(query=query, top_k=k, **kwargs)
|
||||
return [
|
||||
Document(
|
||||
page_content=ref.text,
|
||||
metadata={
|
||||
"id": ref.id,
|
||||
"upvote_ids": ref.upvote_ids,
|
||||
"text": ref.text,
|
||||
"source": ref.source,
|
||||
"metadata": ref.metadata,
|
||||
"score": ref.score,
|
||||
"context": ref.context(1),
|
||||
},
|
||||
)
|
||||
for ref in references
|
||||
]
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error while retrieving documents: {e}") from e
|
||||
|
||||
def save(self, path: str):
|
||||
"""Saves a NeuralDB instance to disk. Can be loaded into memory by
|
||||
calling NeuralDB.from_checkpoint(path)
|
||||
|
||||
Args:
|
||||
path: path on disk to save the NeuralDB instance to.
|
||||
"""
|
||||
self.db.save(path)
|
@ -0,0 +1,65 @@
|
||||
import os
|
||||
import shutil
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain_community.vectorstores import NeuralDBVectorStore
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def test_csv():
|
||||
csv = "thirdai-test.csv"
|
||||
with open(csv, "w") as o:
|
||||
o.write("column_1,column_2\n")
|
||||
o.write("column one,column two\n")
|
||||
yield csv
|
||||
os.remove(csv)
|
||||
|
||||
|
||||
def assert_result_correctness(documents):
|
||||
assert len(documents) == 1
|
||||
assert documents[0].page_content == "column_1: column one\n\ncolumn_2: column two"
|
||||
|
||||
|
||||
@pytest.mark.requires("thirdai[neural_db]")
|
||||
def test_neuraldb_retriever_from_scratch(test_csv):
|
||||
retriever = NeuralDBVectorStore.from_scratch()
|
||||
retriever.insert([test_csv])
|
||||
documents = retriever.similarity_search("column")
|
||||
assert_result_correctness(documents)
|
||||
|
||||
|
||||
@pytest.mark.requires("thirdai[neural_db]")
|
||||
def test_neuraldb_retriever_from_checkpoint(test_csv):
|
||||
checkpoint = "thirdai-test-save.ndb"
|
||||
if os.path.exists(checkpoint):
|
||||
shutil.rmtree(checkpoint)
|
||||
try:
|
||||
retriever = NeuralDBVectorStore.from_scratch()
|
||||
retriever.insert([test_csv])
|
||||
retriever.save(checkpoint)
|
||||
loaded_retriever = NeuralDBVectorStore.from_checkpoint(checkpoint)
|
||||
documents = loaded_retriever.similarity_search("column")
|
||||
assert_result_correctness(documents)
|
||||
finally:
|
||||
if os.path.exists(checkpoint):
|
||||
shutil.rmtree(checkpoint)
|
||||
|
||||
|
||||
@pytest.mark.requires("thirdai[neural_db]")
|
||||
def test_neuraldb_retriever_from_bazaar(test_csv):
|
||||
retriever = NeuralDBVectorStore.from_bazaar("General QnA")
|
||||
retriever.insert([test_csv])
|
||||
documents = retriever.similarity_search("column")
|
||||
assert_result_correctness(documents)
|
||||
|
||||
|
||||
@pytest.mark.requires("thirdai[neural_db]")
|
||||
def test_neuraldb_retriever_other_methods(test_csv):
|
||||
retriever = NeuralDBVectorStore.from_scratch()
|
||||
retriever.insert([test_csv])
|
||||
# Make sure they don't throw an error.
|
||||
retriever.associate("A", "B")
|
||||
retriever.associate_batch([("A", "B"), ("C", "D")])
|
||||
retriever.upvote("A", 0)
|
||||
retriever.upvote_batch([("A", 0), ("B", 0)])
|
@ -74,6 +74,7 @@ _EXPECTED = [
|
||||
"AzureCosmosDBVectorSearch",
|
||||
"VectorStore",
|
||||
"Yellowbrick",
|
||||
"NeuralDBVectorStore",
|
||||
]
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user