mirror of https://github.com/hwchase17/langchain
Harrison/tair (#3770)
Co-authored-by: Seth Huang <848849+seth-hg@users.noreply.github.com>pull/3773/head
parent
502ba6a0be
commit
0c0f14407c
@ -0,0 +1,22 @@
|
||||
# Tair
|
||||
|
||||
This page covers how to use the Tair ecosystem within LangChain.
|
||||
|
||||
## Installation and Setup
|
||||
|
||||
Install Tair Python SDK with `pip install tair`.
|
||||
|
||||
## Wrappers
|
||||
|
||||
### VectorStore
|
||||
|
||||
There exists a wrapper around TairVector, allowing you to use it as a vectorstore,
|
||||
whether for semantic search or example selection.
|
||||
|
||||
To import this vectorstore:
|
||||
|
||||
```python
|
||||
from langchain.vectorstores import Tair
|
||||
```
|
||||
|
||||
For a more detailed walkthrough of the Tair wrapper, see [this notebook](../modules/indexes/vectorstores/examples/tair.ipynb)
|
@ -0,0 +1,286 @@
|
||||
"""Wrapper around Tair Vector."""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from typing import Any, Iterable, List, Optional, Type
|
||||
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
from langchain.vectorstores.base import VectorStore
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _uuid_key() -> str:
|
||||
return uuid.uuid4().hex
|
||||
|
||||
|
||||
class Tair(VectorStore):
|
||||
def __init__(
|
||||
self,
|
||||
embedding_function: Embeddings,
|
||||
url: str,
|
||||
index_name: str,
|
||||
content_key: str = "content",
|
||||
metadata_key: str = "metadata",
|
||||
search_params: Optional[dict] = None,
|
||||
**kwargs: Any,
|
||||
):
|
||||
self.embedding_function = embedding_function
|
||||
self.index_name = index_name
|
||||
try:
|
||||
from tair import Tair as TairClient
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"Could not import tair python package. "
|
||||
"Please install it with `pip install tair`."
|
||||
)
|
||||
try:
|
||||
# connect to tair from url
|
||||
client = TairClient.from_url(url, **kwargs)
|
||||
except ValueError as e:
|
||||
raise ValueError(f"Tair failed to connect: {e}")
|
||||
|
||||
self.client = client
|
||||
self.content_key = content_key
|
||||
self.metadata_key = metadata_key
|
||||
self.search_params = search_params
|
||||
|
||||
def create_index_if_not_exist(
|
||||
self,
|
||||
dim: int,
|
||||
distance_type: str,
|
||||
index_type: str,
|
||||
data_type: str,
|
||||
**kwargs: Any,
|
||||
) -> bool:
|
||||
index = self.client.tvs_get_index(self.index_name)
|
||||
if index is not None:
|
||||
logger.info("Index already exists")
|
||||
return False
|
||||
self.client.tvs_create_index(
|
||||
self.index_name,
|
||||
dim,
|
||||
distance_type,
|
||||
index_type,
|
||||
data_type,
|
||||
**kwargs,
|
||||
)
|
||||
return True
|
||||
|
||||
def add_texts(
|
||||
self,
|
||||
texts: Iterable[str],
|
||||
metadatas: Optional[List[dict]] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[str]:
|
||||
"""Add texts data to an existing index."""
|
||||
ids = []
|
||||
keys = kwargs.get("keys", None)
|
||||
# Write data to tair
|
||||
pipeline = self.client.pipeline(transaction=False)
|
||||
embeddings = self.embedding_function.embed_documents(list(texts))
|
||||
for i, text in enumerate(texts):
|
||||
# Use provided key otherwise use default key
|
||||
key = keys[i] if keys else _uuid_key()
|
||||
metadata = metadatas[i] if metadatas else {}
|
||||
pipeline.tvs_hset(
|
||||
self.index_name,
|
||||
key,
|
||||
embeddings[i],
|
||||
False,
|
||||
**{
|
||||
self.content_key: text,
|
||||
self.metadata_key: json.dumps(metadata),
|
||||
},
|
||||
)
|
||||
ids.append(key)
|
||||
pipeline.execute()
|
||||
return ids
|
||||
|
||||
def similarity_search(
|
||||
self, query: str, k: int = 4, **kwargs: Any
|
||||
) -> List[Document]:
|
||||
"""
|
||||
Returns the most similar indexed documents to the query text.
|
||||
|
||||
Args:
|
||||
query (str): The query text for which to find similar documents.
|
||||
k (int): The number of documents to return. Default is 4.
|
||||
|
||||
Returns:
|
||||
List[Document]: A list of documents that are most similar to the query text.
|
||||
"""
|
||||
# Creates embedding vector from user query
|
||||
embedding = self.embedding_function.embed_query(query)
|
||||
|
||||
keys_and_scores = self.client.tvs_knnsearch(
|
||||
self.index_name, k, embedding, False, None, **kwargs
|
||||
)
|
||||
|
||||
pipeline = self.client.pipeline(transaction=False)
|
||||
for key, _ in keys_and_scores:
|
||||
pipeline.tvs_hmget(
|
||||
self.index_name, key, self.metadata_key, self.content_key
|
||||
)
|
||||
docs = pipeline.execute()
|
||||
|
||||
return [
|
||||
Document(
|
||||
page_content=d[1],
|
||||
metadata=json.loads(d[0]),
|
||||
)
|
||||
for d in docs
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def from_texts(
|
||||
cls: Type[Tair],
|
||||
texts: List[str],
|
||||
embedding: Embeddings,
|
||||
metadatas: Optional[List[dict]] = None,
|
||||
index_name: str = "langchain",
|
||||
content_key: str = "content",
|
||||
metadata_key: str = "metadata",
|
||||
**kwargs: Any,
|
||||
) -> Tair:
|
||||
try:
|
||||
from tair import tairvector
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"Could not import tair python package. "
|
||||
"Please install it with `pip install tair`."
|
||||
)
|
||||
url = get_from_dict_or_env(kwargs, "tair_url", "TAIR_URL")
|
||||
if "tair_url" in kwargs:
|
||||
kwargs.pop("tair_url")
|
||||
|
||||
distance_type = tairvector.DistanceMetric.InnerProduct
|
||||
if "distance_type" in kwargs:
|
||||
distance_type = kwargs.pop("distance_typ")
|
||||
index_type = tairvector.IndexType.HNSW
|
||||
if "index_type" in kwargs:
|
||||
index_type = kwargs.pop("index_type")
|
||||
data_type = tairvector.DataType.Float32
|
||||
if "data_type" in kwargs:
|
||||
data_type = kwargs.pop("data_type")
|
||||
index_params = {}
|
||||
if "index_params" in kwargs:
|
||||
index_params = kwargs.pop("index_params")
|
||||
search_params = {}
|
||||
if "search_params" in kwargs:
|
||||
search_params = kwargs.pop("search_params")
|
||||
|
||||
keys = None
|
||||
if "keys" in kwargs:
|
||||
keys = kwargs.pop("keys")
|
||||
try:
|
||||
tair_vector_store = cls(
|
||||
embedding,
|
||||
url,
|
||||
index_name,
|
||||
content_key=content_key,
|
||||
metadata_key=metadata_key,
|
||||
search_params=search_params,
|
||||
**kwargs,
|
||||
)
|
||||
except ValueError as e:
|
||||
raise ValueError(f"tair failed to connect: {e}")
|
||||
|
||||
# Create embeddings for documents
|
||||
embeddings = embedding.embed_documents(texts)
|
||||
|
||||
tair_vector_store.create_index_if_not_exist(
|
||||
len(embeddings[0]),
|
||||
distance_type,
|
||||
index_type,
|
||||
data_type,
|
||||
**index_params,
|
||||
)
|
||||
|
||||
tair_vector_store.add_texts(texts, metadatas, keys=keys)
|
||||
return tair_vector_store
|
||||
|
||||
@classmethod
|
||||
def from_documents(
|
||||
cls,
|
||||
documents: List[Document],
|
||||
embedding: Embeddings,
|
||||
metadatas: Optional[List[dict]] = None,
|
||||
index_name: str = "langchain",
|
||||
content_key: str = "content",
|
||||
metadata_key: str = "metadata",
|
||||
**kwargs: Any,
|
||||
) -> Tair:
|
||||
texts = [d.page_content for d in documents]
|
||||
metadatas = [d.metadata for d in documents]
|
||||
|
||||
return cls.from_texts(
|
||||
texts, embedding, metadatas, index_name, content_key, metadata_key, **kwargs
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def drop_index(
|
||||
index_name: str = "langchain",
|
||||
**kwargs: Any,
|
||||
) -> bool:
|
||||
"""
|
||||
Drop an existing index.
|
||||
|
||||
Args:
|
||||
index_name (str): Name of the index to drop.
|
||||
|
||||
Returns:
|
||||
bool: True if the index is dropped successfully.
|
||||
"""
|
||||
try:
|
||||
from tair import Tair as TairClient
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"Could not import tair python package. "
|
||||
"Please install it with `pip install tair`."
|
||||
)
|
||||
url = get_from_dict_or_env(kwargs, "tair_url", "TAIR_URL")
|
||||
|
||||
try:
|
||||
if "tair_url" in kwargs:
|
||||
kwargs.pop("tair_url")
|
||||
client = TairClient.from_url(url=url, **kwargs)
|
||||
except ValueError as e:
|
||||
raise ValueError(f"Tair connection error: {e}")
|
||||
# delete index
|
||||
ret = client.tvs_del_index(index_name)
|
||||
if ret == 0:
|
||||
# index not exist
|
||||
logger.info("Index does not exist")
|
||||
return False
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def from_existing_index(
|
||||
cls,
|
||||
embedding: Embeddings,
|
||||
index_name: str = "langchain",
|
||||
content_key: str = "content",
|
||||
metadata_key: str = "metadata",
|
||||
**kwargs: Any,
|
||||
) -> Tair:
|
||||
"""Connect to an existing Tair index."""
|
||||
url = get_from_dict_or_env(kwargs, "tair_url", "TAIR_URL")
|
||||
|
||||
search_params = {}
|
||||
if "search_params" in kwargs:
|
||||
search_params = kwargs.pop("search_params")
|
||||
|
||||
return cls(
|
||||
embedding,
|
||||
url,
|
||||
index_name,
|
||||
content_key=content_key,
|
||||
metadata_key=metadata_key,
|
||||
search_params=search_params,
|
||||
**kwargs,
|
||||
)
|
@ -0,0 +1,15 @@
|
||||
"""Test tair functionality."""
|
||||
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.vectorstores.tair import Tair
|
||||
from tests.integration_tests.vectorstores.fake_embeddings import FakeEmbeddings
|
||||
|
||||
|
||||
def test_tair() -> None:
|
||||
"""Test end to end construction and search."""
|
||||
texts = ["foo", "bar", "baz"]
|
||||
docsearch = Tair.from_texts(
|
||||
texts, FakeEmbeddings(), tair_url="redis://localhost:6379"
|
||||
)
|
||||
output = docsearch.similarity_search("foo", k=1)
|
||||
assert output == [Document(page_content="foo")]
|
Loading…
Reference in New Issue