Harrison/add vald (#10807)

Co-authored-by: datelier <57349093+datelier@users.noreply.github.com>
pull/10790/head
Harrison Chase 1 year ago committed by GitHub
parent bbc3fe259b
commit d2bee34d4c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -0,0 +1,175 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "25bce5eb-8599-40fe-947e-4932cfae8184",
"metadata": {},
"source": [
"# Vald\n",
"\n",
"> [Vald](https://github.com/vdaas/vald) is a highly scalable distributed fast approximate nearest neighbor (ANN) dense vector search engine.\n",
"\n",
"This notebook shows how to use functionality related to the `Vald` database.\n",
"\n",
"To run this notebook you need a running Vald cluster.\n",
"Check [Get Started](https://github.com/vdaas/vald#get-started) for more information.\n",
"\n",
"See the [installation instructions](https://github.com/vdaas/vald-client-python#install)."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f45f46f2-7229-4859-9797-30bbead1b8e0",
"metadata": {},
"outputs": [],
"source": [
"!pip install vald-client-python"
]
},
{
"cell_type": "markdown",
"id": "2f65caa9-8383-409a-bccb-6e91fc8d5e8f",
"metadata": {},
"source": [
"## Basic Example"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "eab0b1e4-9793-4be7-a2ba-e4455c21ea22",
"metadata": {},
"outputs": [],
"source": [
"from langchain.document_loaders import TextLoader\n",
"from langchain.embeddings import HuggingFaceEmbeddings\n",
"from langchain.text_splitter import CharacterTextSplitter\n",
"from langchain.vectorstores import Vald\n",
"\n",
"raw_documents = TextLoader('state_of_the_union.txt').load()\n",
"text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)\n",
"documents = text_splitter.split_documents(raw_documents)\n",
"embeddings = HuggingFaceEmbeddings()\n",
"db = Vald.from_documents(documents, embeddings, host=\"localhost\", port=8080)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b0a6797c-2bb0-45db-a636-5d2437f7a4c0",
"metadata": {},
"outputs": [],
"source": [
"query = \"What did the president say about Ketanji Brown Jackson\"\n",
"docs = db.similarity_search(query)\n",
"docs[0].page_content"
]
},
{
"cell_type": "markdown",
"id": "c4c4e06d-6def-44ce-ac9a-4c01673c29a2",
"metadata": {},
"source": [
"### Similarity search by vector"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1eb72610-d451-4158-880c-9f0d45fa5909",
"metadata": {},
"outputs": [],
"source": [
"embedding_vector = embeddings.embed_query(query)\n",
"docs = db.similarity_search_by_vector(embedding_vector)\n",
"docs[0].page_content"
]
},
{
"cell_type": "markdown",
"id": "d33588d4-67c2-4bd3-b251-76ae783cbafb",
"metadata": {},
"source": [
"### Similarity search with score"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1a41e382-0336-4e6d-b2ef-44cc77db2696",
"metadata": {},
"outputs": [],
"source": [
"docs_and_scores = db.similarity_search_with_score(query)\n",
"docs_and_scores[0]"
]
},
{
"cell_type": "markdown",
"id": "57f930f2-41a0-4795-ad9e-44a33c8f88ec",
"metadata": {},
"source": [
"## Maximal Marginal Relevance Search (MMR)"
]
},
{
"cell_type": "markdown",
"id": "4790e437-3207-45cb-b121-d857ab5aabd8",
"metadata": {},
"source": [
"In addition to using similarity search in the retriever object, you can also use `mmr` as retriever."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "495754b1-5cdb-4af6-9733-f68700bb7232",
"metadata": {},
"outputs": [],
"source": [
"retriever = db.as_retriever(search_type=\"mmr\")\n",
"retriever.get_relevant_documents(query)"
]
},
{
"cell_type": "markdown",
"id": "e213d957-e439-4bd6-90f2-8909323f5f09",
"metadata": {},
"source": [
"Or use `max_marginal_relevance_search` directly:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "99d928d0-3b79-4588-925e-32230e12af47",
"metadata": {},
"outputs": [],
"source": [
"db.max_marginal_relevance_search(query, k=2, fetch_k=10)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.4"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

@ -71,6 +71,7 @@ from langchain.vectorstores.tencentvectordb import TencentVectorDB
from langchain.vectorstores.tigris import Tigris from langchain.vectorstores.tigris import Tigris
from langchain.vectorstores.typesense import Typesense from langchain.vectorstores.typesense import Typesense
from langchain.vectorstores.usearch import USearch from langchain.vectorstores.usearch import USearch
from langchain.vectorstores.vald import Vald
from langchain.vectorstores.vectara import Vectara from langchain.vectorstores.vectara import Vectara
from langchain.vectorstores.weaviate import Weaviate from langchain.vectorstores.weaviate import Weaviate
from langchain.vectorstores.zep import ZepVectorStore from langchain.vectorstores.zep import ZepVectorStore
@ -133,6 +134,7 @@ __all__ = [
"Tigris", "Tigris",
"Typesense", "Typesense",
"USearch", "USearch",
"Vald",
"Vectara", "Vectara",
"VectorStore", "VectorStore",
"Weaviate", "Weaviate",

@ -0,0 +1,375 @@
"""Wrapper around Vald vector database."""
from __future__ import annotations
from typing import Any, Iterable, List, Optional, Tuple, Type
import numpy as np
from langchain.docstore.document import Document
from langchain.schema.embeddings import Embeddings
from langchain.vectorstores.base import VectorStore
from langchain.vectorstores.utils import maximal_marginal_relevance
class Vald(VectorStore):
"""Wrapper around Vald vector database.
To use, you should have the ``vald-client-python`` python package installed.
Example:
.. code-block:: python
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import Vald
texts = ['foo', 'bar', 'baz']
vald = Vald.from_texts(
texts=texts,
embedding=HuggingFaceEmbeddings(),
host="localhost",
port=8080,
skip_strict_exist_check=False,
)
"""
def __init__(
self,
embedding: Embeddings,
host: str = "localhost",
port: int = 8080,
grpc_options: Tuple = (
("grpc.keepalive_time_ms", 1000 * 10),
("grpc.keepalive_timeout_ms", 1000 * 10),
),
):
self._embedding = embedding
self.target = host + ":" + str(port)
self.grpc_options = grpc_options
@property
def embeddings(self) -> Optional[Embeddings]:
return self._embedding
def add_texts(
self,
texts: Iterable[str],
metadatas: Optional[List[dict]] = None,
skip_strict_exist_check: bool = False,
**kwargs: Any,
) -> List[str]:
"""
Args:
skip_strict_exist_check: Deprecated. This is not used basically.
"""
try:
import grpc
from vald.v1.payload import payload_pb2
from vald.v1.vald import upsert_pb2_grpc
except ImportError:
raise ValueError(
"Could not import vald-client-python python package. "
"Please install it with `pip install vald-client-python`."
)
channel = grpc.insecure_channel(self.target, options=self.grpc_options)
# Depending on the network quality,
# it is necessary to wait for ChannelConnectivity.READY.
# _ = grpc.channel_ready_future(channel).result(timeout=10)
stub = upsert_pb2_grpc.UpsertStub(channel)
cfg = payload_pb2.Upsert.Config(skip_strict_exist_check=skip_strict_exist_check)
ids = []
embs = self._embedding.embed_documents(list(texts))
for text, emb in zip(texts, embs):
vec = payload_pb2.Object.Vector(id=text, vector=emb)
res = stub.Upsert(payload_pb2.Upsert.Request(vector=vec, config=cfg))
ids.append(res.uuid)
channel.close()
return ids
def delete(
self,
ids: Optional[List[str]] = None,
skip_strict_exist_check: bool = False,
**kwargs: Any,
) -> Optional[bool]:
"""
Args:
skip_strict_exist_check: Deprecated. This is not used basically.
"""
try:
import grpc
from vald.v1.payload import payload_pb2
from vald.v1.vald import remove_pb2_grpc
except ImportError:
raise ValueError(
"Could not import vald-client-python python package. "
"Please install it with `pip install vald-client-python`."
)
if ids is None:
raise ValueError("No ids provided to delete")
channel = grpc.insecure_channel(self.target, options=self.grpc_options)
# Depending on the network quality,
# it is necessary to wait for ChannelConnectivity.READY.
# _ = grpc.channel_ready_future(channel).result(timeout=10)
stub = remove_pb2_grpc.RemoveStub(channel)
cfg = payload_pb2.Remove.Config(skip_strict_exist_check=skip_strict_exist_check)
for _id in ids:
oid = payload_pb2.Object.ID(id=_id)
_ = stub.Remove(payload_pb2.Remove.Request(id=oid, config=cfg))
channel.close()
return True
def similarity_search(
self,
query: str,
k: int = 4,
radius: float = -1.0,
epsilon: float = 0.01,
timeout: int = 3000000000,
**kwargs: Any,
) -> List[Document]:
docs_and_scores = self.similarity_search_with_score(
query, k, radius, epsilon, timeout
)
docs = []
for doc, _ in docs_and_scores:
docs.append(doc)
return docs
def similarity_search_with_score(
self,
query: str,
k: int = 4,
radius: float = -1.0,
epsilon: float = 0.01,
timeout: int = 3000000000,
**kwargs: Any,
) -> List[Tuple[Document, float]]:
emb = self._embedding.embed_query(query)
docs_and_scores = self.similarity_search_with_score_by_vector(
emb, k, radius, epsilon, timeout
)
return docs_and_scores
def similarity_search_by_vector(
self,
embedding: List[float],
k: int = 4,
radius: float = -1.0,
epsilon: float = 0.01,
timeout: int = 3000000000,
**kwargs: Any,
) -> List[Document]:
docs_and_scores = self.similarity_search_with_score_by_vector(
embedding, k, radius, epsilon, timeout
)
docs = []
for doc, _ in docs_and_scores:
docs.append(doc)
return docs
def similarity_search_with_score_by_vector(
self,
embedding: List[float],
k: int = 4,
radius: float = -1.0,
epsilon: float = 0.01,
timeout: int = 3000000000,
**kwargs: Any,
) -> List[Tuple[Document, float]]:
try:
import grpc
from vald.v1.payload import payload_pb2
from vald.v1.vald import search_pb2_grpc
except ImportError:
raise ValueError(
"Could not import vald-client-python python package. "
"Please install it with `pip install vald-client-python`."
)
channel = grpc.insecure_channel(self.target, options=self.grpc_options)
# Depending on the network quality,
# it is necessary to wait for ChannelConnectivity.READY.
# _ = grpc.channel_ready_future(channel).result(timeout=10)
stub = search_pb2_grpc.SearchStub(channel)
cfg = payload_pb2.Search.Config(
num=k, radius=radius, epsilon=epsilon, timeout=timeout
)
res = stub.Search(payload_pb2.Search.Request(vector=embedding, config=cfg))
docs_and_scores = []
for result in res.results:
docs_and_scores.append((Document(page_content=result.id), result.distance))
channel.close()
return docs_and_scores
def max_marginal_relevance_search(
self,
query: str,
k: int = 4,
fetch_k: int = 20,
lambda_mult: float = 0.5,
radius: float = -1.0,
epsilon: float = 0.01,
timeout: int = 3000000000,
**kwargs: Any,
) -> List[Document]:
emb = self._embedding.embed_query(query)
docs = self.max_marginal_relevance_search_by_vector(
emb,
k=k,
fetch_k=fetch_k,
radius=radius,
epsilon=epsilon,
timeout=timeout,
lambda_mult=lambda_mult,
)
return docs
def max_marginal_relevance_search_by_vector(
self,
embedding: List[float],
k: int = 4,
fetch_k: int = 20,
lambda_mult: float = 0.5,
radius: float = -1.0,
epsilon: float = 0.01,
timeout: int = 3000000000,
**kwargs: Any,
) -> List[Document]:
try:
import grpc
from vald.v1.payload import payload_pb2
from vald.v1.vald import object_pb2_grpc
except ImportError:
raise ValueError(
"Could not import vald-client-python python package. "
"Please install it with `pip install vald-client-python`."
)
channel = grpc.insecure_channel(self.target, options=self.grpc_options)
# Depending on the network quality,
# it is necessary to wait for ChannelConnectivity.READY.
# _ = grpc.channel_ready_future(channel).result(timeout=10)
stub = object_pb2_grpc.ObjectStub(channel)
docs_and_scores = self.similarity_search_with_score_by_vector(
embedding, fetch_k=fetch_k, radius=radius, epsilon=epsilon, timeout=timeout
)
docs = []
embs = []
for doc, _ in docs_and_scores:
vec = stub.GetObject(
payload_pb2.Object.VectorRequest(
id=payload_pb2.Object.ID(id=doc.page_content)
)
)
embs.append(vec.vector)
docs.append(doc)
mmr = maximal_marginal_relevance(
np.array(embedding),
embs,
lambda_mult=lambda_mult,
k=k,
)
channel.close()
return [docs[i] for i in mmr]
@classmethod
def from_texts(
cls: Type[Vald],
texts: List[str],
embedding: Embeddings,
metadatas: Optional[List[dict]] = None,
host: str = "localhost",
port: int = 8080,
grpc_options: Tuple = (
("grpc.keepalive_time_ms", 1000 * 10),
("grpc.keepalive_timeout_ms", 1000 * 10),
),
skip_strict_exist_check: bool = False,
**kwargs: Any,
) -> Vald:
"""
Args:
skip_strict_exist_check: Deprecated. This is not used basically.
"""
vald = cls(
embedding=embedding,
host=host,
port=port,
grpc_options=grpc_options,
**kwargs,
)
vald.add_texts(
texts=texts,
metadatas=metadatas,
skip_strict_exist_check=skip_strict_exist_check,
)
return vald
"""We will support if there are any requests."""
# async def aadd_texts(
# self,
# texts: Iterable[str],
# metadatas: Optional[List[dict]] = None,
# **kwargs: Any,
# ) -> List[str]:
# pass
#
# def _select_relevance_score_fn(self) -> Callable[[float], float]:
# pass
#
# def _similarity_search_with_relevance_scores(
# self,
# query: str,
# k: int = 4,
# **kwargs: Any,
# ) -> List[Tuple[Document, float]]:
# pass
#
# def similarity_search_with_relevance_scores(
# self,
# query: str,
# k: int = 4,
# **kwargs: Any,
# ) -> List[Tuple[Document, float]]:
# pass
#
# async def amax_marginal_relevance_search_by_vector(
# self,
# embedding: List[float],
# k: int = 4,
# fetch_k: int = 20,
# lambda_mult: float = 0.5,
# **kwargs: Any,
# ) -> List[Document]:
# pass
#
# @classmethod
# async def afrom_texts(
# cls: Type[VST],
# texts: List[str],
# embedding: Embeddings,
# metadatas: Optional[List[dict]] = None,
# **kwargs: Any,
# ) -> VST:
# pass

@ -0,0 +1,170 @@
"""Test Vald functionality."""
import time
from typing import List, Optional
from langchain.docstore.document import Document
from langchain.vectorstores import Vald
from tests.integration_tests.vectorstores.fake_embeddings import (
FakeEmbeddings,
fake_texts,
)
"""
To run, you should have a Vald cluster.
https://github.com/vdaas/vald/blob/main/docs/tutorial/get-started.md
"""
WAIT_TIME = 90
def _vald_from_texts(
metadatas: Optional[List[dict]] = None,
host: str = "localhost",
port: int = 8080,
skip_strict_exist_check: bool = True,
) -> Vald:
return Vald.from_texts(
fake_texts,
FakeEmbeddings(),
metadatas=metadatas,
host=host,
port=port,
skip_strict_exist_check=skip_strict_exist_check,
)
def test_vald_add_texts() -> None:
texts = ["foo", "bar", "baz"]
metadatas = [{"page": i} for i in range(len(texts))]
docsearch = _vald_from_texts(metadatas=metadatas)
time.sleep(WAIT_TIME) # Wait for CreateIndex
output = docsearch.similarity_search("foo", k=10)
assert len(output) == 3
texts = ["a", "b", "c"]
metadatas = [{"page": i} for i in range(len(texts))]
docsearch.add_texts(texts, metadatas)
time.sleep(WAIT_TIME) # Wait for CreateIndex
output = docsearch.similarity_search("foo", k=10)
assert len(output) == 6
def test_vald_delete() -> None:
texts = ["foo", "bar", "baz"]
metadatas = [{"page": i} for i in range(len(texts))]
docsearch = _vald_from_texts(metadatas=metadatas)
time.sleep(WAIT_TIME)
output = docsearch.similarity_search("foo", k=10)
assert len(output) == 3
docsearch.delete(["foo"])
time.sleep(WAIT_TIME)
output = docsearch.similarity_search("foo", k=10)
assert len(output) == 2
def test_vald_search() -> None:
"""Test end to end construction and search."""
docsearch = _vald_from_texts()
time.sleep(WAIT_TIME)
output = docsearch.similarity_search("foo", k=3)
assert output == [
Document(page_content="foo"),
Document(page_content="bar"),
Document(page_content="baz"),
]
def test_vald_search_with_score() -> None:
"""Test end to end construction and search with scores."""
texts = ["foo", "bar", "baz"]
metadatas = [{"page": i} for i in range(len(texts))]
docsearch = _vald_from_texts(metadatas=metadatas)
time.sleep(WAIT_TIME)
output = docsearch.similarity_search_with_score("foo", k=3)
docs = [o[0] for o in output]
scores = [o[1] for o in output]
assert docs == [
Document(page_content="foo"),
Document(page_content="bar"),
Document(page_content="baz"),
]
assert scores[0] < scores[1] < scores[2]
def test_vald_search_by_vector() -> None:
"""Test end to end construction and search by vector."""
texts = ["foo", "bar", "baz"]
metadatas = [{"page": i} for i in range(len(texts))]
docsearch = _vald_from_texts(metadatas=metadatas)
time.sleep(WAIT_TIME)
embedding = FakeEmbeddings().embed_query("foo")
output = docsearch.similarity_search_by_vector(embedding, k=3)
assert output == [
Document(page_content="foo"),
Document(page_content="bar"),
Document(page_content="baz"),
]
def test_vald_search_with_score_by_vector() -> None:
"""Test end to end construction and search with scores by vector."""
texts = ["foo", "bar", "baz"]
metadatas = [{"page": i} for i in range(len(texts))]
docsearch = _vald_from_texts(metadatas=metadatas)
time.sleep(WAIT_TIME)
embedding = FakeEmbeddings().embed_query("foo")
output = docsearch.similarity_search_with_score_by_vector(embedding, k=3)
docs = [o[0] for o in output]
scores = [o[1] for o in output]
assert docs == [
Document(page_content="foo"),
Document(page_content="bar"),
Document(page_content="baz"),
]
assert scores[0] < scores[1] < scores[2]
def test_vald_max_marginal_relevance_search() -> None:
"""Test end to end construction and MRR search."""
texts = ["foo", "bar", "baz"]
metadatas = [{"page": i} for i in range(len(texts))]
docsearch = _vald_from_texts(metadatas=metadatas)
time.sleep(WAIT_TIME)
output = docsearch.max_marginal_relevance_search("foo", k=2, fetch_k=3)
assert output == [
Document(page_content="foo"),
Document(page_content="bar"),
]
def test_vald_max_marginal_relevance_search_by_vector() -> None:
"""Test end to end construction and MRR search by vector."""
texts = ["foo", "bar", "baz"]
metadatas = [{"page": i} for i in range(len(texts))]
docsearch = _vald_from_texts(metadatas=metadatas)
time.sleep(WAIT_TIME)
embedding = FakeEmbeddings().embed_query("foo")
output = docsearch.max_marginal_relevance_search_by_vector(
embedding, k=2, fetch_k=3
)
assert output == [
Document(page_content="foo"),
Document(page_content="bar"),
]
Loading…
Cancel
Save