parametrize redis (#2946)

This commit is contained in:
Harrison Chase 2023-04-15 12:47:36 -07:00 committed by GitHub
parent 36aa7f30e4
commit baf350e32b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -70,6 +70,9 @@ class Redis(VectorStore):
redis_url: str, redis_url: str,
index_name: str, index_name: str,
embedding_function: Callable, embedding_function: Callable,
content_key: str = "content",
metadata_key: str = "metadata",
vector_key: str = "content_vector",
**kwargs: Any, **kwargs: Any,
): ):
"""Initialize with necessary components.""" """Initialize with necessary components."""
@ -92,6 +95,9 @@ class Redis(VectorStore):
raise ValueError(f"Redis failed to connect: {e}") raise ValueError(f"Redis failed to connect: {e}")
self.client = redis_client self.client = redis_client
self.content_key = content_key
self.metadata_key = metadata_key
self.vector_key = vector_key
def add_texts( def add_texts(
self, self,
@ -112,11 +118,11 @@ class Redis(VectorStore):
pipeline.hset( pipeline.hset(
key, key,
mapping={ mapping={
"content": text, self.content_key: text,
"content_vector": np.array( self.vector_key: np.array(
self.embedding_function(text), dtype=np.float32 self.embedding_function(text), dtype=np.float32
).tobytes(), ).tobytes(),
"metadata": json.dumps(metadata), self.metadata_key: json.dumps(metadata),
}, },
) )
ids.append(key) ids.append(key)
@ -191,8 +197,8 @@ class Redis(VectorStore):
embedding = self.embedding_function(query) embedding = self.embedding_function(query)
# Prepare the Query # Prepare the Query
return_fields = ["metadata", "content", "vector_score"] return_fields = [self.metadata_key, self.content_key, "vector_score"]
vector_field = "content_vector" vector_field = self.vector_key
hybrid_fields = "*" hybrid_fields = "*"
base_query = ( base_query = (
f"{hybrid_fields}=>[KNN {k} @{vector_field} $vector AS vector_score]" f"{hybrid_fields}=>[KNN {k} @{vector_field} $vector AS vector_score]"
@ -232,6 +238,9 @@ class Redis(VectorStore):
embedding: Embeddings, embedding: Embeddings,
metadatas: Optional[List[dict]] = None, metadatas: Optional[List[dict]] = None,
index_name: Optional[str] = None, index_name: Optional[str] = None,
content_key: str = "content",
metadata_key: str = "metadata",
vector_key: str = "content_vector",
**kwargs: Any, **kwargs: Any,
) -> Redis: ) -> Redis:
"""Construct RediSearch wrapper from raw documents. """Construct RediSearch wrapper from raw documents.
@ -287,10 +296,10 @@ class Redis(VectorStore):
"COSINE" # distance metric for the vectors (ex. COSINE, IP, L2) "COSINE" # distance metric for the vectors (ex. COSINE, IP, L2)
) )
schema = ( schema = (
TextField(name="content"), TextField(name=content_key),
TextField(name="metadata"), TextField(name=metadata_key),
VectorField( VectorField(
"content_vector", vector_key,
"FLAT", "FLAT",
{ {
"TYPE": "FLOAT32", "TYPE": "FLOAT32",
@ -313,11 +322,9 @@ class Redis(VectorStore):
pipeline.hset( pipeline.hset(
key, key,
mapping={ mapping={
"content": text, content_key: text,
"content_vector": np.array( vector_key: np.array(embeddings[i], dtype=np.float32).tobytes(),
embeddings[i], dtype=np.float32 metadata_key: json.dumps(metadata),
).tobytes(),
"metadata": json.dumps(metadata),
}, },
) )
pipeline.execute() pipeline.execute()