parametrize redis (#2946)

This commit is contained in:
Harrison Chase 2023-04-15 12:47:36 -07:00 committed by GitHub
parent 36aa7f30e4
commit baf350e32b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -70,6 +70,9 @@ class Redis(VectorStore):
redis_url: str,
index_name: str,
embedding_function: Callable,
content_key: str = "content",
metadata_key: str = "metadata",
vector_key: str = "content_vector",
**kwargs: Any,
):
"""Initialize with necessary components."""
@ -92,6 +95,9 @@ class Redis(VectorStore):
raise ValueError(f"Redis failed to connect: {e}")
self.client = redis_client
self.content_key = content_key
self.metadata_key = metadata_key
self.vector_key = vector_key
def add_texts(
self,
@ -112,11 +118,11 @@ class Redis(VectorStore):
pipeline.hset(
key,
mapping={
"content": text,
"content_vector": np.array(
self.content_key: text,
self.vector_key: np.array(
self.embedding_function(text), dtype=np.float32
).tobytes(),
"metadata": json.dumps(metadata),
self.metadata_key: json.dumps(metadata),
},
)
ids.append(key)
@ -191,8 +197,8 @@ class Redis(VectorStore):
embedding = self.embedding_function(query)
# Prepare the Query
return_fields = ["metadata", "content", "vector_score"]
vector_field = "content_vector"
return_fields = [self.metadata_key, self.content_key, "vector_score"]
vector_field = self.vector_key
hybrid_fields = "*"
base_query = (
f"{hybrid_fields}=>[KNN {k} @{vector_field} $vector AS vector_score]"
@ -232,6 +238,9 @@ class Redis(VectorStore):
embedding: Embeddings,
metadatas: Optional[List[dict]] = None,
index_name: Optional[str] = None,
content_key: str = "content",
metadata_key: str = "metadata",
vector_key: str = "content_vector",
**kwargs: Any,
) -> Redis:
"""Construct RediSearch wrapper from raw documents.
@ -287,10 +296,10 @@ class Redis(VectorStore):
"COSINE" # distance metric for the vectors (ex. COSINE, IP, L2)
)
schema = (
TextField(name="content"),
TextField(name="metadata"),
TextField(name=content_key),
TextField(name=metadata_key),
VectorField(
"content_vector",
vector_key,
"FLAT",
{
"TYPE": "FLOAT32",
@ -313,11 +322,9 @@ class Redis(VectorStore):
pipeline.hset(
key,
mapping={
"content": text,
"content_vector": np.array(
embeddings[i], dtype=np.float32
).tobytes(),
"metadata": json.dumps(metadata),
content_key: text,
vector_key: np.array(embeddings[i], dtype=np.float32).tobytes(),
metadata_key: json.dumps(metadata),
},
)
pipeline.execute()