mirror of
https://github.com/hwchase17/langchain
synced 2024-11-04 06:00:26 +00:00
310 lines
9.3 KiB
Python
310 lines
9.3 KiB
Python
|
from __future__ import annotations
|
||
|
|
||
|
import json
|
||
|
import logging
|
||
|
import uuid
|
||
|
from typing import Any, Iterable, List, Optional, Type
|
||
|
|
||
|
from langchain_core.documents import Document
|
||
|
from langchain_core.embeddings import Embeddings
|
||
|
from langchain_core.utils import get_from_dict_or_env
|
||
|
from langchain_core.vectorstores import VectorStore
|
||
|
|
||
|
logger = logging.getLogger(__name__)
|
||
|
|
||
|
|
||
|
def _uuid_key() -> str:
|
||
|
return uuid.uuid4().hex
|
||
|
|
||
|
|
||
|
class Tair(VectorStore):
|
||
|
"""`Tair` vector store."""
|
||
|
|
||
|
def __init__(
|
||
|
self,
|
||
|
embedding_function: Embeddings,
|
||
|
url: str,
|
||
|
index_name: str,
|
||
|
content_key: str = "content",
|
||
|
metadata_key: str = "metadata",
|
||
|
search_params: Optional[dict] = None,
|
||
|
**kwargs: Any,
|
||
|
):
|
||
|
self.embedding_function = embedding_function
|
||
|
self.index_name = index_name
|
||
|
try:
|
||
|
from tair import Tair as TairClient
|
||
|
except ImportError:
|
||
|
raise ImportError(
|
||
|
"Could not import tair python package. "
|
||
|
"Please install it with `pip install tair`."
|
||
|
)
|
||
|
try:
|
||
|
# connect to tair from url
|
||
|
client = TairClient.from_url(url, **kwargs)
|
||
|
except ValueError as e:
|
||
|
raise ValueError(f"Tair failed to connect: {e}")
|
||
|
|
||
|
self.client = client
|
||
|
self.content_key = content_key
|
||
|
self.metadata_key = metadata_key
|
||
|
self.search_params = search_params
|
||
|
|
||
|
@property
|
||
|
def embeddings(self) -> Embeddings:
|
||
|
return self.embedding_function
|
||
|
|
||
|
def create_index_if_not_exist(
|
||
|
self,
|
||
|
dim: int,
|
||
|
distance_type: str,
|
||
|
index_type: str,
|
||
|
data_type: str,
|
||
|
**kwargs: Any,
|
||
|
) -> bool:
|
||
|
index = self.client.tvs_get_index(self.index_name)
|
||
|
if index is not None:
|
||
|
logger.info("Index already exists")
|
||
|
return False
|
||
|
self.client.tvs_create_index(
|
||
|
self.index_name,
|
||
|
dim,
|
||
|
distance_type,
|
||
|
index_type,
|
||
|
data_type,
|
||
|
**kwargs,
|
||
|
)
|
||
|
return True
|
||
|
|
||
|
def add_texts(
|
||
|
self,
|
||
|
texts: Iterable[str],
|
||
|
metadatas: Optional[List[dict]] = None,
|
||
|
**kwargs: Any,
|
||
|
) -> List[str]:
|
||
|
"""Add texts data to an existing index."""
|
||
|
ids = []
|
||
|
keys = kwargs.get("keys", None)
|
||
|
use_hybrid_search = False
|
||
|
index = self.client.tvs_get_index(self.index_name)
|
||
|
if index is not None and index.get("lexical_algorithm") == "bm25":
|
||
|
use_hybrid_search = True
|
||
|
# Write data to tair
|
||
|
pipeline = self.client.pipeline(transaction=False)
|
||
|
embeddings = self.embedding_function.embed_documents(list(texts))
|
||
|
for i, text in enumerate(texts):
|
||
|
# Use provided key otherwise use default key
|
||
|
key = keys[i] if keys else _uuid_key()
|
||
|
metadata = metadatas[i] if metadatas else {}
|
||
|
if use_hybrid_search:
|
||
|
# tair use TEXT attr hybrid search
|
||
|
pipeline.tvs_hset(
|
||
|
self.index_name,
|
||
|
key,
|
||
|
embeddings[i],
|
||
|
False,
|
||
|
**{
|
||
|
"TEXT": text,
|
||
|
self.content_key: text,
|
||
|
self.metadata_key: json.dumps(metadata),
|
||
|
},
|
||
|
)
|
||
|
else:
|
||
|
pipeline.tvs_hset(
|
||
|
self.index_name,
|
||
|
key,
|
||
|
embeddings[i],
|
||
|
False,
|
||
|
**{
|
||
|
self.content_key: text,
|
||
|
self.metadata_key: json.dumps(metadata),
|
||
|
},
|
||
|
)
|
||
|
ids.append(key)
|
||
|
pipeline.execute()
|
||
|
return ids
|
||
|
|
||
|
def similarity_search(
|
||
|
self, query: str, k: int = 4, **kwargs: Any
|
||
|
) -> List[Document]:
|
||
|
"""
|
||
|
Returns the most similar indexed documents to the query text.
|
||
|
|
||
|
Args:
|
||
|
query (str): The query text for which to find similar documents.
|
||
|
k (int): The number of documents to return. Default is 4.
|
||
|
|
||
|
Returns:
|
||
|
List[Document]: A list of documents that are most similar to the query text.
|
||
|
"""
|
||
|
# Creates embedding vector from user query
|
||
|
embedding = self.embedding_function.embed_query(query)
|
||
|
|
||
|
keys_and_scores = self.client.tvs_knnsearch(
|
||
|
self.index_name, k, embedding, False, None, **kwargs
|
||
|
)
|
||
|
|
||
|
pipeline = self.client.pipeline(transaction=False)
|
||
|
for key, _ in keys_and_scores:
|
||
|
pipeline.tvs_hmget(
|
||
|
self.index_name, key, self.metadata_key, self.content_key
|
||
|
)
|
||
|
docs = pipeline.execute()
|
||
|
|
||
|
return [
|
||
|
Document(
|
||
|
page_content=d[1],
|
||
|
metadata=json.loads(d[0]),
|
||
|
)
|
||
|
for d in docs
|
||
|
]
|
||
|
|
||
|
@classmethod
|
||
|
def from_texts(
|
||
|
cls: Type[Tair],
|
||
|
texts: List[str],
|
||
|
embedding: Embeddings,
|
||
|
metadatas: Optional[List[dict]] = None,
|
||
|
index_name: str = "langchain",
|
||
|
content_key: str = "content",
|
||
|
metadata_key: str = "metadata",
|
||
|
**kwargs: Any,
|
||
|
) -> Tair:
|
||
|
try:
|
||
|
from tair import tairvector
|
||
|
except ImportError:
|
||
|
raise ValueError(
|
||
|
"Could not import tair python package. "
|
||
|
"Please install it with `pip install tair`."
|
||
|
)
|
||
|
url = get_from_dict_or_env(kwargs, "tair_url", "TAIR_URL")
|
||
|
if "tair_url" in kwargs:
|
||
|
kwargs.pop("tair_url")
|
||
|
|
||
|
distance_type = tairvector.DistanceMetric.InnerProduct
|
||
|
if "distance_type" in kwargs:
|
||
|
distance_type = kwargs.pop("distance_type")
|
||
|
index_type = tairvector.IndexType.HNSW
|
||
|
if "index_type" in kwargs:
|
||
|
index_type = kwargs.pop("index_type")
|
||
|
data_type = tairvector.DataType.Float32
|
||
|
if "data_type" in kwargs:
|
||
|
data_type = kwargs.pop("data_type")
|
||
|
index_params = {}
|
||
|
if "index_params" in kwargs:
|
||
|
index_params = kwargs.pop("index_params")
|
||
|
search_params = {}
|
||
|
if "search_params" in kwargs:
|
||
|
search_params = kwargs.pop("search_params")
|
||
|
|
||
|
keys = None
|
||
|
if "keys" in kwargs:
|
||
|
keys = kwargs.pop("keys")
|
||
|
try:
|
||
|
tair_vector_store = cls(
|
||
|
embedding,
|
||
|
url,
|
||
|
index_name,
|
||
|
content_key=content_key,
|
||
|
metadata_key=metadata_key,
|
||
|
search_params=search_params,
|
||
|
**kwargs,
|
||
|
)
|
||
|
except ValueError as e:
|
||
|
raise ValueError(f"tair failed to connect: {e}")
|
||
|
|
||
|
# Create embeddings for documents
|
||
|
embeddings = embedding.embed_documents(texts)
|
||
|
|
||
|
tair_vector_store.create_index_if_not_exist(
|
||
|
len(embeddings[0]),
|
||
|
distance_type,
|
||
|
index_type,
|
||
|
data_type,
|
||
|
**index_params,
|
||
|
)
|
||
|
|
||
|
tair_vector_store.add_texts(texts, metadatas, keys=keys)
|
||
|
return tair_vector_store
|
||
|
|
||
|
@classmethod
|
||
|
def from_documents(
|
||
|
cls,
|
||
|
documents: List[Document],
|
||
|
embedding: Embeddings,
|
||
|
metadatas: Optional[List[dict]] = None,
|
||
|
index_name: str = "langchain",
|
||
|
content_key: str = "content",
|
||
|
metadata_key: str = "metadata",
|
||
|
**kwargs: Any,
|
||
|
) -> Tair:
|
||
|
texts = [d.page_content for d in documents]
|
||
|
metadatas = [d.metadata for d in documents]
|
||
|
|
||
|
return cls.from_texts(
|
||
|
texts, embedding, metadatas, index_name, content_key, metadata_key, **kwargs
|
||
|
)
|
||
|
|
||
|
@staticmethod
|
||
|
def drop_index(
|
||
|
index_name: str = "langchain",
|
||
|
**kwargs: Any,
|
||
|
) -> bool:
|
||
|
"""
|
||
|
Drop an existing index.
|
||
|
|
||
|
Args:
|
||
|
index_name (str): Name of the index to drop.
|
||
|
|
||
|
Returns:
|
||
|
bool: True if the index is dropped successfully.
|
||
|
"""
|
||
|
try:
|
||
|
from tair import Tair as TairClient
|
||
|
except ImportError:
|
||
|
raise ValueError(
|
||
|
"Could not import tair python package. "
|
||
|
"Please install it with `pip install tair`."
|
||
|
)
|
||
|
url = get_from_dict_or_env(kwargs, "tair_url", "TAIR_URL")
|
||
|
|
||
|
try:
|
||
|
if "tair_url" in kwargs:
|
||
|
kwargs.pop("tair_url")
|
||
|
client = TairClient.from_url(url=url, **kwargs)
|
||
|
except ValueError as e:
|
||
|
raise ValueError(f"Tair connection error: {e}")
|
||
|
# delete index
|
||
|
ret = client.tvs_del_index(index_name)
|
||
|
if ret == 0:
|
||
|
# index not exist
|
||
|
logger.info("Index does not exist")
|
||
|
return False
|
||
|
return True
|
||
|
|
||
|
@classmethod
|
||
|
def from_existing_index(
|
||
|
cls,
|
||
|
embedding: Embeddings,
|
||
|
index_name: str = "langchain",
|
||
|
content_key: str = "content",
|
||
|
metadata_key: str = "metadata",
|
||
|
**kwargs: Any,
|
||
|
) -> Tair:
|
||
|
"""Connect to an existing Tair index."""
|
||
|
url = get_from_dict_or_env(kwargs, "tair_url", "TAIR_URL")
|
||
|
|
||
|
search_params = {}
|
||
|
if "search_params" in kwargs:
|
||
|
search_params = kwargs.pop("search_params")
|
||
|
|
||
|
return cls(
|
||
|
embedding,
|
||
|
url,
|
||
|
index_name,
|
||
|
content_key=content_key,
|
||
|
metadata_key=metadata_key,
|
||
|
search_params=search_params,
|
||
|
**kwargs,
|
||
|
)
|