mirror of
https://github.com/hwchase17/langchain
synced 2024-11-06 03:20:49 +00:00
94c7899257
This PR adds Rockset as a vectorstore for langchain. [Rockset](https://rockset.com/blog/introducing-vector-search-on-rockset/) is a real time OLAP database which provides a fast and efficient vector search functionality. Further since it is entirely schemaless, it can store metadata in separate columns thereby allowing fast metadata filters during vector similarity search (as opposed to storing the entire metadata in a single JSON column). It currently supports three distance functions: `COSINE_SIMILARITY`, `EUCLIDEAN_DISTANCE`, and `DOT_PRODUCT`. This PR adds `rockset` client as an optional dependency. We would love a twitter shoutout, our handle is https://twitter.com/RocksetCloud --------- Co-authored-by: Dev 2049 <dev.dev2049@gmail.com>
156 lines
5.1 KiB
Python
156 lines
5.1 KiB
Python
import logging
|
|
import os
|
|
|
|
import rockset
|
|
import rockset.models
|
|
|
|
from langchain.docstore.document import Document
|
|
from langchain.vectorstores.rocksetdb import Rockset
|
|
from tests.integration_tests.vectorstores.fake_embeddings import (
|
|
ConsistentFakeEmbeddings,
|
|
fake_texts,
|
|
)
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
# To run these tests, make sure you have a collection with the name `langchain_demo`
|
|
# and the following ingest transformation:
|
|
#
|
|
# SELECT
|
|
# _input.* EXCEPT(_meta),
|
|
# VECTOR_ENFORCE(_input.description_embedding, 10, 'float') as
|
|
# description_embedding
|
|
# FROM
|
|
# _input
|
|
#
|
|
# We're using FakeEmbeddings utility to create text embeddings.
|
|
# It generates vector embeddings of length 10.
|
|
#
|
|
# Set env ROCKSET_DELETE_DOCS_ON_START=1 if you want to delete all docs from
|
|
# the collection before running any test. Be careful, this will delete any
|
|
# existing documents in your Rockset collection.
|
|
#
|
|
# See https://rockset.com/blog/introducing-vector-search-on-rockset/ for more details.
|
|
|
|
collection_name = "langchain_demo"
|
|
text_key = "description"
|
|
embedding_key = "description_embedding"
|
|
|
|
|
|
class TestRockset:
|
|
rockset_vectorstore: Rockset
|
|
|
|
@classmethod
|
|
def setup_class(cls) -> None:
|
|
assert os.environ.get("ROCKSET_API_KEY") is not None
|
|
assert os.environ.get("ROCKSET_REGION") is not None
|
|
|
|
api_key = os.environ.get("ROCKSET_API_KEY")
|
|
region = os.environ.get("ROCKSET_REGION")
|
|
if region == "use1a1":
|
|
host = rockset.Regions.use1a1
|
|
elif region == "usw2a1":
|
|
host = rockset.Regions.usw2a1
|
|
elif region == "euc1a1":
|
|
host = rockset.Regions.euc1a1
|
|
elif region == "dev":
|
|
host = rockset.DevRegions.usw2a1
|
|
else:
|
|
logger.warn(
|
|
"Using ROCKSET_REGION:%s as it is.. \
|
|
You should know what you're doing...",
|
|
region,
|
|
)
|
|
|
|
host = region
|
|
|
|
client = rockset.RocksetClient(host, api_key)
|
|
if os.environ.get("ROCKSET_DELETE_DOCS_ON_START") == "1":
|
|
logger.info(
|
|
"Deleting all existing documents from the Rockset collection %s",
|
|
collection_name,
|
|
)
|
|
|
|
query_response = client.Queries.query(
|
|
sql={"query": "select _id from {}".format(collection_name)}
|
|
)
|
|
ids = [
|
|
str(r["_id"])
|
|
for r in getattr(
|
|
query_response, query_response.attribute_map["results"]
|
|
)
|
|
]
|
|
logger.info("Existing ids in collection: %s", ids)
|
|
client.Documents.delete_documents(
|
|
collection=collection_name,
|
|
data=[rockset.models.DeleteDocumentsRequestData(id=i) for i in ids],
|
|
)
|
|
|
|
embeddings = ConsistentFakeEmbeddings()
|
|
embeddings.embed_documents(fake_texts)
|
|
cls.rockset_vectorstore = Rockset(
|
|
client, embeddings, collection_name, text_key, embedding_key
|
|
)
|
|
|
|
def test_rockset_insert_and_search(self) -> None:
|
|
"""Test end to end vector search in Rockset"""
|
|
|
|
texts = ["foo", "bar", "baz"]
|
|
metadatas = [{"metadata_index": i} for i in range(len(texts))]
|
|
ids = self.rockset_vectorstore.add_texts(
|
|
texts=texts,
|
|
metadatas=metadatas,
|
|
)
|
|
assert len(ids) == len(texts)
|
|
# Test that `foo` is closest to `foo`
|
|
output = self.rockset_vectorstore.similarity_search(
|
|
query="foo", distance_func=Rockset.DistanceFunction.COSINE_SIM, k=1
|
|
)
|
|
assert output == [Document(page_content="foo", metadata={"metadata_index": 0})]
|
|
|
|
# Find closest vector to `foo` which is not `foo`
|
|
output = self.rockset_vectorstore.similarity_search(
|
|
query="foo",
|
|
distance_func=Rockset.DistanceFunction.COSINE_SIM,
|
|
k=1,
|
|
where_str="metadata_index != 0",
|
|
)
|
|
assert output == [Document(page_content="bar", metadata={"metadata_index": 1})]
|
|
|
|
def test_build_query_sql(self) -> None:
|
|
vector = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]
|
|
q_str = self.rockset_vectorstore._build_query_sql(
|
|
vector,
|
|
Rockset.DistanceFunction.COSINE_SIM,
|
|
4,
|
|
)
|
|
vector_str = ",".join(map(str, vector))
|
|
expected = f"""\
|
|
SELECT * EXCEPT(description_embedding), \
|
|
COSINE_SIM(description_embedding, [{vector_str}]) as dist
|
|
FROM langchain_demo
|
|
ORDER BY dist DESC
|
|
LIMIT 4
|
|
"""
|
|
assert q_str == expected
|
|
|
|
def test_build_query_sql_with_where(self) -> None:
|
|
vector = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]
|
|
q_str = self.rockset_vectorstore._build_query_sql(
|
|
vector,
|
|
Rockset.DistanceFunction.COSINE_SIM,
|
|
4,
|
|
"age >= 10",
|
|
)
|
|
vector_str = ",".join(map(str, vector))
|
|
expected = f"""\
|
|
SELECT * EXCEPT(description_embedding), \
|
|
COSINE_SIM(description_embedding, [{vector_str}]) as dist
|
|
FROM langchain_demo
|
|
WHERE age >= 10
|
|
ORDER BY dist DESC
|
|
LIMIT 4
|
|
"""
|
|
assert q_str == expected
|