From baf350e32bdc85f61efde736eeb61c243b8437b2 Mon Sep 17 00:00:00 2001 From: Harrison Chase Date: Sat, 15 Apr 2023 12:47:36 -0700 Subject: [PATCH] parametrize redis (#2946) --- langchain/vectorstores/redis.py | 33 ++++++++++++++++++++------------- 1 file changed, 20 insertions(+), 13 deletions(-) diff --git a/langchain/vectorstores/redis.py b/langchain/vectorstores/redis.py index 02a5d19d1a..30b926518b 100644 --- a/langchain/vectorstores/redis.py +++ b/langchain/vectorstores/redis.py @@ -70,6 +70,9 @@ class Redis(VectorStore): redis_url: str, index_name: str, embedding_function: Callable, + content_key: str = "content", + metadata_key: str = "metadata", + vector_key: str = "content_vector", **kwargs: Any, ): """Initialize with necessary components.""" @@ -92,6 +95,9 @@ class Redis(VectorStore): raise ValueError(f"Redis failed to connect: {e}") self.client = redis_client + self.content_key = content_key + self.metadata_key = metadata_key + self.vector_key = vector_key def add_texts( self, @@ -112,11 +118,11 @@ class Redis(VectorStore): pipeline.hset( key, mapping={ - "content": text, - "content_vector": np.array( + self.content_key: text, + self.vector_key: np.array( self.embedding_function(text), dtype=np.float32 ).tobytes(), - "metadata": json.dumps(metadata), + self.metadata_key: json.dumps(metadata), }, ) ids.append(key) @@ -191,8 +197,8 @@ class Redis(VectorStore): embedding = self.embedding_function(query) # Prepare the Query - return_fields = ["metadata", "content", "vector_score"] - vector_field = "content_vector" + return_fields = [self.metadata_key, self.content_key, "vector_score"] + vector_field = self.vector_key hybrid_fields = "*" base_query = ( f"{hybrid_fields}=>[KNN {k} @{vector_field} $vector AS vector_score]" @@ -232,6 +238,9 @@ class Redis(VectorStore): embedding: Embeddings, metadatas: Optional[List[dict]] = None, index_name: Optional[str] = None, + content_key: str = "content", + metadata_key: str = "metadata", + vector_key: str = "content_vector", **kwargs: Any, ) -> Redis: """Construct RediSearch wrapper from raw documents. @@ -287,10 +296,10 @@ class Redis(VectorStore): "COSINE" # distance metric for the vectors (ex. COSINE, IP, L2) ) schema = ( - TextField(name="content"), - TextField(name="metadata"), + TextField(name=content_key), + TextField(name=metadata_key), VectorField( - "content_vector", + vector_key, "FLAT", { "TYPE": "FLOAT32", @@ -313,11 +322,9 @@ class Redis(VectorStore): pipeline.hset( key, mapping={ - "content": text, - "content_vector": np.array( - embeddings[i], dtype=np.float32 - ).tobytes(), - "metadata": json.dumps(metadata), + content_key: text, + vector_key: np.array(embeddings[i], dtype=np.float32).tobytes(), + metadata_key: json.dumps(metadata), }, ) pipeline.execute()