Harrison/redis cache (#3766)

Co-authored-by: Tyler Hutcherson <tyler.hutcherson@redis.com>
This commit is contained in:
Harrison Chase 2023-04-28 20:47:18 -07:00 committed by GitHub
parent b588446bf9
commit be7a8e0824
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 616 additions and 149 deletions

79
docs/ecosystem/redis.md Normal file
View File

@ -0,0 +1,79 @@
# Redis
This page covers how to use the [Redis](https://redis.com) ecosystem within LangChain.
It is broken into two parts: installation and setup, and then references to specific Redis wrappers.
## Installation and Setup
- Install the Redis Python SDK with `pip install redis`
## Wrappers
### Cache
The Cache wrapper allows for [Redis](https://redis.io) to be used as a remote, low-latency, in-memory cache for LLM prompts and responses.
#### Standard Cache
The standard cache is the Redis bread & butter of use case in production for both [open source](https://redis.io) and [enterprise](https://redis.com) users globally.
To import this cache:
```python
from langchain.cache import RedisCache
```
To use this cache with your LLMs:
```python
import langchain
import redis
redis_client = redis.Redis.from_url(...)
langchain.llm_cache = RedisCache(redis_client)
```
#### Semantic Cache
Semantic caching allows users to retrieve cached prompts based on semantic similarity between the user input and previously cached results. Under the hood it blends Redis as both a cache and a vectorstore.
To import this cache:
```python
from langchain.cache import RedisSemanticCache
```
To use this cache with your LLMs:
```python
import langchain
import redis
# use any embedding provider...
from tests.integration_tests.vectorstores.fake_embeddings import FakeEmbeddings
redis_url = "redis://localhost:6379"
langchain.llm_cache = RedisSemanticCache(
embedding=FakeEmbeddings(),
redis_url=redis_url
)
```
### VectorStore
The vectorstore wrapper turns Redis into a low-latency [vector database](https://redis.com/solutions/use-cases/vector-database/) for semantic search or LLM content retrieval.
To import this vectorstore:
```python
from langchain.vectorstores import Redis
```
For a more detailed walkthrough of the Redis vectorstore wrapper, see [this notebook](../modules/indexes/vectorstores/examples/redis.ipynb).
### Retriever
The Redis vector store retriever wrapper generalizes the vectorstore class to perform low-latency document retrieval. To create the retriever, simply call `.as_retriever()` on the base vectorstore class.
### Memory
Redis can be used to persist LLM conversations.
#### Vector Store Retriever Memory
For a more detailed walkthrough of the `VectorStoreRetrieverMemory` wrapper, see [this notebook](../modules/memory/types/vectorstore_retriever_memory.ipynb).
#### Chat Message History Memory
For a detailed example of Redis to cache conversation message history, see [this notebook](../modules/memory/examples/redis_chat_message_history.ipynb).

View File

@ -41,7 +41,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 3, "execution_count": 5,
"id": "f69f6283", "id": "f69f6283",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
@ -52,7 +52,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 4, "execution_count": 6,
"id": "64005d1f", "id": "64005d1f",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
@ -60,8 +60,8 @@
"name": "stdout", "name": "stdout",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"CPU times: user 14.2 ms, sys: 4.9 ms, total: 19.1 ms\n", "CPU times: user 26.1 ms, sys: 21.5 ms, total: 47.6 ms\n",
"Wall time: 1.1 s\n" "Wall time: 1.68 s\n"
] ]
}, },
{ {
@ -70,7 +70,7 @@
"'\\n\\nWhy did the chicken cross the road?\\n\\nTo get to the other side.'" "'\\n\\nWhy did the chicken cross the road?\\n\\nTo get to the other side.'"
] ]
}, },
"execution_count": 4, "execution_count": 6,
"metadata": {}, "metadata": {},
"output_type": "execute_result" "output_type": "execute_result"
} }
@ -83,7 +83,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 5, "execution_count": 7,
"id": "c8a1cb2b", "id": "c8a1cb2b",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
@ -91,8 +91,8 @@
"name": "stdout", "name": "stdout",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"CPU times: user 162 µs, sys: 7 µs, total: 169 µs\n", "CPU times: user 238 µs, sys: 143 µs, total: 381 µs\n",
"Wall time: 175 µs\n" "Wall time: 1.76 ms\n"
] ]
}, },
{ {
@ -101,7 +101,7 @@
"'\\n\\nWhy did the chicken cross the road?\\n\\nTo get to the other side.'" "'\\n\\nWhy did the chicken cross the road?\\n\\nTo get to the other side.'"
] ]
}, },
"execution_count": 5, "execution_count": 7,
"metadata": {}, "metadata": {},
"output_type": "execute_result" "output_type": "execute_result"
} }
@ -214,9 +214,18 @@
"## Redis Cache" "## Redis Cache"
] ]
}, },
{
"cell_type": "markdown",
"id": "c5c9a4d5",
"metadata": {},
"source": [
"### Standard Cache\n",
"Use [Redis](../../../../ecosystem/redis.md) to cache prompts and responses."
]
},
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": 8,
"id": "39f6eb0b", "id": "39f6eb0b",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
@ -225,15 +234,35 @@
"# (make sure your local Redis instance is running first before running this example)\n", "# (make sure your local Redis instance is running first before running this example)\n",
"from redis import Redis\n", "from redis import Redis\n",
"from langchain.cache import RedisCache\n", "from langchain.cache import RedisCache\n",
"\n",
"langchain.llm_cache = RedisCache(redis_=Redis())" "langchain.llm_cache = RedisCache(redis_=Redis())"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": 9,
"id": "28920749", "id": "28920749",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 6.88 ms, sys: 8.75 ms, total: 15.6 ms\n",
"Wall time: 1.04 s\n"
]
},
{
"data": {
"text/plain": [
"'\\n\\nWhy did the chicken cross the road?\\n\\nTo get to the other side!'"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [ "source": [
"%%time\n", "%%time\n",
"# The first time, it is not yet in cache, so it should take longer\n", "# The first time, it is not yet in cache, so it should take longer\n",
@ -242,16 +271,124 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": 14,
"id": "94bf9415", "id": "94bf9415",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 1.59 ms, sys: 610 µs, total: 2.2 ms\n",
"Wall time: 5.58 ms\n"
]
},
{
"data": {
"text/plain": [
"'\\n\\nWhy did the chicken cross the road?\\n\\nTo get to the other side!'"
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [ "source": [
"%%time\n", "%%time\n",
"# The second time it is, so it goes faster\n", "# The second time it is, so it goes faster\n",
"llm(\"Tell me a joke\")" "llm(\"Tell me a joke\")"
] ]
}, },
{
"cell_type": "markdown",
"id": "82be23f6",
"metadata": {},
"source": [
"### Semantic Cache\n",
"Use [Redis](../../../../ecosystem/redis.md) to cache prompts and responses and evaluate hits based on semantic similarity."
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "64df3099",
"metadata": {},
"outputs": [],
"source": [
"from langchain.embeddings import OpenAIEmbeddings\n",
"from langchain.cache import RedisSemanticCache\n",
"\n",
"\n",
"langchain.llm_cache = RedisSemanticCache(\n",
" redis_url=\"redis://localhost:6379\",\n",
" embedding=OpenAIEmbeddings()\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "8e91d3ac",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 351 ms, sys: 156 ms, total: 507 ms\n",
"Wall time: 3.37 s\n"
]
},
{
"data": {
"text/plain": [
"\"\\n\\nWhy don't scientists trust atoms?\\nBecause they make up everything.\""
]
},
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%time\n",
"# The first time, it is not yet in cache, so it should take longer\n",
"llm(\"Tell me a joke\")"
]
},
{
"cell_type": "code",
"execution_count": 27,
"id": "df856948",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 6.25 ms, sys: 2.72 ms, total: 8.97 ms\n",
"Wall time: 262 ms\n"
]
},
{
"data": {
"text/plain": [
"\"\\n\\nWhy don't scientists trust atoms?\\nBecause they make up everything.\""
]
},
"execution_count": 27,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%time\n",
"# The second time, while not a direct hit, the question is semantically similar to the original question,\n",
"# so it uses the cached result!\n",
"llm(\"Tell me one joke\")"
]
},
{ {
"cell_type": "markdown", "cell_type": "markdown",
"id": "684eab55", "id": "684eab55",

View File

@ -1,4 +1,5 @@
"""Beta Feature: base interface for cache.""" """Beta Feature: base interface for cache."""
import hashlib
import json import json
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, cast from typing import Any, Callable, Dict, List, Optional, Tuple, Type, cast
@ -12,11 +13,18 @@ try:
except ImportError: except ImportError:
from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.ext.declarative import declarative_base
from langchain.embeddings.base import Embeddings
from langchain.schema import Generation from langchain.schema import Generation
from langchain.vectorstores.redis import Redis as RedisVectorstore
RETURN_VAL_TYPE = List[Generation] RETURN_VAL_TYPE = List[Generation]
def _hash(_input: str) -> str:
"""Use a deterministic hashing approach."""
return hashlib.md5(_input.encode()).hexdigest()
class BaseCache(ABC): class BaseCache(ABC):
"""Base interface for cache.""" """Base interface for cache."""
@ -117,6 +125,8 @@ class SQLiteCache(SQLAlchemyCache):
class RedisCache(BaseCache): class RedisCache(BaseCache):
"""Cache that uses Redis as a backend.""" """Cache that uses Redis as a backend."""
# TODO - implement a TTL policy in Redis
def __init__(self, redis_: Any): def __init__(self, redis_: Any):
"""Initialize by passing in Redis instance.""" """Initialize by passing in Redis instance."""
try: try:
@ -130,28 +140,30 @@ class RedisCache(BaseCache):
raise ValueError("Please pass in Redis object.") raise ValueError("Please pass in Redis object.")
self.redis = redis_ self.redis = redis_
def _key(self, prompt: str, llm_string: str, idx: int) -> str: def _key(self, prompt: str, llm_string: str) -> str:
"""Compute key from prompt, llm_string, and idx.""" """Compute key from prompt and llm_string"""
return str(hash(prompt + llm_string)) + "_" + str(idx) return _hash(prompt + llm_string)
def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]: def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]:
"""Look up based on prompt and llm_string.""" """Look up based on prompt and llm_string."""
idx = 0
generations = [] generations = []
while self.redis.get(self._key(prompt, llm_string, idx)): # Read from a Redis HASH
result = self.redis.get(self._key(prompt, llm_string, idx)) results = self.redis.hgetall(self._key(prompt, llm_string))
if not result: if results:
break for _, text in results.items():
elif isinstance(result, bytes): generations.append(Generation(text=text))
result = result.decode()
generations.append(Generation(text=result))
idx += 1
return generations if generations else None return generations if generations else None
def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None: def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None:
"""Update cache based on prompt and llm_string.""" """Update cache based on prompt and llm_string."""
for i, generation in enumerate(return_val): # Write to a Redis HASH
self.redis.set(self._key(prompt, llm_string, i), generation.text) key = self._key(prompt, llm_string)
self.redis.hset(
key,
mapping={
str(idx): generation.text for idx, generation in enumerate(return_val)
},
)
def clear(self, **kwargs: Any) -> None: def clear(self, **kwargs: Any) -> None:
"""Clear cache. If `asynchronous` is True, flush asynchronously.""" """Clear cache. If `asynchronous` is True, flush asynchronously."""
@ -159,6 +171,106 @@ class RedisCache(BaseCache):
self.redis.flushdb(asynchronous=asynchronous, **kwargs) self.redis.flushdb(asynchronous=asynchronous, **kwargs)
class RedisSemanticCache(BaseCache):
"""Cache that uses Redis as a vector-store backend."""
# TODO - implement a TTL policy in Redis
def __init__(
self, redis_url: str, embedding: Embeddings, score_threshold: float = 0.2
):
"""Initialize by passing in the `init` GPTCache func
Args:
redis_url (str): URL to connect to Redis.
embedding (Embedding): Embedding provider for semantic encoding and search.
score_threshold (float, 0.2):
Example:
.. code-block:: python
import langchain
from langchain.cache import RedisSemanticCache
from langchain.embeddings import OpenAIEmbeddings
langchain.llm_cache = RedisSemanticCache(
redis_url="redis://localhost:6379",
embedding=OpenAIEmbeddings()
)
"""
self._cache_dict: Dict[str, RedisVectorstore] = {}
self.redis_url = redis_url
self.embedding = embedding
self.score_threshold = score_threshold
def _index_name(self, llm_string: str) -> str:
hashed_index = _hash(llm_string)
return f"cache:{hashed_index}"
def _get_llm_cache(self, llm_string: str) -> RedisVectorstore:
index_name = self._index_name(llm_string)
# return vectorstore client for the specific llm string
if index_name in self._cache_dict:
return self._cache_dict[index_name]
# create new vectorstore client for the specific llm string
try:
self._cache_dict[index_name] = RedisVectorstore.from_existing_index(
embedding=self.embedding,
index_name=index_name,
redis_url=self.redis_url,
)
except ValueError:
redis = RedisVectorstore(
embedding_function=self.embedding.embed_query,
index_name=index_name,
redis_url=self.redis_url,
)
_embedding = self.embedding.embed_query(text="test")
redis._create_index(dim=len(_embedding))
self._cache_dict[index_name] = redis
return self._cache_dict[index_name]
def clear(self, **kwargs: Any) -> None:
"""Clear semantic cache for a given llm_string."""
index_name = self._index_name(kwargs["llm_string"])
if index_name in self._cache_dict:
self._cache_dict[index_name].drop_index(
index_name=index_name, delete_documents=True, redis_url=self.redis_url
)
del self._cache_dict[index_name]
def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]:
"""Look up based on prompt and llm_string."""
llm_cache = self._get_llm_cache(llm_string)
generations = []
# Read from a Hash
results = llm_cache.similarity_search_limit_score(
query=prompt,
k=1,
score_threshold=self.score_threshold,
)
if results:
for document in results:
for text in document.metadata["return_val"]:
generations.append(Generation(text=text))
return generations if generations else None
def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None:
"""Update cache based on prompt and llm_string."""
llm_cache = self._get_llm_cache(llm_string)
# Write to vectorstore
metadata = {
"llm_string": llm_string,
"prompt": prompt,
"return_val": [generation.text for generation in return_val],
}
llm_cache.add_texts(texts=[prompt], metadatas=[metadata])
class GPTCache(BaseCache): class GPTCache(BaseCache):
"""Cache that uses GPTCache as a backend.""" """Cache that uses GPTCache as a backend."""

View File

@ -13,11 +13,13 @@ from langchain.vectorstores.myscale import MyScale, MyScaleSettings
from langchain.vectorstores.opensearch_vector_search import OpenSearchVectorSearch from langchain.vectorstores.opensearch_vector_search import OpenSearchVectorSearch
from langchain.vectorstores.pinecone import Pinecone from langchain.vectorstores.pinecone import Pinecone
from langchain.vectorstores.qdrant import Qdrant from langchain.vectorstores.qdrant import Qdrant
from langchain.vectorstores.redis import Redis
from langchain.vectorstores.supabase import SupabaseVectorStore from langchain.vectorstores.supabase import SupabaseVectorStore
from langchain.vectorstores.weaviate import Weaviate from langchain.vectorstores.weaviate import Weaviate
from langchain.vectorstores.zilliz import Zilliz from langchain.vectorstores.zilliz import Zilliz
__all__ = [ __all__ = [
"Redis",
"ElasticVectorSearch", "ElasticVectorSearch",
"FAISS", "FAISS",
"VectorStore", "VectorStore",

View File

@ -4,11 +4,21 @@ from __future__ import annotations
import json import json
import logging import logging
import uuid import uuid
from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Tuple, Type from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
Iterable,
List,
Mapping,
Optional,
Tuple,
Type,
)
import numpy as np import numpy as np
from pydantic import BaseModel, root_validator from pydantic import BaseModel, root_validator
from redis.client import Redis as RedisType
from langchain.docstore.document import Document from langchain.docstore.document import Document
from langchain.embeddings.base import Embeddings from langchain.embeddings.base import Embeddings
@ -18,23 +28,30 @@ from langchain.vectorstores.base import VectorStore
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
if TYPE_CHECKING:
from redis.client import Redis as RedisType
from redis.commands.search.query import Query
# required modules # required modules
REDIS_REQUIRED_MODULES = [ REDIS_REQUIRED_MODULES = [
{"name": "search", "ver": 20400}, {"name": "search", "ver": 20400},
{"name": "searchlight", "ver": 20400},
] ]
def _check_redis_module_exist(client: RedisType, modules: List[dict]) -> None: def _check_redis_module_exist(client: RedisType, required_modules: List[dict]) -> None:
"""Check if the correct Redis modules are installed.""" """Check if the correct Redis modules are installed."""
installed_modules = client.module_list() installed_modules = client.module_list()
installed_modules = { installed_modules = {
module[b"name"].decode("utf-8"): module for module in installed_modules module[b"name"].decode("utf-8"): module for module in installed_modules
} }
for module in modules: for module in required_modules:
if module["name"] not in installed_modules or int( if module["name"] in installed_modules and int(
installed_modules[module["name"]][b"ver"] installed_modules[module["name"]][b"ver"]
) < int(module["ver"]): ) >= int(module["ver"]):
return
# otherwise raise error
error_message = ( error_message = (
"You must add the RediSearch (>= 2.4) module from Redis Stack. " "You must add the RediSearch (>= 2.4) module from Redis Stack. "
"Please refer to Redis Stack docs: https://redis.io/docs/stack/" "Please refer to Redis Stack docs: https://redis.io/docs/stack/"
@ -65,6 +82,24 @@ def _redis_prefix(index_name: str) -> str:
class Redis(VectorStore): class Redis(VectorStore):
"""Wrapper around Redis vector database.
To use, you should have the ``redis`` python package installed.
Example:
.. code-block:: python
from langchain.vectorstores import Redis
from langchain.embeddings import OpenAIEmbeddings
embeddings = OpenAIEmbeddings()
vectorstore = Redis(
redis_url="redis://username:password@localhost:6379"
index_name="my-index",
embedding_function=embeddings.embed_query,
)
"""
def __init__( def __init__(
self, self,
redis_url: str, redis_url: str,
@ -99,33 +134,92 @@ class Redis(VectorStore):
self.metadata_key = metadata_key self.metadata_key = metadata_key
self.vector_key = vector_key self.vector_key = vector_key
def _create_index(self, dim: int = 1536) -> None:
try:
from redis.commands.search.field import TextField, VectorField
from redis.commands.search.indexDefinition import IndexDefinition, IndexType
except ImportError:
raise ValueError(
"Could not import redis python package. "
"Please install it with `pip install redis`."
)
# Check if index exists
if not _check_index_exists(self.client, self.index_name):
# Constants
distance_metric = (
"COSINE" # distance metric for the vectors (ex. COSINE, IP, L2)
)
schema = (
TextField(name=self.content_key),
TextField(name=self.metadata_key),
VectorField(
self.vector_key,
"FLAT",
{
"TYPE": "FLOAT32",
"DIM": dim,
"DISTANCE_METRIC": distance_metric,
},
),
)
prefix = _redis_prefix(self.index_name)
# Create Redis Index
self.client.ft(self.index_name).create_index(
fields=schema,
definition=IndexDefinition(prefix=[prefix], index_type=IndexType.HASH),
)
def add_texts( def add_texts(
self, self,
texts: Iterable[str], texts: Iterable[str],
metadatas: Optional[List[dict]] = None, metadatas: Optional[List[dict]] = None,
embeddings: Optional[List[List[float]]] = None,
keys: Optional[List[str]] = None,
batch_size: int = 1000,
**kwargs: Any, **kwargs: Any,
) -> List[str]: ) -> List[str]:
"""Add texts data to an existing index.""" """Add more texts to the vectorstore.
prefix = _redis_prefix(self.index_name)
keys = kwargs.get("keys") Args:
texts (Iterable[str]): Iterable of strings/text to add to the vectorstore.
metadatas (Optional[List[dict]], optional): Optional list of metadatas.
Defaults to None.
embeddings (Optional[List[List[float]]], optional): Optional pre-generated
embeddings. Defaults to None.
keys (Optional[List[str]], optional): Optional key values to use as ids.
Defaults to None.
batch_size (int, optional): Batch size to use for writes. Defaults to 1000.
Returns:
List[str]: List of ids added to the vectorstore
"""
ids = [] ids = []
prefix = _redis_prefix(self.index_name)
# Write data to redis # Write data to redis
pipeline = self.client.pipeline(transaction=False) pipeline = self.client.pipeline(transaction=False)
for i, text in enumerate(texts): for i, text in enumerate(texts):
# Use provided key otherwise use default key # Use provided values by default or fallback
key = keys[i] if keys else _redis_key(prefix) key = keys[i] if keys else _redis_key(prefix)
metadata = metadatas[i] if metadatas else {} metadata = metadatas[i] if metadatas else {}
embedding = embeddings[i] if embeddings else self.embedding_function(text)
pipeline.hset( pipeline.hset(
key, key,
mapping={ mapping={
self.content_key: text, self.content_key: text,
self.vector_key: np.array( self.vector_key: np.array(embedding, dtype=np.float32).tobytes(),
self.embedding_function(text), dtype=np.float32
).tobytes(),
self.metadata_key: json.dumps(metadata), self.metadata_key: json.dumps(metadata),
}, },
) )
ids.append(key) ids.append(key)
# Write batch
if i % batch_size == 0:
pipeline.execute()
# Cleanup final batch
pipeline.execute() pipeline.execute()
return ids return ids
@ -170,9 +264,30 @@ class Redis(VectorStore):
""" """
docs_and_scores = self.similarity_search_with_score(query, k=k) docs_and_scores = self.similarity_search_with_score(query, k=k)
return [doc for doc, score in docs_and_scores if score < score_threshold] return [doc for doc, score in docs_and_scores if score < score_threshold]
def _prepare_query(self, k: int) -> Query:
try:
from redis.commands.search.query import Query
except ImportError:
raise ValueError(
"Could not import redis python package. "
"Please install it with `pip install redis`."
)
# Prepare the Query
hybrid_fields = "*"
base_query = (
f"{hybrid_fields}=>[KNN {k} @{self.vector_key} $vector AS vector_score]"
)
return_fields = [self.metadata_key, self.content_key, "vector_score"]
return (
Query(base_query)
.return_fields(*return_fields)
.sort_by("vector_score")
.paging(0, k)
.dialect(2)
)
def similarity_search_with_score( def similarity_search_with_score(
self, query: str, k: int = 4 self, query: str, k: int = 4
) -> List[Tuple[Document, float]]: ) -> List[Tuple[Document, float]]:
@ -185,40 +300,22 @@ class Redis(VectorStore):
Returns: Returns:
List of Documents most similar to the query and score for each List of Documents most similar to the query and score for each
""" """
try:
from redis.commands.search.query import Query
except ImportError:
raise ValueError(
"Could not import redis python package. "
"Please install it with `pip install redis`."
)
# Creates embedding vector from user query # Creates embedding vector from user query
embedding = self.embedding_function(query) embedding = self.embedding_function(query)
# Prepare the Query # Creates Redis query
return_fields = [self.metadata_key, self.content_key, "vector_score"] redis_query = self._prepare_query(k)
vector_field = self.vector_key
hybrid_fields = "*"
base_query = (
f"{hybrid_fields}=>[KNN {k} @{vector_field} $vector AS vector_score]"
)
redis_query = (
Query(base_query)
.return_fields(*return_fields)
.sort_by("vector_score")
.paging(0, k)
.dialect(2)
)
params_dict: Mapping[str, str] = { params_dict: Mapping[str, str] = {
"vector": np.array(embedding) # type: ignore "vector": np.array(embedding) # type: ignore
.astype(dtype=np.float32) .astype(dtype=np.float32)
.tobytes() .tobytes()
} }
# perform vector search # Perform vector search
results = self.client.ft(self.index_name).search(redis_query, params_dict) results = self.client.ft(self.index_name).search(redis_query, params_dict)
# Prepare document results
docs = [ docs = [
( (
Document( Document(
@ -243,15 +340,15 @@ class Redis(VectorStore):
vector_key: str = "content_vector", vector_key: str = "content_vector",
**kwargs: Any, **kwargs: Any,
) -> Redis: ) -> Redis:
"""Construct RediSearch wrapper from raw documents. """Create a Redis vectorstore from raw documents.
This is a user-friendly interface that: This is a user-friendly interface that:
1. Embeds documents. 1. Embeds documents.
2. Creates a new index for the embeddings in the RediSearch instance. 2. Creates a new index for the embeddings in Redis.
3. Adds the documents to the newly created RediSearch index. 3. Adds the documents to the newly created Redis index.
This is intended to be a quick way to get started. This is intended to be a quick way to get started.
Example: Example:
.. code-block:: python .. code-block:: python
from langchain import RediSearch from langchain.vectorstores import Redis
from langchain.embeddings import OpenAIEmbeddings from langchain.embeddings import OpenAIEmbeddings
embeddings = OpenAIEmbeddings() embeddings = OpenAIEmbeddings()
redisearch = RediSearch.from_texts( redisearch = RediSearch.from_texts(
@ -261,84 +358,35 @@ class Redis(VectorStore):
) )
""" """
redis_url = get_from_dict_or_env(kwargs, "redis_url", "REDIS_URL") redis_url = get_from_dict_or_env(kwargs, "redis_url", "REDIS_URL")
try:
import redis
from redis.commands.search.field import TextField, VectorField
from redis.commands.search.indexDefinition import IndexDefinition, IndexType
except ImportError:
raise ValueError(
"Could not import redis python package. "
"Please install it with `pip install redis`."
)
try:
# We need to first remove redis_url from kwargs,
# otherwise passing it to Redis will result in an error.
if "redis_url" in kwargs: if "redis_url" in kwargs:
kwargs.pop("redis_url") kwargs.pop("redis_url")
client = redis.from_url(url=redis_url, **kwargs)
# check if redis has redisearch module installed
_check_redis_module_exist(client, REDIS_REQUIRED_MODULES)
except ValueError as e:
raise ValueError(f"Redis failed to connect: {e}")
# Create embeddings over documents
embeddings = embedding.embed_documents(texts)
# Name of the search index if not given # Name of the search index if not given
if not index_name: if not index_name:
index_name = uuid.uuid4().hex index_name = uuid.uuid4().hex
prefix = _redis_prefix(index_name) # prefix for the document keys
# Check if index exists # Create instance
if not _check_index_exists(client, index_name): instance = cls(
# Constants redis_url=redis_url,
dim = len(embeddings[0]) index_name=index_name,
distance_metric = ( embedding_function=embedding.embed_query,
"COSINE" # distance metric for the vectors (ex. COSINE, IP, L2)
)
schema = (
TextField(name=content_key),
TextField(name=metadata_key),
VectorField(
vector_key,
"FLAT",
{
"TYPE": "FLOAT32",
"DIM": dim,
"DISTANCE_METRIC": distance_metric,
},
),
)
# Create Redis Index
client.ft(index_name).create_index(
fields=schema,
definition=IndexDefinition(prefix=[prefix], index_type=IndexType.HASH),
)
# Write data to Redis
pipeline = client.pipeline(transaction=False)
for i, text in enumerate(texts):
key = _redis_key(prefix)
metadata = metadatas[i] if metadatas else {}
pipeline.hset(
key,
mapping={
content_key: text,
vector_key: np.array(embeddings[i], dtype=np.float32).tobytes(),
metadata_key: json.dumps(metadata),
},
)
pipeline.execute()
return cls(
redis_url,
index_name,
embedding.embed_query,
content_key=content_key, content_key=content_key,
metadata_key=metadata_key, metadata_key=metadata_key,
vector_key=vector_key, vector_key=vector_key,
**kwargs, **kwargs,
) )
# Create embeddings over documents
embeddings = embedding.embed_documents(texts)
# Create the search index
instance._create_index(dim=len(embeddings[0]))
# Add data to Redis
instance.add_texts(texts, metadatas, embeddings)
return instance
@staticmethod @staticmethod
def drop_index( def drop_index(
index_name: str, index_name: str,

View File

@ -0,0 +1,55 @@
"""Test Redis cache functionality."""
import redis
import langchain
from langchain.cache import RedisCache, RedisSemanticCache
from langchain.schema import Generation, LLMResult
from tests.integration_tests.vectorstores.fake_embeddings import FakeEmbeddings
from tests.unit_tests.llms.fake_llm import FakeLLM
REDIS_TEST_URL = "redis://localhost:6379"
def test_redis_cache() -> None:
langchain.llm_cache = RedisCache(redis_=redis.Redis.from_url(REDIS_TEST_URL))
llm = FakeLLM()
params = llm.dict()
params["stop"] = None
llm_string = str(sorted([(k, v) for k, v in params.items()]))
langchain.llm_cache.update("foo", llm_string, [Generation(text="fizz")])
output = llm.generate(["foo"])
print(output)
expected_output = LLMResult(
generations=[[Generation(text="fizz")]],
llm_output={},
)
print(expected_output)
assert output == expected_output
langchain.llm_cache.redis.flushall()
def test_redis_semantic_cache() -> None:
langchain.llm_cache = RedisSemanticCache(
embedding=FakeEmbeddings(), redis_url=REDIS_TEST_URL, score_threshold=0.1
)
llm = FakeLLM()
params = llm.dict()
params["stop"] = None
llm_string = str(sorted([(k, v) for k, v in params.items()]))
langchain.llm_cache.update("foo", llm_string, [Generation(text="fizz")])
output = llm.generate(
["bar"]
) # foo and bar will have the same embedding produced by FakeEmbeddings
expected_output = LLMResult(
generations=[[Generation(text="fizz")]],
llm_output={},
)
assert output == expected_output
# clear the cache
langchain.llm_cache.clear(llm_string=llm_string)
output = llm.generate(
["bar"]
) # foo and bar will have the same embedding produced by FakeEmbeddings
# expect different output now without cached result
assert output != expected_output
langchain.llm_cache.clear(llm_string=llm_string)

View File

@ -1,26 +1,60 @@
"""Test Redis functionality.""" """Test Redis functionality."""
from langchain.docstore.document import Document from langchain.docstore.document import Document
from langchain.vectorstores.redis import Redis from langchain.vectorstores.redis import Redis
from tests.integration_tests.vectorstores.fake_embeddings import FakeEmbeddings from tests.integration_tests.vectorstores.fake_embeddings import FakeEmbeddings
TEST_INDEX_NAME = "test"
TEST_REDIS_URL = "redis://localhost:6379"
TEST_SINGLE_RESULT = [Document(page_content="foo")]
TEST_RESULT = [Document(page_content="foo"), Document(page_content="foo")]
def drop(index_name: str) -> bool:
return Redis.drop_index(
index_name=index_name, delete_documents=True, redis_url=TEST_REDIS_URL
)
def test_redis() -> None: def test_redis() -> None:
"""Test end to end construction and search.""" """Test end to end construction and search."""
texts = ["foo", "bar", "baz"] texts = ["foo", "bar", "baz"]
docsearch = Redis.from_texts( docsearch = Redis.from_texts(texts, FakeEmbeddings(), redis_url=TEST_REDIS_URL)
texts, FakeEmbeddings(), redis_url="redis://localhost:6379"
)
output = docsearch.similarity_search("foo", k=1) output = docsearch.similarity_search("foo", k=1)
assert output == [Document(page_content="foo")] assert output == TEST_SINGLE_RESULT
assert drop(docsearch.index_name)
def test_redis_new_vector() -> None: def test_redis_new_vector() -> None:
"""Test adding a new document""" """Test adding a new document"""
texts = ["foo", "bar", "baz"] texts = ["foo", "bar", "baz"]
docsearch = Redis.from_texts( docsearch = Redis.from_texts(texts, FakeEmbeddings(), redis_url=TEST_REDIS_URL)
texts, FakeEmbeddings(), redis_url="redis://localhost:6379" docsearch.add_texts(["foo"])
output = docsearch.similarity_search("foo", k=2)
assert output == TEST_RESULT
assert drop(docsearch.index_name)
def test_redis_from_existing() -> None:
"""Test adding a new document"""
texts = ["foo", "bar", "baz"]
Redis.from_texts(
texts, FakeEmbeddings(), index_name=TEST_INDEX_NAME, redis_url=TEST_REDIS_URL
)
# Test creating from an existing
docsearch2 = Redis.from_existing_index(
FakeEmbeddings(), index_name=TEST_INDEX_NAME, redis_url=TEST_REDIS_URL
)
output = docsearch2.similarity_search("foo", k=1)
assert output == TEST_SINGLE_RESULT
def test_redis_add_texts_to_existing() -> None:
"""Test adding a new document"""
# Test creating from an existing
docsearch = Redis.from_existing_index(
FakeEmbeddings(), index_name=TEST_INDEX_NAME, redis_url=TEST_REDIS_URL
) )
docsearch.add_texts(["foo"]) docsearch.add_texts(["foo"])
output = docsearch.similarity_search("foo", k=2) output = docsearch.similarity_search("foo", k=2)
assert output == [Document(page_content="foo"), Document(page_content="foo")] assert output == TEST_RESULT
assert drop(TEST_INDEX_NAME)