mirror of
https://github.com/hwchase17/langchain
synced 2024-11-11 19:11:02 +00:00
420 lines
13 KiB
Python
420 lines
13 KiB
Python
|
"""Wrapper around Vald vector database."""
|
||
|
from __future__ import annotations
|
||
|
|
||
|
from typing import Any, Iterable, List, Optional, Tuple, Type
|
||
|
|
||
|
import numpy as np
|
||
|
from langchain_core.documents import Document
|
||
|
from langchain_core.embeddings import Embeddings
|
||
|
from langchain_core.vectorstores import VectorStore
|
||
|
|
||
|
from langchain_community.vectorstores.utils import maximal_marginal_relevance
|
||
|
|
||
|
|
||
|
class Vald(VectorStore):
|
||
|
"""Wrapper around Vald vector database.
|
||
|
|
||
|
To use, you should have the ``vald-client-python`` python package installed.
|
||
|
|
||
|
Example:
|
||
|
.. code-block:: python
|
||
|
|
||
|
from langchain_community.embeddings import HuggingFaceEmbeddings
|
||
|
from langchain_community.vectorstores import Vald
|
||
|
|
||
|
texts = ['foo', 'bar', 'baz']
|
||
|
vald = Vald.from_texts(
|
||
|
texts=texts,
|
||
|
embedding=HuggingFaceEmbeddings(),
|
||
|
host="localhost",
|
||
|
port=8080,
|
||
|
skip_strict_exist_check=False,
|
||
|
)
|
||
|
"""
|
||
|
|
||
|
def __init__(
|
||
|
self,
|
||
|
embedding: Embeddings,
|
||
|
host: str = "localhost",
|
||
|
port: int = 8080,
|
||
|
grpc_options: Tuple = (
|
||
|
("grpc.keepalive_time_ms", 1000 * 10),
|
||
|
("grpc.keepalive_timeout_ms", 1000 * 10),
|
||
|
),
|
||
|
grpc_use_secure: bool = False,
|
||
|
grpc_credentials: Optional[Any] = None,
|
||
|
):
|
||
|
self._embedding = embedding
|
||
|
self.target = host + ":" + str(port)
|
||
|
self.grpc_options = grpc_options
|
||
|
self.grpc_use_secure = grpc_use_secure
|
||
|
self.grpc_credentials = grpc_credentials
|
||
|
|
||
|
@property
|
||
|
def embeddings(self) -> Optional[Embeddings]:
|
||
|
return self._embedding
|
||
|
|
||
|
def _get_channel(self) -> Any:
|
||
|
try:
|
||
|
import grpc
|
||
|
except ImportError:
|
||
|
raise ValueError(
|
||
|
"Could not import grpcio python package. "
|
||
|
"Please install it with `pip install grpcio`."
|
||
|
)
|
||
|
return (
|
||
|
grpc.secure_channel(
|
||
|
self.target, self.grpc_credentials, options=self.grpc_options
|
||
|
)
|
||
|
if self.grpc_use_secure
|
||
|
else grpc.insecure_channel(self.target, options=self.grpc_options)
|
||
|
)
|
||
|
|
||
|
def add_texts(
|
||
|
self,
|
||
|
texts: Iterable[str],
|
||
|
metadatas: Optional[List[dict]] = None,
|
||
|
grpc_metadata: Optional[Any] = None,
|
||
|
skip_strict_exist_check: bool = False,
|
||
|
**kwargs: Any,
|
||
|
) -> List[str]:
|
||
|
"""
|
||
|
Args:
|
||
|
skip_strict_exist_check: Deprecated. This is not used basically.
|
||
|
"""
|
||
|
try:
|
||
|
from vald.v1.payload import payload_pb2
|
||
|
from vald.v1.vald import upsert_pb2_grpc
|
||
|
except ImportError:
|
||
|
raise ValueError(
|
||
|
"Could not import vald-client-python python package. "
|
||
|
"Please install it with `pip install vald-client-python`."
|
||
|
)
|
||
|
|
||
|
channel = self._get_channel()
|
||
|
# Depending on the network quality,
|
||
|
# it is necessary to wait for ChannelConnectivity.READY.
|
||
|
# _ = grpc.channel_ready_future(channel).result(timeout=10)
|
||
|
stub = upsert_pb2_grpc.UpsertStub(channel)
|
||
|
cfg = payload_pb2.Upsert.Config(skip_strict_exist_check=skip_strict_exist_check)
|
||
|
|
||
|
ids = []
|
||
|
embs = self._embedding.embed_documents(list(texts))
|
||
|
for text, emb in zip(texts, embs):
|
||
|
vec = payload_pb2.Object.Vector(id=text, vector=emb)
|
||
|
res = stub.Upsert(
|
||
|
payload_pb2.Upsert.Request(vector=vec, config=cfg),
|
||
|
metadata=grpc_metadata,
|
||
|
)
|
||
|
ids.append(res.uuid)
|
||
|
|
||
|
channel.close()
|
||
|
return ids
|
||
|
|
||
|
def delete(
|
||
|
self,
|
||
|
ids: Optional[List[str]] = None,
|
||
|
skip_strict_exist_check: bool = False,
|
||
|
grpc_metadata: Optional[Any] = None,
|
||
|
**kwargs: Any,
|
||
|
) -> Optional[bool]:
|
||
|
"""
|
||
|
Args:
|
||
|
skip_strict_exist_check: Deprecated. This is not used basically.
|
||
|
"""
|
||
|
try:
|
||
|
from vald.v1.payload import payload_pb2
|
||
|
from vald.v1.vald import remove_pb2_grpc
|
||
|
except ImportError:
|
||
|
raise ValueError(
|
||
|
"Could not import vald-client-python python package. "
|
||
|
"Please install it with `pip install vald-client-python`."
|
||
|
)
|
||
|
|
||
|
if ids is None:
|
||
|
raise ValueError("No ids provided to delete")
|
||
|
|
||
|
channel = self._get_channel()
|
||
|
# Depending on the network quality,
|
||
|
# it is necessary to wait for ChannelConnectivity.READY.
|
||
|
# _ = grpc.channel_ready_future(channel).result(timeout=10)
|
||
|
stub = remove_pb2_grpc.RemoveStub(channel)
|
||
|
cfg = payload_pb2.Remove.Config(skip_strict_exist_check=skip_strict_exist_check)
|
||
|
|
||
|
for _id in ids:
|
||
|
oid = payload_pb2.Object.ID(id=_id)
|
||
|
_ = stub.Remove(
|
||
|
payload_pb2.Remove.Request(id=oid, config=cfg), metadata=grpc_metadata
|
||
|
)
|
||
|
|
||
|
channel.close()
|
||
|
return True
|
||
|
|
||
|
def similarity_search(
|
||
|
self,
|
||
|
query: str,
|
||
|
k: int = 4,
|
||
|
radius: float = -1.0,
|
||
|
epsilon: float = 0.01,
|
||
|
timeout: int = 3000000000,
|
||
|
grpc_metadata: Optional[Any] = None,
|
||
|
**kwargs: Any,
|
||
|
) -> List[Document]:
|
||
|
docs_and_scores = self.similarity_search_with_score(
|
||
|
query, k, radius, epsilon, timeout, grpc_metadata
|
||
|
)
|
||
|
|
||
|
docs = []
|
||
|
for doc, _ in docs_and_scores:
|
||
|
docs.append(doc)
|
||
|
|
||
|
return docs
|
||
|
|
||
|
def similarity_search_with_score(
|
||
|
self,
|
||
|
query: str,
|
||
|
k: int = 4,
|
||
|
radius: float = -1.0,
|
||
|
epsilon: float = 0.01,
|
||
|
timeout: int = 3000000000,
|
||
|
grpc_metadata: Optional[Any] = None,
|
||
|
**kwargs: Any,
|
||
|
) -> List[Tuple[Document, float]]:
|
||
|
emb = self._embedding.embed_query(query)
|
||
|
docs_and_scores = self.similarity_search_with_score_by_vector(
|
||
|
emb, k, radius, epsilon, timeout, grpc_metadata
|
||
|
)
|
||
|
|
||
|
return docs_and_scores
|
||
|
|
||
|
def similarity_search_by_vector(
|
||
|
self,
|
||
|
embedding: List[float],
|
||
|
k: int = 4,
|
||
|
radius: float = -1.0,
|
||
|
epsilon: float = 0.01,
|
||
|
timeout: int = 3000000000,
|
||
|
grpc_metadata: Optional[Any] = None,
|
||
|
**kwargs: Any,
|
||
|
) -> List[Document]:
|
||
|
docs_and_scores = self.similarity_search_with_score_by_vector(
|
||
|
embedding, k, radius, epsilon, timeout, grpc_metadata
|
||
|
)
|
||
|
|
||
|
docs = []
|
||
|
for doc, _ in docs_and_scores:
|
||
|
docs.append(doc)
|
||
|
|
||
|
return docs
|
||
|
|
||
|
def similarity_search_with_score_by_vector(
|
||
|
self,
|
||
|
embedding: List[float],
|
||
|
k: int = 4,
|
||
|
radius: float = -1.0,
|
||
|
epsilon: float = 0.01,
|
||
|
timeout: int = 3000000000,
|
||
|
grpc_metadata: Optional[Any] = None,
|
||
|
**kwargs: Any,
|
||
|
) -> List[Tuple[Document, float]]:
|
||
|
try:
|
||
|
from vald.v1.payload import payload_pb2
|
||
|
from vald.v1.vald import search_pb2_grpc
|
||
|
except ImportError:
|
||
|
raise ValueError(
|
||
|
"Could not import vald-client-python python package. "
|
||
|
"Please install it with `pip install vald-client-python`."
|
||
|
)
|
||
|
|
||
|
channel = self._get_channel()
|
||
|
# Depending on the network quality,
|
||
|
# it is necessary to wait for ChannelConnectivity.READY.
|
||
|
# _ = grpc.channel_ready_future(channel).result(timeout=10)
|
||
|
stub = search_pb2_grpc.SearchStub(channel)
|
||
|
cfg = payload_pb2.Search.Config(
|
||
|
num=k, radius=radius, epsilon=epsilon, timeout=timeout
|
||
|
)
|
||
|
|
||
|
res = stub.Search(
|
||
|
payload_pb2.Search.Request(vector=embedding, config=cfg),
|
||
|
metadata=grpc_metadata,
|
||
|
)
|
||
|
|
||
|
docs_and_scores = []
|
||
|
for result in res.results:
|
||
|
docs_and_scores.append((Document(page_content=result.id), result.distance))
|
||
|
|
||
|
channel.close()
|
||
|
return docs_and_scores
|
||
|
|
||
|
def max_marginal_relevance_search(
|
||
|
self,
|
||
|
query: str,
|
||
|
k: int = 4,
|
||
|
fetch_k: int = 20,
|
||
|
lambda_mult: float = 0.5,
|
||
|
radius: float = -1.0,
|
||
|
epsilon: float = 0.01,
|
||
|
timeout: int = 3000000000,
|
||
|
grpc_metadata: Optional[Any] = None,
|
||
|
**kwargs: Any,
|
||
|
) -> List[Document]:
|
||
|
emb = self._embedding.embed_query(query)
|
||
|
docs = self.max_marginal_relevance_search_by_vector(
|
||
|
emb,
|
||
|
k=k,
|
||
|
fetch_k=fetch_k,
|
||
|
radius=radius,
|
||
|
epsilon=epsilon,
|
||
|
timeout=timeout,
|
||
|
lambda_mult=lambda_mult,
|
||
|
grpc_metadata=grpc_metadata,
|
||
|
)
|
||
|
|
||
|
return docs
|
||
|
|
||
|
def max_marginal_relevance_search_by_vector(
|
||
|
self,
|
||
|
embedding: List[float],
|
||
|
k: int = 4,
|
||
|
fetch_k: int = 20,
|
||
|
lambda_mult: float = 0.5,
|
||
|
radius: float = -1.0,
|
||
|
epsilon: float = 0.01,
|
||
|
timeout: int = 3000000000,
|
||
|
grpc_metadata: Optional[Any] = None,
|
||
|
**kwargs: Any,
|
||
|
) -> List[Document]:
|
||
|
try:
|
||
|
from vald.v1.payload import payload_pb2
|
||
|
from vald.v1.vald import object_pb2_grpc
|
||
|
except ImportError:
|
||
|
raise ValueError(
|
||
|
"Could not import vald-client-python python package. "
|
||
|
"Please install it with `pip install vald-client-python`."
|
||
|
)
|
||
|
channel = self._get_channel()
|
||
|
# Depending on the network quality,
|
||
|
# it is necessary to wait for ChannelConnectivity.READY.
|
||
|
# _ = grpc.channel_ready_future(channel).result(timeout=10)
|
||
|
stub = object_pb2_grpc.ObjectStub(channel)
|
||
|
|
||
|
docs_and_scores = self.similarity_search_with_score_by_vector(
|
||
|
embedding,
|
||
|
fetch_k=fetch_k,
|
||
|
radius=radius,
|
||
|
epsilon=epsilon,
|
||
|
timeout=timeout,
|
||
|
grpc_metadata=grpc_metadata,
|
||
|
)
|
||
|
|
||
|
docs = []
|
||
|
embs = []
|
||
|
for doc, _ in docs_and_scores:
|
||
|
vec = stub.GetObject(
|
||
|
payload_pb2.Object.VectorRequest(
|
||
|
id=payload_pb2.Object.ID(id=doc.page_content)
|
||
|
),
|
||
|
metadata=grpc_metadata,
|
||
|
)
|
||
|
embs.append(vec.vector)
|
||
|
docs.append(doc)
|
||
|
|
||
|
mmr = maximal_marginal_relevance(
|
||
|
np.array(embedding),
|
||
|
embs,
|
||
|
lambda_mult=lambda_mult,
|
||
|
k=k,
|
||
|
)
|
||
|
|
||
|
channel.close()
|
||
|
return [docs[i] for i in mmr]
|
||
|
|
||
|
@classmethod
|
||
|
def from_texts(
|
||
|
cls: Type[Vald],
|
||
|
texts: List[str],
|
||
|
embedding: Embeddings,
|
||
|
metadatas: Optional[List[dict]] = None,
|
||
|
host: str = "localhost",
|
||
|
port: int = 8080,
|
||
|
grpc_options: Tuple = (
|
||
|
("grpc.keepalive_time_ms", 1000 * 10),
|
||
|
("grpc.keepalive_timeout_ms", 1000 * 10),
|
||
|
),
|
||
|
grpc_use_secure: bool = False,
|
||
|
grpc_credentials: Optional[Any] = None,
|
||
|
grpc_metadata: Optional[Any] = None,
|
||
|
skip_strict_exist_check: bool = False,
|
||
|
**kwargs: Any,
|
||
|
) -> Vald:
|
||
|
"""
|
||
|
Args:
|
||
|
skip_strict_exist_check: Deprecated. This is not used basically.
|
||
|
"""
|
||
|
vald = cls(
|
||
|
embedding=embedding,
|
||
|
host=host,
|
||
|
port=port,
|
||
|
grpc_options=grpc_options,
|
||
|
grpc_use_secure=grpc_use_secure,
|
||
|
grpc_credentials=grpc_credentials,
|
||
|
**kwargs,
|
||
|
)
|
||
|
vald.add_texts(
|
||
|
texts=texts,
|
||
|
metadatas=metadatas,
|
||
|
grpc_metadata=grpc_metadata,
|
||
|
skip_strict_exist_check=skip_strict_exist_check,
|
||
|
)
|
||
|
return vald
|
||
|
|
||
|
|
||
|
"""We will support if there are any requests."""
|
||
|
# async def aadd_texts(
|
||
|
# self,
|
||
|
# texts: Iterable[str],
|
||
|
# metadatas: Optional[List[dict]] = None,
|
||
|
# **kwargs: Any,
|
||
|
# ) -> List[str]:
|
||
|
# pass
|
||
|
#
|
||
|
# def _select_relevance_score_fn(self) -> Callable[[float], float]:
|
||
|
# pass
|
||
|
#
|
||
|
# def _similarity_search_with_relevance_scores(
|
||
|
# self,
|
||
|
# query: str,
|
||
|
# k: int = 4,
|
||
|
# **kwargs: Any,
|
||
|
# ) -> List[Tuple[Document, float]]:
|
||
|
# pass
|
||
|
#
|
||
|
# def similarity_search_with_relevance_scores(
|
||
|
# self,
|
||
|
# query: str,
|
||
|
# k: int = 4,
|
||
|
# **kwargs: Any,
|
||
|
# ) -> List[Tuple[Document, float]]:
|
||
|
# pass
|
||
|
#
|
||
|
# async def amax_marginal_relevance_search_by_vector(
|
||
|
# self,
|
||
|
# embedding: List[float],
|
||
|
# k: int = 4,
|
||
|
# fetch_k: int = 20,
|
||
|
# lambda_mult: float = 0.5,
|
||
|
# **kwargs: Any,
|
||
|
# ) -> List[Document]:
|
||
|
# pass
|
||
|
#
|
||
|
# @classmethod
|
||
|
# async def afrom_texts(
|
||
|
# cls: Type[VST],
|
||
|
# texts: List[str],
|
||
|
# embedding: Embeddings,
|
||
|
# metadatas: Optional[List[dict]] = None,
|
||
|
# **kwargs: Any,
|
||
|
# ) -> VST:
|
||
|
# pass
|