Harrison/redis cache (#3766)

Co-authored-by: Tyler Hutcherson <tyler.hutcherson@redis.com>
This commit is contained in:
Harrison Chase 2023-04-28 20:47:18 -07:00 committed by GitHub
parent b588446bf9
commit be7a8e0824
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 616 additions and 149 deletions

79
docs/ecosystem/redis.md Normal file
View File

@ -0,0 +1,79 @@
# Redis
This page covers how to use the [Redis](https://redis.com) ecosystem within LangChain.
It is broken into two parts: installation and setup, and then references to specific Redis wrappers.
## Installation and Setup
- Install the Redis Python SDK with `pip install redis`
## Wrappers
### Cache
The Cache wrapper allows for [Redis](https://redis.io) to be used as a remote, low-latency, in-memory cache for LLM prompts and responses.
#### Standard Cache
The standard cache is the Redis bread & butter of use case in production for both [open source](https://redis.io) and [enterprise](https://redis.com) users globally.
To import this cache:
```python
from langchain.cache import RedisCache
```
To use this cache with your LLMs:
```python
import langchain
import redis
redis_client = redis.Redis.from_url(...)
langchain.llm_cache = RedisCache(redis_client)
```
#### Semantic Cache
Semantic caching allows users to retrieve cached prompts based on semantic similarity between the user input and previously cached results. Under the hood it blends Redis as both a cache and a vectorstore.
To import this cache:
```python
from langchain.cache import RedisSemanticCache
```
To use this cache with your LLMs:
```python
import langchain
import redis
# use any embedding provider...
from tests.integration_tests.vectorstores.fake_embeddings import FakeEmbeddings
redis_url = "redis://localhost:6379"
langchain.llm_cache = RedisSemanticCache(
embedding=FakeEmbeddings(),
redis_url=redis_url
)
```
### VectorStore
The vectorstore wrapper turns Redis into a low-latency [vector database](https://redis.com/solutions/use-cases/vector-database/) for semantic search or LLM content retrieval.
To import this vectorstore:
```python
from langchain.vectorstores import Redis
```
For a more detailed walkthrough of the Redis vectorstore wrapper, see [this notebook](../modules/indexes/vectorstores/examples/redis.ipynb).
### Retriever
The Redis vector store retriever wrapper generalizes the vectorstore class to perform low-latency document retrieval. To create the retriever, simply call `.as_retriever()` on the base vectorstore class.
### Memory
Redis can be used to persist LLM conversations.
#### Vector Store Retriever Memory
For a more detailed walkthrough of the `VectorStoreRetrieverMemory` wrapper, see [this notebook](../modules/memory/types/vectorstore_retriever_memory.ipynb).
#### Chat Message History Memory
For a detailed example of Redis to cache conversation message history, see [this notebook](../modules/memory/examples/redis_chat_message_history.ipynb).

View File

@ -41,7 +41,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 5,
"id": "f69f6283",
"metadata": {},
"outputs": [],
@ -52,7 +52,7 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 6,
"id": "64005d1f",
"metadata": {},
"outputs": [
@ -60,8 +60,8 @@
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 14.2 ms, sys: 4.9 ms, total: 19.1 ms\n",
"Wall time: 1.1 s\n"
"CPU times: user 26.1 ms, sys: 21.5 ms, total: 47.6 ms\n",
"Wall time: 1.68 s\n"
]
},
{
@ -70,7 +70,7 @@
"'\\n\\nWhy did the chicken cross the road?\\n\\nTo get to the other side.'"
]
},
"execution_count": 4,
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
@ -83,7 +83,7 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 7,
"id": "c8a1cb2b",
"metadata": {},
"outputs": [
@ -91,8 +91,8 @@
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 162 µs, sys: 7 µs, total: 169 µs\n",
"Wall time: 175 µs\n"
"CPU times: user 238 µs, sys: 143 µs, total: 381 µs\n",
"Wall time: 1.76 ms\n"
]
},
{
@ -101,7 +101,7 @@
"'\\n\\nWhy did the chicken cross the road?\\n\\nTo get to the other side.'"
]
},
"execution_count": 5,
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
@ -214,9 +214,18 @@
"## Redis Cache"
]
},
{
"cell_type": "markdown",
"id": "c5c9a4d5",
"metadata": {},
"source": [
"### Standard Cache\n",
"Use [Redis](../../../../ecosystem/redis.md) to cache prompts and responses."
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 8,
"id": "39f6eb0b",
"metadata": {},
"outputs": [],
@ -225,15 +234,35 @@
"# (make sure your local Redis instance is running first before running this example)\n",
"from redis import Redis\n",
"from langchain.cache import RedisCache\n",
"\n",
"langchain.llm_cache = RedisCache(redis_=Redis())"
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 9,
"id": "28920749",
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 6.88 ms, sys: 8.75 ms, total: 15.6 ms\n",
"Wall time: 1.04 s\n"
]
},
{
"data": {
"text/plain": [
"'\\n\\nWhy did the chicken cross the road?\\n\\nTo get to the other side!'"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%time\n",
"# The first time, it is not yet in cache, so it should take longer\n",
@ -242,16 +271,124 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 14,
"id": "94bf9415",
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 1.59 ms, sys: 610 µs, total: 2.2 ms\n",
"Wall time: 5.58 ms\n"
]
},
{
"data": {
"text/plain": [
"'\\n\\nWhy did the chicken cross the road?\\n\\nTo get to the other side!'"
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%time\n",
"# The second time it is, so it goes faster\n",
"llm(\"Tell me a joke\")"
]
},
{
"cell_type": "markdown",
"id": "82be23f6",
"metadata": {},
"source": [
"### Semantic Cache\n",
"Use [Redis](../../../../ecosystem/redis.md) to cache prompts and responses and evaluate hits based on semantic similarity."
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "64df3099",
"metadata": {},
"outputs": [],
"source": [
"from langchain.embeddings import OpenAIEmbeddings\n",
"from langchain.cache import RedisSemanticCache\n",
"\n",
"\n",
"langchain.llm_cache = RedisSemanticCache(\n",
" redis_url=\"redis://localhost:6379\",\n",
" embedding=OpenAIEmbeddings()\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "8e91d3ac",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 351 ms, sys: 156 ms, total: 507 ms\n",
"Wall time: 3.37 s\n"
]
},
{
"data": {
"text/plain": [
"\"\\n\\nWhy don't scientists trust atoms?\\nBecause they make up everything.\""
]
},
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%time\n",
"# The first time, it is not yet in cache, so it should take longer\n",
"llm(\"Tell me a joke\")"
]
},
{
"cell_type": "code",
"execution_count": 27,
"id": "df856948",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 6.25 ms, sys: 2.72 ms, total: 8.97 ms\n",
"Wall time: 262 ms\n"
]
},
{
"data": {
"text/plain": [
"\"\\n\\nWhy don't scientists trust atoms?\\nBecause they make up everything.\""
]
},
"execution_count": 27,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%time\n",
"# The second time, while not a direct hit, the question is semantically similar to the original question,\n",
"# so it uses the cached result!\n",
"llm(\"Tell me one joke\")"
]
},
{
"cell_type": "markdown",
"id": "684eab55",

View File

@ -1,4 +1,5 @@
"""Beta Feature: base interface for cache."""
import hashlib
import json
from abc import ABC, abstractmethod
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, cast
@ -12,11 +13,18 @@ try:
except ImportError:
from sqlalchemy.ext.declarative import declarative_base
from langchain.embeddings.base import Embeddings
from langchain.schema import Generation
from langchain.vectorstores.redis import Redis as RedisVectorstore
RETURN_VAL_TYPE = List[Generation]
def _hash(_input: str) -> str:
"""Use a deterministic hashing approach."""
return hashlib.md5(_input.encode()).hexdigest()
class BaseCache(ABC):
"""Base interface for cache."""
@ -117,6 +125,8 @@ class SQLiteCache(SQLAlchemyCache):
class RedisCache(BaseCache):
"""Cache that uses Redis as a backend."""
# TODO - implement a TTL policy in Redis
def __init__(self, redis_: Any):
"""Initialize by passing in Redis instance."""
try:
@ -130,28 +140,30 @@ class RedisCache(BaseCache):
raise ValueError("Please pass in Redis object.")
self.redis = redis_
def _key(self, prompt: str, llm_string: str, idx: int) -> str:
"""Compute key from prompt, llm_string, and idx."""
return str(hash(prompt + llm_string)) + "_" + str(idx)
def _key(self, prompt: str, llm_string: str) -> str:
"""Compute key from prompt and llm_string"""
return _hash(prompt + llm_string)
def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]:
"""Look up based on prompt and llm_string."""
idx = 0
generations = []
while self.redis.get(self._key(prompt, llm_string, idx)):
result = self.redis.get(self._key(prompt, llm_string, idx))
if not result:
break
elif isinstance(result, bytes):
result = result.decode()
generations.append(Generation(text=result))
idx += 1
# Read from a Redis HASH
results = self.redis.hgetall(self._key(prompt, llm_string))
if results:
for _, text in results.items():
generations.append(Generation(text=text))
return generations if generations else None
def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None:
"""Update cache based on prompt and llm_string."""
for i, generation in enumerate(return_val):
self.redis.set(self._key(prompt, llm_string, i), generation.text)
# Write to a Redis HASH
key = self._key(prompt, llm_string)
self.redis.hset(
key,
mapping={
str(idx): generation.text for idx, generation in enumerate(return_val)
},
)
def clear(self, **kwargs: Any) -> None:
"""Clear cache. If `asynchronous` is True, flush asynchronously."""
@ -159,6 +171,106 @@ class RedisCache(BaseCache):
self.redis.flushdb(asynchronous=asynchronous, **kwargs)
class RedisSemanticCache(BaseCache):
"""Cache that uses Redis as a vector-store backend."""
# TODO - implement a TTL policy in Redis
def __init__(
self, redis_url: str, embedding: Embeddings, score_threshold: float = 0.2
):
"""Initialize by passing in the `init` GPTCache func
Args:
redis_url (str): URL to connect to Redis.
embedding (Embedding): Embedding provider for semantic encoding and search.
score_threshold (float, 0.2):
Example:
.. code-block:: python
import langchain
from langchain.cache import RedisSemanticCache
from langchain.embeddings import OpenAIEmbeddings
langchain.llm_cache = RedisSemanticCache(
redis_url="redis://localhost:6379",
embedding=OpenAIEmbeddings()
)
"""
self._cache_dict: Dict[str, RedisVectorstore] = {}
self.redis_url = redis_url
self.embedding = embedding
self.score_threshold = score_threshold
def _index_name(self, llm_string: str) -> str:
hashed_index = _hash(llm_string)
return f"cache:{hashed_index}"
def _get_llm_cache(self, llm_string: str) -> RedisVectorstore:
index_name = self._index_name(llm_string)
# return vectorstore client for the specific llm string
if index_name in self._cache_dict:
return self._cache_dict[index_name]
# create new vectorstore client for the specific llm string
try:
self._cache_dict[index_name] = RedisVectorstore.from_existing_index(
embedding=self.embedding,
index_name=index_name,
redis_url=self.redis_url,
)
except ValueError:
redis = RedisVectorstore(
embedding_function=self.embedding.embed_query,
index_name=index_name,
redis_url=self.redis_url,
)
_embedding = self.embedding.embed_query(text="test")
redis._create_index(dim=len(_embedding))
self._cache_dict[index_name] = redis
return self._cache_dict[index_name]
def clear(self, **kwargs: Any) -> None:
"""Clear semantic cache for a given llm_string."""
index_name = self._index_name(kwargs["llm_string"])
if index_name in self._cache_dict:
self._cache_dict[index_name].drop_index(
index_name=index_name, delete_documents=True, redis_url=self.redis_url
)
del self._cache_dict[index_name]
def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]:
"""Look up based on prompt and llm_string."""
llm_cache = self._get_llm_cache(llm_string)
generations = []
# Read from a Hash
results = llm_cache.similarity_search_limit_score(
query=prompt,
k=1,
score_threshold=self.score_threshold,
)
if results:
for document in results:
for text in document.metadata["return_val"]:
generations.append(Generation(text=text))
return generations if generations else None
def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None:
"""Update cache based on prompt and llm_string."""
llm_cache = self._get_llm_cache(llm_string)
# Write to vectorstore
metadata = {
"llm_string": llm_string,
"prompt": prompt,
"return_val": [generation.text for generation in return_val],
}
llm_cache.add_texts(texts=[prompt], metadatas=[metadata])
class GPTCache(BaseCache):
"""Cache that uses GPTCache as a backend."""

View File

@ -13,11 +13,13 @@ from langchain.vectorstores.myscale import MyScale, MyScaleSettings
from langchain.vectorstores.opensearch_vector_search import OpenSearchVectorSearch
from langchain.vectorstores.pinecone import Pinecone
from langchain.vectorstores.qdrant import Qdrant
from langchain.vectorstores.redis import Redis
from langchain.vectorstores.supabase import SupabaseVectorStore
from langchain.vectorstores.weaviate import Weaviate
from langchain.vectorstores.zilliz import Zilliz
__all__ = [
"Redis",
"ElasticVectorSearch",
"FAISS",
"VectorStore",

View File

@ -4,11 +4,21 @@ from __future__ import annotations
import json
import logging
import uuid
from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Tuple, Type
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
Iterable,
List,
Mapping,
Optional,
Tuple,
Type,
)
import numpy as np
from pydantic import BaseModel, root_validator
from redis.client import Redis as RedisType
from langchain.docstore.document import Document
from langchain.embeddings.base import Embeddings
@ -18,29 +28,36 @@ from langchain.vectorstores.base import VectorStore
logger = logging.getLogger(__name__)
if TYPE_CHECKING:
from redis.client import Redis as RedisType
from redis.commands.search.query import Query
# required modules
REDIS_REQUIRED_MODULES = [
{"name": "search", "ver": 20400},
{"name": "searchlight", "ver": 20400},
]
def _check_redis_module_exist(client: RedisType, modules: List[dict]) -> None:
def _check_redis_module_exist(client: RedisType, required_modules: List[dict]) -> None:
"""Check if the correct Redis modules are installed."""
installed_modules = client.module_list()
installed_modules = {
module[b"name"].decode("utf-8"): module for module in installed_modules
}
for module in modules:
if module["name"] not in installed_modules or int(
for module in required_modules:
if module["name"] in installed_modules and int(
installed_modules[module["name"]][b"ver"]
) < int(module["ver"]):
error_message = (
"You must add the RediSearch (>= 2.4) module from Redis Stack. "
"Please refer to Redis Stack docs: https://redis.io/docs/stack/"
)
logging.error(error_message)
raise ValueError(error_message)
) >= int(module["ver"]):
return
# otherwise raise error
error_message = (
"You must add the RediSearch (>= 2.4) module from Redis Stack. "
"Please refer to Redis Stack docs: https://redis.io/docs/stack/"
)
logging.error(error_message)
raise ValueError(error_message)
def _check_index_exists(client: RedisType, index_name: str) -> bool:
@ -65,6 +82,24 @@ def _redis_prefix(index_name: str) -> str:
class Redis(VectorStore):
"""Wrapper around Redis vector database.
To use, you should have the ``redis`` python package installed.
Example:
.. code-block:: python
from langchain.vectorstores import Redis
from langchain.embeddings import OpenAIEmbeddings
embeddings = OpenAIEmbeddings()
vectorstore = Redis(
redis_url="redis://username:password@localhost:6379"
index_name="my-index",
embedding_function=embeddings.embed_query,
)
"""
def __init__(
self,
redis_url: str,
@ -99,33 +134,92 @@ class Redis(VectorStore):
self.metadata_key = metadata_key
self.vector_key = vector_key
def _create_index(self, dim: int = 1536) -> None:
try:
from redis.commands.search.field import TextField, VectorField
from redis.commands.search.indexDefinition import IndexDefinition, IndexType
except ImportError:
raise ValueError(
"Could not import redis python package. "
"Please install it with `pip install redis`."
)
# Check if index exists
if not _check_index_exists(self.client, self.index_name):
# Constants
distance_metric = (
"COSINE" # distance metric for the vectors (ex. COSINE, IP, L2)
)
schema = (
TextField(name=self.content_key),
TextField(name=self.metadata_key),
VectorField(
self.vector_key,
"FLAT",
{
"TYPE": "FLOAT32",
"DIM": dim,
"DISTANCE_METRIC": distance_metric,
},
),
)
prefix = _redis_prefix(self.index_name)
# Create Redis Index
self.client.ft(self.index_name).create_index(
fields=schema,
definition=IndexDefinition(prefix=[prefix], index_type=IndexType.HASH),
)
def add_texts(
self,
texts: Iterable[str],
metadatas: Optional[List[dict]] = None,
embeddings: Optional[List[List[float]]] = None,
keys: Optional[List[str]] = None,
batch_size: int = 1000,
**kwargs: Any,
) -> List[str]:
"""Add texts data to an existing index."""
prefix = _redis_prefix(self.index_name)
keys = kwargs.get("keys")
"""Add more texts to the vectorstore.
Args:
texts (Iterable[str]): Iterable of strings/text to add to the vectorstore.
metadatas (Optional[List[dict]], optional): Optional list of metadatas.
Defaults to None.
embeddings (Optional[List[List[float]]], optional): Optional pre-generated
embeddings. Defaults to None.
keys (Optional[List[str]], optional): Optional key values to use as ids.
Defaults to None.
batch_size (int, optional): Batch size to use for writes. Defaults to 1000.
Returns:
List[str]: List of ids added to the vectorstore
"""
ids = []
prefix = _redis_prefix(self.index_name)
# Write data to redis
pipeline = self.client.pipeline(transaction=False)
for i, text in enumerate(texts):
# Use provided key otherwise use default key
# Use provided values by default or fallback
key = keys[i] if keys else _redis_key(prefix)
metadata = metadatas[i] if metadatas else {}
embedding = embeddings[i] if embeddings else self.embedding_function(text)
pipeline.hset(
key,
mapping={
self.content_key: text,
self.vector_key: np.array(
self.embedding_function(text), dtype=np.float32
).tobytes(),
self.vector_key: np.array(embedding, dtype=np.float32).tobytes(),
self.metadata_key: json.dumps(metadata),
},
)
ids.append(key)
# Write batch
if i % batch_size == 0:
pipeline.execute()
# Cleanup final batch
pipeline.execute()
return ids
@ -170,9 +264,30 @@ class Redis(VectorStore):
"""
docs_and_scores = self.similarity_search_with_score(query, k=k)
return [doc for doc, score in docs_and_scores if score < score_threshold]
def _prepare_query(self, k: int) -> Query:
try:
from redis.commands.search.query import Query
except ImportError:
raise ValueError(
"Could not import redis python package. "
"Please install it with `pip install redis`."
)
# Prepare the Query
hybrid_fields = "*"
base_query = (
f"{hybrid_fields}=>[KNN {k} @{self.vector_key} $vector AS vector_score]"
)
return_fields = [self.metadata_key, self.content_key, "vector_score"]
return (
Query(base_query)
.return_fields(*return_fields)
.sort_by("vector_score")
.paging(0, k)
.dialect(2)
)
def similarity_search_with_score(
self, query: str, k: int = 4
) -> List[Tuple[Document, float]]:
@ -185,40 +300,22 @@ class Redis(VectorStore):
Returns:
List of Documents most similar to the query and score for each
"""
try:
from redis.commands.search.query import Query
except ImportError:
raise ValueError(
"Could not import redis python package. "
"Please install it with `pip install redis`."
)
# Creates embedding vector from user query
embedding = self.embedding_function(query)
# Prepare the Query
return_fields = [self.metadata_key, self.content_key, "vector_score"]
vector_field = self.vector_key
hybrid_fields = "*"
base_query = (
f"{hybrid_fields}=>[KNN {k} @{vector_field} $vector AS vector_score]"
)
redis_query = (
Query(base_query)
.return_fields(*return_fields)
.sort_by("vector_score")
.paging(0, k)
.dialect(2)
)
# Creates Redis query
redis_query = self._prepare_query(k)
params_dict: Mapping[str, str] = {
"vector": np.array(embedding) # type: ignore
.astype(dtype=np.float32)
.tobytes()
}
# perform vector search
# Perform vector search
results = self.client.ft(self.index_name).search(redis_query, params_dict)
# Prepare document results
docs = [
(
Document(
@ -243,15 +340,15 @@ class Redis(VectorStore):
vector_key: str = "content_vector",
**kwargs: Any,
) -> Redis:
"""Construct RediSearch wrapper from raw documents.
"""Create a Redis vectorstore from raw documents.
This is a user-friendly interface that:
1. Embeds documents.
2. Creates a new index for the embeddings in the RediSearch instance.
3. Adds the documents to the newly created RediSearch index.
2. Creates a new index for the embeddings in Redis.
3. Adds the documents to the newly created Redis index.
This is intended to be a quick way to get started.
Example:
.. code-block:: python
from langchain import RediSearch
from langchain.vectorstores import Redis
from langchain.embeddings import OpenAIEmbeddings
embeddings = OpenAIEmbeddings()
redisearch = RediSearch.from_texts(
@ -261,84 +358,35 @@ class Redis(VectorStore):
)
"""
redis_url = get_from_dict_or_env(kwargs, "redis_url", "REDIS_URL")
try:
import redis
from redis.commands.search.field import TextField, VectorField
from redis.commands.search.indexDefinition import IndexDefinition, IndexType
except ImportError:
raise ValueError(
"Could not import redis python package. "
"Please install it with `pip install redis`."
)
try:
# We need to first remove redis_url from kwargs,
# otherwise passing it to Redis will result in an error.
if "redis_url" in kwargs:
kwargs.pop("redis_url")
client = redis.from_url(url=redis_url, **kwargs)
# check if redis has redisearch module installed
_check_redis_module_exist(client, REDIS_REQUIRED_MODULES)
except ValueError as e:
raise ValueError(f"Redis failed to connect: {e}")
# Create embeddings over documents
embeddings = embedding.embed_documents(texts)
if "redis_url" in kwargs:
kwargs.pop("redis_url")
# Name of the search index if not given
if not index_name:
index_name = uuid.uuid4().hex
prefix = _redis_prefix(index_name) # prefix for the document keys
# Check if index exists
if not _check_index_exists(client, index_name):
# Constants
dim = len(embeddings[0])
distance_metric = (
"COSINE" # distance metric for the vectors (ex. COSINE, IP, L2)
)
schema = (
TextField(name=content_key),
TextField(name=metadata_key),
VectorField(
vector_key,
"FLAT",
{
"TYPE": "FLOAT32",
"DIM": dim,
"DISTANCE_METRIC": distance_metric,
},
),
)
# Create Redis Index
client.ft(index_name).create_index(
fields=schema,
definition=IndexDefinition(prefix=[prefix], index_type=IndexType.HASH),
)
# Write data to Redis
pipeline = client.pipeline(transaction=False)
for i, text in enumerate(texts):
key = _redis_key(prefix)
metadata = metadatas[i] if metadatas else {}
pipeline.hset(
key,
mapping={
content_key: text,
vector_key: np.array(embeddings[i], dtype=np.float32).tobytes(),
metadata_key: json.dumps(metadata),
},
)
pipeline.execute()
return cls(
redis_url,
index_name,
embedding.embed_query,
# Create instance
instance = cls(
redis_url=redis_url,
index_name=index_name,
embedding_function=embedding.embed_query,
content_key=content_key,
metadata_key=metadata_key,
vector_key=vector_key,
**kwargs,
)
# Create embeddings over documents
embeddings = embedding.embed_documents(texts)
# Create the search index
instance._create_index(dim=len(embeddings[0]))
# Add data to Redis
instance.add_texts(texts, metadatas, embeddings)
return instance
@staticmethod
def drop_index(
index_name: str,

View File

@ -0,0 +1,55 @@
"""Test Redis cache functionality."""
import redis
import langchain
from langchain.cache import RedisCache, RedisSemanticCache
from langchain.schema import Generation, LLMResult
from tests.integration_tests.vectorstores.fake_embeddings import FakeEmbeddings
from tests.unit_tests.llms.fake_llm import FakeLLM
REDIS_TEST_URL = "redis://localhost:6379"
def test_redis_cache() -> None:
langchain.llm_cache = RedisCache(redis_=redis.Redis.from_url(REDIS_TEST_URL))
llm = FakeLLM()
params = llm.dict()
params["stop"] = None
llm_string = str(sorted([(k, v) for k, v in params.items()]))
langchain.llm_cache.update("foo", llm_string, [Generation(text="fizz")])
output = llm.generate(["foo"])
print(output)
expected_output = LLMResult(
generations=[[Generation(text="fizz")]],
llm_output={},
)
print(expected_output)
assert output == expected_output
langchain.llm_cache.redis.flushall()
def test_redis_semantic_cache() -> None:
langchain.llm_cache = RedisSemanticCache(
embedding=FakeEmbeddings(), redis_url=REDIS_TEST_URL, score_threshold=0.1
)
llm = FakeLLM()
params = llm.dict()
params["stop"] = None
llm_string = str(sorted([(k, v) for k, v in params.items()]))
langchain.llm_cache.update("foo", llm_string, [Generation(text="fizz")])
output = llm.generate(
["bar"]
) # foo and bar will have the same embedding produced by FakeEmbeddings
expected_output = LLMResult(
generations=[[Generation(text="fizz")]],
llm_output={},
)
assert output == expected_output
# clear the cache
langchain.llm_cache.clear(llm_string=llm_string)
output = llm.generate(
["bar"]
) # foo and bar will have the same embedding produced by FakeEmbeddings
# expect different output now without cached result
assert output != expected_output
langchain.llm_cache.clear(llm_string=llm_string)

View File

@ -1,26 +1,60 @@
"""Test Redis functionality."""
from langchain.docstore.document import Document
from langchain.vectorstores.redis import Redis
from tests.integration_tests.vectorstores.fake_embeddings import FakeEmbeddings
TEST_INDEX_NAME = "test"
TEST_REDIS_URL = "redis://localhost:6379"
TEST_SINGLE_RESULT = [Document(page_content="foo")]
TEST_RESULT = [Document(page_content="foo"), Document(page_content="foo")]
def drop(index_name: str) -> bool:
return Redis.drop_index(
index_name=index_name, delete_documents=True, redis_url=TEST_REDIS_URL
)
def test_redis() -> None:
"""Test end to end construction and search."""
texts = ["foo", "bar", "baz"]
docsearch = Redis.from_texts(
texts, FakeEmbeddings(), redis_url="redis://localhost:6379"
)
docsearch = Redis.from_texts(texts, FakeEmbeddings(), redis_url=TEST_REDIS_URL)
output = docsearch.similarity_search("foo", k=1)
assert output == [Document(page_content="foo")]
assert output == TEST_SINGLE_RESULT
assert drop(docsearch.index_name)
def test_redis_new_vector() -> None:
"""Test adding a new document"""
texts = ["foo", "bar", "baz"]
docsearch = Redis.from_texts(
texts, FakeEmbeddings(), redis_url="redis://localhost:6379"
docsearch = Redis.from_texts(texts, FakeEmbeddings(), redis_url=TEST_REDIS_URL)
docsearch.add_texts(["foo"])
output = docsearch.similarity_search("foo", k=2)
assert output == TEST_RESULT
assert drop(docsearch.index_name)
def test_redis_from_existing() -> None:
"""Test adding a new document"""
texts = ["foo", "bar", "baz"]
Redis.from_texts(
texts, FakeEmbeddings(), index_name=TEST_INDEX_NAME, redis_url=TEST_REDIS_URL
)
# Test creating from an existing
docsearch2 = Redis.from_existing_index(
FakeEmbeddings(), index_name=TEST_INDEX_NAME, redis_url=TEST_REDIS_URL
)
output = docsearch2.similarity_search("foo", k=1)
assert output == TEST_SINGLE_RESULT
def test_redis_add_texts_to_existing() -> None:
"""Test adding a new document"""
# Test creating from an existing
docsearch = Redis.from_existing_index(
FakeEmbeddings(), index_name=TEST_INDEX_NAME, redis_url=TEST_REDIS_URL
)
docsearch.add_texts(["foo"])
output = docsearch.similarity_search("foo", k=2)
assert output == [Document(page_content="foo"), Document(page_content="foo")]
assert output == TEST_RESULT
assert drop(TEST_INDEX_NAME)