mirror of
https://github.com/hwchase17/langchain
synced 2024-11-06 03:20:49 +00:00
Harrison/redis improvements (#2528)
Co-authored-by: Tyler Hutcherson <tyler.hutcherson@redis.com>
This commit is contained in:
parent
ec489599fd
commit
a31c9511e8
@ -1,32 +1,34 @@
|
|||||||
{
|
{
|
||||||
"cells": [
|
"cells": [
|
||||||
{
|
{
|
||||||
|
"attachments": {},
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"source": [
|
"source": [
|
||||||
"# Redis\n",
|
"# Redis\n",
|
||||||
"\n",
|
"\n",
|
||||||
"This notebook shows how to use functionality related to the Redis database."
|
"This notebook shows how to use functionality related to the [Redis vector database](https://redis.com/solutions/use-cases/vector-database/)."
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 1,
|
"execution_count": 3,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"from langchain.embeddings.openai import OpenAIEmbeddings\n",
|
"from langchain.embeddings import OpenAIEmbeddings\n",
|
||||||
"from langchain.text_splitter import CharacterTextSplitter\n",
|
"from langchain.text_splitter import CharacterTextSplitter\n",
|
||||||
"from langchain.vectorstores.redis import Redis"
|
"from langchain.vectorstores.redis import Redis"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 2,
|
"execution_count": 4,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"from langchain.document_loaders import TextLoader\n",
|
"from langchain.document_loaders import TextLoader\n",
|
||||||
|
"\n",
|
||||||
"loader = TextLoader('../../../state_of_the_union.txt')\n",
|
"loader = TextLoader('../../../state_of_the_union.txt')\n",
|
||||||
"documents = loader.load()\n",
|
"documents = loader.load()\n",
|
||||||
"text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)\n",
|
"text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)\n",
|
||||||
@ -37,7 +39,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 3,
|
"execution_count": 5,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
@ -46,7 +48,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 4,
|
"execution_count": 6,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
@ -55,7 +57,7 @@
|
|||||||
"'link'"
|
"'link'"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"execution_count": 4,
|
"execution_count": 6,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"output_type": "execute_result"
|
"output_type": "execute_result"
|
||||||
}
|
}
|
||||||
@ -66,7 +68,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 5,
|
"execution_count": 7,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
@ -91,14 +93,14 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 7,
|
"execution_count": 8,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
"name": "stdout",
|
"name": "stdout",
|
||||||
"output_type": "stream",
|
"output_type": "stream",
|
||||||
"text": [
|
"text": [
|
||||||
"['doc:333eadf75bd74be393acafa8bca48669']\n"
|
"['doc:link:d7d02e3faf1b40bbbe29a683ff75b280']\n"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
@ -108,7 +110,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 8,
|
"execution_count": 9,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
@ -127,11 +129,25 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": 10,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Tonight. I call on the Senate to: Pass the Freedom to Vote Act. Pass the John Lewis Voting Rights Act. And while you’re at it, pass the Disclose Act so Americans can know who is funding our elections. \n",
|
||||||
|
"\n",
|
||||||
|
"Tonight, I’d like to honor someone who has dedicated his life to serve this country: Justice Stephen Breyer—an Army veteran, Constitutional scholar, and retiring Justice of the United States Supreme Court. Justice Breyer, thank you for your service. \n",
|
||||||
|
"\n",
|
||||||
|
"One of the most serious constitutional responsibilities a President has is nominating someone to serve on the United States Supreme Court. \n",
|
||||||
|
"\n",
|
||||||
|
"And I did that 4 days ago, when I nominated Circuit Court of Appeals Judge Ketanji Brown Jackson. One of our nation’s top legal minds, who will continue Justice Breyer’s legacy of excellence.\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"#Query\n",
|
"# Load from existing index\n",
|
||||||
"rds = Redis.from_existing_index(embeddings, redis_url=\"redis://localhost:6379\", index_name='link')\n",
|
"rds = Redis.from_existing_index(embeddings, redis_url=\"redis://localhost:6379\", index_name='link')\n",
|
||||||
"\n",
|
"\n",
|
||||||
"query = \"What did the president say about Ketanji Brown Jackson\"\n",
|
"query = \"What did the president say about Ketanji Brown Jackson\"\n",
|
||||||
@ -152,7 +168,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 6,
|
"execution_count": 11,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
@ -161,7 +177,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 7,
|
"execution_count": 12,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
@ -177,7 +193,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 5,
|
"execution_count": 13,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
@ -186,31 +202,13 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 6,
|
"execution_count": null,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [],
|
||||||
{
|
|
||||||
"data": {
|
|
||||||
"text/plain": [
|
|
||||||
"[]"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"execution_count": 6,
|
|
||||||
"metadata": {},
|
|
||||||
"output_type": "execute_result"
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
"source": [
|
||||||
"# Here we can see it doesn't return any results because there are no relevant documents\n",
|
"# Here we can see it doesn't return any results because there are no relevant documents\n",
|
||||||
"retriever.get_relevant_documents(\"where did ankush go to college?\")"
|
"retriever.get_relevant_documents(\"where did ankush go to college?\")"
|
||||||
]
|
]
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": []
|
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"metadata": {
|
"metadata": {
|
||||||
@ -229,7 +227,7 @@
|
|||||||
"name": "python",
|
"name": "python",
|
||||||
"nbconvert_exporter": "python",
|
"nbconvert_exporter": "python",
|
||||||
"pygments_lexer": "ipython3",
|
"pygments_lexer": "ipython3",
|
||||||
"version": "3.9.1"
|
"version": "3.9.16"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"nbformat": 4,
|
"nbformat": 4,
|
||||||
|
@ -7,7 +7,7 @@ import uuid
|
|||||||
from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Tuple
|
from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from pydantic import BaseModel, Field, root_validator
|
from pydantic import BaseModel, root_validator
|
||||||
from redis.client import Redis as RedisType
|
from redis.client import Redis as RedisType
|
||||||
|
|
||||||
from langchain.docstore.document import Document
|
from langchain.docstore.document import Document
|
||||||
@ -19,8 +19,47 @@ from langchain.vectorstores.base import VectorStore
|
|||||||
logger = logging.getLogger()
|
logger = logging.getLogger()
|
||||||
|
|
||||||
|
|
||||||
def _check_redis_module_exist(client: RedisType, module: str) -> bool:
|
# required modules
|
||||||
return module in [m["name"] for m in client.info().get("modules", {"name": ""})]
|
REDIS_REQUIRED_MODULES = [
|
||||||
|
{"name": "search", "ver": 20400},
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def _check_redis_module_exist(client: RedisType, modules: List[dict]) -> None:
|
||||||
|
"""Check if the correct Redis modules are installed."""
|
||||||
|
installed_modules = client.info().get("modules", [])
|
||||||
|
installed_modules = {module["name"]: module for module in installed_modules}
|
||||||
|
for module in modules:
|
||||||
|
if module["name"] not in installed_modules or int(
|
||||||
|
installed_modules[module["name"]]["ver"]
|
||||||
|
) < int(module["ver"]):
|
||||||
|
error_message = (
|
||||||
|
"You must add the RediSearch (>= 2.4) module from Redis Stack. "
|
||||||
|
"Please refer to Redis Stack docs: https://redis.io/docs/stack/"
|
||||||
|
)
|
||||||
|
logging.error(error_message)
|
||||||
|
raise ValueError(error_message)
|
||||||
|
|
||||||
|
|
||||||
|
def _check_index_exists(client: RedisType, index_name: str) -> bool:
|
||||||
|
"""Check if Redis index exists."""
|
||||||
|
try:
|
||||||
|
client.ft(index_name).info()
|
||||||
|
except: # noqa: E722
|
||||||
|
logger.info("Index does not exist")
|
||||||
|
return False
|
||||||
|
logger.info("Index already exists")
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def _redis_key(prefix: str) -> str:
|
||||||
|
"""Redis key schema for a given prefix."""
|
||||||
|
return f"{prefix}:{uuid.uuid4().hex}"
|
||||||
|
|
||||||
|
|
||||||
|
def _redis_prefix(index_name: str) -> str:
|
||||||
|
"""Redis key prefix for a given index."""
|
||||||
|
return f"doc:{index_name}"
|
||||||
|
|
||||||
|
|
||||||
class Redis(VectorStore):
|
class Redis(VectorStore):
|
||||||
@ -43,16 +82,12 @@ class Redis(VectorStore):
|
|||||||
self.embedding_function = embedding_function
|
self.embedding_function = embedding_function
|
||||||
self.index_name = index_name
|
self.index_name = index_name
|
||||||
try:
|
try:
|
||||||
|
# connect to redis from url
|
||||||
redis_client = redis.from_url(redis_url, **kwargs)
|
redis_client = redis.from_url(redis_url, **kwargs)
|
||||||
|
# check if redis has redisearch module installed
|
||||||
|
_check_redis_module_exist(redis_client, REDIS_REQUIRED_MODULES)
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
raise ValueError(f"Your redis connected error: {e}")
|
raise ValueError(f"Redis failed to connect: {e}")
|
||||||
|
|
||||||
# check if redis add redisearch module
|
|
||||||
if not _check_redis_module_exist(redis_client, "search"):
|
|
||||||
raise ValueError(
|
|
||||||
"Could not use redis directly, you need to add search module"
|
|
||||||
"Please refer [RediSearch](https://redis.io/docs/stack/search/quick_start/)" # noqa
|
|
||||||
)
|
|
||||||
|
|
||||||
self.client = redis_client
|
self.client = redis_client
|
||||||
|
|
||||||
@ -62,17 +97,17 @@ class Redis(VectorStore):
|
|||||||
metadatas: Optional[List[dict]] = None,
|
metadatas: Optional[List[dict]] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> List[str]:
|
) -> List[str]:
|
||||||
# `prefix`: Maybe in the future we can let the user choose the index_name.
|
"""Add texts data to an existing index."""
|
||||||
prefix = "doc" # prefix for the document keys
|
prefix = _redis_prefix(self.index_name)
|
||||||
keys = kwargs.get("keys")
|
keys = kwargs.get("keys")
|
||||||
|
|
||||||
ids = []
|
ids = []
|
||||||
# Check if index exists
|
# Write data to redis
|
||||||
|
pipeline = self.client.pipeline(transaction=False)
|
||||||
for i, text in enumerate(texts):
|
for i, text in enumerate(texts):
|
||||||
_key = keys[i] if keys else self.index_name
|
# Use provided key otherwise use default key
|
||||||
key = f"{prefix}:{_key}"
|
key = keys[i] if keys else _redis_key(prefix)
|
||||||
metadata = metadatas[i] if metadatas else {}
|
metadata = metadatas[i] if metadatas else {}
|
||||||
self.client.hset(
|
pipeline.hset(
|
||||||
key,
|
key,
|
||||||
mapping={
|
mapping={
|
||||||
"content": text,
|
"content": text,
|
||||||
@ -83,11 +118,22 @@ class Redis(VectorStore):
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
ids.append(key)
|
ids.append(key)
|
||||||
|
pipeline.execute()
|
||||||
return ids
|
return ids
|
||||||
|
|
||||||
def similarity_search(
|
def similarity_search(
|
||||||
self, query: str, k: int = 4, **kwargs: Any
|
self, query: str, k: int = 4, **kwargs: Any
|
||||||
) -> List[Document]:
|
) -> List[Document]:
|
||||||
|
"""
|
||||||
|
Returns the most similar indexed documents to the query text.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query (str): The query text for which to find similar documents.
|
||||||
|
k (int): The number of documents to return. Default is 4.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[Document]: A list of documents that are most similar to the query text.
|
||||||
|
"""
|
||||||
docs_and_scores = self.similarity_search_with_score(query, k=k)
|
docs_and_scores = self.similarity_search_with_score(query, k=k)
|
||||||
return [doc for doc, _ in docs_and_scores]
|
return [doc for doc, _ in docs_and_scores]
|
||||||
|
|
||||||
@ -95,7 +141,8 @@ class Redis(VectorStore):
|
|||||||
self, query: str, k: int = 4, score_threshold: float = 0.2, **kwargs: Any
|
self, query: str, k: int = 4, score_threshold: float = 0.2, **kwargs: Any
|
||||||
) -> List[Document]:
|
) -> List[Document]:
|
||||||
"""
|
"""
|
||||||
Returns the most similar indexed documents to the query text.
|
Returns the most similar indexed documents to the query text within the
|
||||||
|
score_threshold range.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
query (str): The query text for which to find similar documents.
|
query (str): The query text for which to find similar documents.
|
||||||
@ -217,55 +264,49 @@ class Redis(VectorStore):
|
|||||||
# otherwise passing it to Redis will result in an error.
|
# otherwise passing it to Redis will result in an error.
|
||||||
kwargs.pop("redis_url")
|
kwargs.pop("redis_url")
|
||||||
client = redis.from_url(url=redis_url, **kwargs)
|
client = redis.from_url(url=redis_url, **kwargs)
|
||||||
|
# check if redis has redisearch module installed
|
||||||
|
_check_redis_module_exist(client, REDIS_REQUIRED_MODULES)
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
raise ValueError(f"Your redis connected error: {e}")
|
raise ValueError(f"Redis failed to connect: {e}")
|
||||||
|
|
||||||
# check if redis add redisearch module
|
|
||||||
if not _check_redis_module_exist(client, "search"):
|
|
||||||
raise ValueError(
|
|
||||||
"Could not use redis directly, you need to add search module"
|
|
||||||
"Please refer [RediSearch](https://redis.io/docs/stack/search/quick_start/)" # noqa
|
|
||||||
)
|
|
||||||
|
|
||||||
|
# Create embeddings over documents
|
||||||
embeddings = embedding.embed_documents(texts)
|
embeddings = embedding.embed_documents(texts)
|
||||||
dim = len(embeddings[0])
|
|
||||||
# Constants
|
# Name of the search index if not given
|
||||||
vector_number = len(embeddings) # initial number of vectors
|
|
||||||
# name of the search index if not given
|
|
||||||
if not index_name:
|
if not index_name:
|
||||||
index_name = uuid.uuid4().hex
|
index_name = uuid.uuid4().hex
|
||||||
prefix = f"doc:{index_name}" # prefix for the document keys
|
prefix = _redis_prefix(index_name) # prefix for the document keys
|
||||||
|
|
||||||
|
# Check if index exists
|
||||||
|
if not _check_index_exists(client, index_name):
|
||||||
|
# Constants
|
||||||
|
dim = len(embeddings[0])
|
||||||
distance_metric = (
|
distance_metric = (
|
||||||
"COSINE" # distance metric for the vectors (ex. COSINE, IP, L2)
|
"COSINE" # distance metric for the vectors (ex. COSINE, IP, L2)
|
||||||
)
|
)
|
||||||
content = TextField(name="content")
|
schema = (
|
||||||
metadata = TextField(name="metadata")
|
TextField(name="content"),
|
||||||
content_embedding = VectorField(
|
TextField(name="metadata"),
|
||||||
|
VectorField(
|
||||||
"content_vector",
|
"content_vector",
|
||||||
"FLAT",
|
"FLAT",
|
||||||
{
|
{
|
||||||
"TYPE": "FLOAT32",
|
"TYPE": "FLOAT32",
|
||||||
"DIM": dim,
|
"DIM": dim,
|
||||||
"DISTANCE_METRIC": distance_metric,
|
"DISTANCE_METRIC": distance_metric,
|
||||||
"INITIAL_CAP": vector_number,
|
|
||||||
},
|
},
|
||||||
|
),
|
||||||
)
|
)
|
||||||
fields = [content, metadata, content_embedding]
|
|
||||||
|
|
||||||
# Check if index exists
|
|
||||||
try:
|
|
||||||
client.ft(index_name).info()
|
|
||||||
logger.info("Index already exists")
|
|
||||||
except: # noqa
|
|
||||||
# Create Redis Index
|
# Create Redis Index
|
||||||
client.ft(index_name).create_index(
|
client.ft(index_name).create_index(
|
||||||
fields=fields,
|
fields=schema,
|
||||||
definition=IndexDefinition(prefix=[prefix], index_type=IndexType.HASH),
|
definition=IndexDefinition(prefix=[prefix], index_type=IndexType.HASH),
|
||||||
)
|
)
|
||||||
|
|
||||||
pipeline = client.pipeline()
|
# Write data to Redis
|
||||||
|
pipeline = client.pipeline(transaction=False)
|
||||||
for i, text in enumerate(texts):
|
for i, text in enumerate(texts):
|
||||||
key = f"{prefix}:{i}"
|
key = _redis_key(prefix)
|
||||||
metadata = metadatas[i] if metadatas else {}
|
metadata = metadatas[i] if metadatas else {}
|
||||||
pipeline.hset(
|
pipeline.hset(
|
||||||
key,
|
key,
|
||||||
@ -286,6 +327,16 @@ class Redis(VectorStore):
|
|||||||
delete_documents: bool,
|
delete_documents: bool,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
|
"""
|
||||||
|
Drop a Redis search index.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
index_name (str): Name of the index to drop.
|
||||||
|
delete_documents (bool): Whether to drop the associated documents.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: Whether or not the drop was successful.
|
||||||
|
"""
|
||||||
redis_url = get_from_dict_or_env(kwargs, "redis_url", "REDIS_URL")
|
redis_url = get_from_dict_or_env(kwargs, "redis_url", "REDIS_URL")
|
||||||
try:
|
try:
|
||||||
import redis
|
import redis
|
||||||
@ -306,7 +357,7 @@ class Redis(VectorStore):
|
|||||||
client.ft(index_name).dropindex(delete_documents)
|
client.ft(index_name).dropindex(delete_documents)
|
||||||
logger.info("Drop index")
|
logger.info("Drop index")
|
||||||
return True
|
return True
|
||||||
except: # noqa
|
except: # noqa: E722
|
||||||
# Index not exist
|
# Index not exist
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@ -317,6 +368,7 @@ class Redis(VectorStore):
|
|||||||
index_name: str,
|
index_name: str,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Redis:
|
) -> Redis:
|
||||||
|
"""Connect to an existing Redis index."""
|
||||||
redis_url = get_from_dict_or_env(kwargs, "redis_url", "REDIS_URL")
|
redis_url = get_from_dict_or_env(kwargs, "redis_url", "REDIS_URL")
|
||||||
try:
|
try:
|
||||||
import redis
|
import redis
|
||||||
@ -330,15 +382,14 @@ class Redis(VectorStore):
|
|||||||
# otherwise passing it to Redis will result in an error.
|
# otherwise passing it to Redis will result in an error.
|
||||||
kwargs.pop("redis_url")
|
kwargs.pop("redis_url")
|
||||||
client = redis.from_url(url=redis_url, **kwargs)
|
client = redis.from_url(url=redis_url, **kwargs)
|
||||||
except ValueError as e:
|
# check if redis has redisearch module installed
|
||||||
raise ValueError(f"Your redis connected error: {e}")
|
_check_redis_module_exist(client, REDIS_REQUIRED_MODULES)
|
||||||
|
# ensure that the index already exists
|
||||||
# check if redis add redisearch module
|
assert _check_index_exists(
|
||||||
if not _check_redis_module_exist(client, "search"):
|
client, index_name
|
||||||
raise ValueError(
|
), f"Index {index_name} does not exist"
|
||||||
"Could not use redis directly, you need to add search module"
|
except Exception as e:
|
||||||
"Please refer [RediSearch](https://redis.io/docs/stack/search/quick_start/)" # noqa
|
raise ValueError(f"Redis failed to connect: {e}")
|
||||||
)
|
|
||||||
|
|
||||||
return cls(redis_url, index_name, embedding.embed_query)
|
return cls(redis_url, index_name, embedding.embed_query)
|
||||||
|
|
||||||
@ -349,7 +400,8 @@ class Redis(VectorStore):
|
|||||||
class RedisVectorStoreRetriever(BaseRetriever, BaseModel):
|
class RedisVectorStoreRetriever(BaseRetriever, BaseModel):
|
||||||
vectorstore: Redis
|
vectorstore: Redis
|
||||||
search_type: str = "similarity"
|
search_type: str = "similarity"
|
||||||
search_kwargs: dict = Field(default_factory=dict)
|
k: int = 4
|
||||||
|
score_threshold: float = 0.4
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
"""Configuration for this pydantic object."""
|
"""Configuration for this pydantic object."""
|
||||||
@ -367,10 +419,10 @@ class RedisVectorStoreRetriever(BaseRetriever, BaseModel):
|
|||||||
|
|
||||||
def get_relevant_documents(self, query: str) -> List[Document]:
|
def get_relevant_documents(self, query: str) -> List[Document]:
|
||||||
if self.search_type == "similarity":
|
if self.search_type == "similarity":
|
||||||
docs = self.vectorstore.similarity_search(query, **self.search_kwargs)
|
docs = self.vectorstore.similarity_search(query, k=self.k)
|
||||||
elif self.search_type == "similarity_limit":
|
elif self.search_type == "similarity_limit":
|
||||||
docs = self.vectorstore.similarity_search_limit_score(
|
docs = self.vectorstore.similarity_search_limit_score(
|
||||||
query, **self.search_kwargs
|
query, k=self.k, score_threshold=self.score_threshold
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"search_type of {self.search_type} not allowed.")
|
raise ValueError(f"search_type of {self.search_type} not allowed.")
|
||||||
|
Loading…
Reference in New Issue
Block a user