From be7a8e08241cdfbd713f2e4fec35977bed49d715 Mon Sep 17 00:00:00 2001 From: Harrison Chase Date: Fri, 28 Apr 2023 20:47:18 -0700 Subject: [PATCH] Harrison/redis cache (#3766) Co-authored-by: Tyler Hutcherson --- docs/ecosystem/redis.md | 79 +++++ .../models/llms/examples/llm_caching.ipynb | 165 ++++++++++- langchain/cache.py | 140 ++++++++- langchain/vectorstores/__init__.py | 2 + langchain/vectorstores/redis.py | 274 ++++++++++-------- .../cache/test_redis_cache.py | 55 ++++ .../vectorstores/test_redis.py | 50 +++- 7 files changed, 616 insertions(+), 149 deletions(-) create mode 100644 docs/ecosystem/redis.md create mode 100644 tests/integration_tests/cache/test_redis_cache.py diff --git a/docs/ecosystem/redis.md b/docs/ecosystem/redis.md new file mode 100644 index 00000000..8a313707 --- /dev/null +++ b/docs/ecosystem/redis.md @@ -0,0 +1,79 @@ +# Redis + +This page covers how to use the [Redis](https://redis.com) ecosystem within LangChain. +It is broken into two parts: installation and setup, and then references to specific Redis wrappers. + +## Installation and Setup +- Install the Redis Python SDK with `pip install redis` + +## Wrappers + +### Cache + +The Cache wrapper allows for [Redis](https://redis.io) to be used as a remote, low-latency, in-memory cache for LLM prompts and responses. + +#### Standard Cache +The standard cache is the Redis bread & butter of use case in production for both [open source](https://redis.io) and [enterprise](https://redis.com) users globally. + +To import this cache: +```python +from langchain.cache import RedisCache +``` + +To use this cache with your LLMs: +```python +import langchain +import redis + +redis_client = redis.Redis.from_url(...) +langchain.llm_cache = RedisCache(redis_client) +``` + +#### Semantic Cache +Semantic caching allows users to retrieve cached prompts based on semantic similarity between the user input and previously cached results. Under the hood it blends Redis as both a cache and a vectorstore. + +To import this cache: +```python +from langchain.cache import RedisSemanticCache +``` + +To use this cache with your LLMs: +```python +import langchain +import redis + +# use any embedding provider... +from tests.integration_tests.vectorstores.fake_embeddings import FakeEmbeddings + +redis_url = "redis://localhost:6379" + +langchain.llm_cache = RedisSemanticCache( + embedding=FakeEmbeddings(), + redis_url=redis_url +) +``` + +### VectorStore + +The vectorstore wrapper turns Redis into a low-latency [vector database](https://redis.com/solutions/use-cases/vector-database/) for semantic search or LLM content retrieval. + +To import this vectorstore: +```python +from langchain.vectorstores import Redis +``` + +For a more detailed walkthrough of the Redis vectorstore wrapper, see [this notebook](../modules/indexes/vectorstores/examples/redis.ipynb). + +### Retriever + +The Redis vector store retriever wrapper generalizes the vectorstore class to perform low-latency document retrieval. To create the retriever, simply call `.as_retriever()` on the base vectorstore class. + +### Memory +Redis can be used to persist LLM conversations. + +#### Vector Store Retriever Memory + +For a more detailed walkthrough of the `VectorStoreRetrieverMemory` wrapper, see [this notebook](../modules/memory/types/vectorstore_retriever_memory.ipynb). + +#### Chat Message History Memory +For a detailed example of Redis to cache conversation message history, see [this notebook](../modules/memory/examples/redis_chat_message_history.ipynb). diff --git a/docs/modules/models/llms/examples/llm_caching.ipynb b/docs/modules/models/llms/examples/llm_caching.ipynb index a37adaad..8655c000 100644 --- a/docs/modules/models/llms/examples/llm_caching.ipynb +++ b/docs/modules/models/llms/examples/llm_caching.ipynb @@ -41,7 +41,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 5, "id": "f69f6283", "metadata": {}, "outputs": [], @@ -52,7 +52,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 6, "id": "64005d1f", "metadata": {}, "outputs": [ @@ -60,8 +60,8 @@ "name": "stdout", "output_type": "stream", "text": [ - "CPU times: user 14.2 ms, sys: 4.9 ms, total: 19.1 ms\n", - "Wall time: 1.1 s\n" + "CPU times: user 26.1 ms, sys: 21.5 ms, total: 47.6 ms\n", + "Wall time: 1.68 s\n" ] }, { @@ -70,7 +70,7 @@ "'\\n\\nWhy did the chicken cross the road?\\n\\nTo get to the other side.'" ] }, - "execution_count": 4, + "execution_count": 6, "metadata": {}, "output_type": "execute_result" } @@ -83,7 +83,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 7, "id": "c8a1cb2b", "metadata": {}, "outputs": [ @@ -91,8 +91,8 @@ "name": "stdout", "output_type": "stream", "text": [ - "CPU times: user 162 µs, sys: 7 µs, total: 169 µs\n", - "Wall time: 175 µs\n" + "CPU times: user 238 µs, sys: 143 µs, total: 381 µs\n", + "Wall time: 1.76 ms\n" ] }, { @@ -101,7 +101,7 @@ "'\\n\\nWhy did the chicken cross the road?\\n\\nTo get to the other side.'" ] }, - "execution_count": 5, + "execution_count": 7, "metadata": {}, "output_type": "execute_result" } @@ -214,9 +214,18 @@ "## Redis Cache" ] }, + { + "cell_type": "markdown", + "id": "c5c9a4d5", + "metadata": {}, + "source": [ + "### Standard Cache\n", + "Use [Redis](../../../../ecosystem/redis.md) to cache prompts and responses." + ] + }, { "cell_type": "code", - "execution_count": null, + "execution_count": 8, "id": "39f6eb0b", "metadata": {}, "outputs": [], @@ -225,15 +234,35 @@ "# (make sure your local Redis instance is running first before running this example)\n", "from redis import Redis\n", "from langchain.cache import RedisCache\n", + "\n", "langchain.llm_cache = RedisCache(redis_=Redis())" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 9, "id": "28920749", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "CPU times: user 6.88 ms, sys: 8.75 ms, total: 15.6 ms\n", + "Wall time: 1.04 s\n" + ] + }, + { + "data": { + "text/plain": [ + "'\\n\\nWhy did the chicken cross the road?\\n\\nTo get to the other side!'" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "%%time\n", "# The first time, it is not yet in cache, so it should take longer\n", @@ -242,16 +271,124 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 14, "id": "94bf9415", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "CPU times: user 1.59 ms, sys: 610 µs, total: 2.2 ms\n", + "Wall time: 5.58 ms\n" + ] + }, + { + "data": { + "text/plain": [ + "'\\n\\nWhy did the chicken cross the road?\\n\\nTo get to the other side!'" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "%%time\n", "# The second time it is, so it goes faster\n", "llm(\"Tell me a joke\")" ] }, + { + "cell_type": "markdown", + "id": "82be23f6", + "metadata": {}, + "source": [ + "### Semantic Cache\n", + "Use [Redis](../../../../ecosystem/redis.md) to cache prompts and responses and evaluate hits based on semantic similarity." + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "64df3099", + "metadata": {}, + "outputs": [], + "source": [ + "from langchain.embeddings import OpenAIEmbeddings\n", + "from langchain.cache import RedisSemanticCache\n", + "\n", + "\n", + "langchain.llm_cache = RedisSemanticCache(\n", + " redis_url=\"redis://localhost:6379\",\n", + " embedding=OpenAIEmbeddings()\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "8e91d3ac", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "CPU times: user 351 ms, sys: 156 ms, total: 507 ms\n", + "Wall time: 3.37 s\n" + ] + }, + { + "data": { + "text/plain": [ + "\"\\n\\nWhy don't scientists trust atoms?\\nBecause they make up everything.\"" + ] + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "%%time\n", + "# The first time, it is not yet in cache, so it should take longer\n", + "llm(\"Tell me a joke\")" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "id": "df856948", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "CPU times: user 6.25 ms, sys: 2.72 ms, total: 8.97 ms\n", + "Wall time: 262 ms\n" + ] + }, + { + "data": { + "text/plain": [ + "\"\\n\\nWhy don't scientists trust atoms?\\nBecause they make up everything.\"" + ] + }, + "execution_count": 27, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "%%time\n", + "# The second time, while not a direct hit, the question is semantically similar to the original question,\n", + "# so it uses the cached result!\n", + "llm(\"Tell me one joke\")" + ] + }, { "cell_type": "markdown", "id": "684eab55", diff --git a/langchain/cache.py b/langchain/cache.py index 975dd6fa..1a33b678 100644 --- a/langchain/cache.py +++ b/langchain/cache.py @@ -1,4 +1,5 @@ """Beta Feature: base interface for cache.""" +import hashlib import json from abc import ABC, abstractmethod from typing import Any, Callable, Dict, List, Optional, Tuple, Type, cast @@ -12,11 +13,18 @@ try: except ImportError: from sqlalchemy.ext.declarative import declarative_base +from langchain.embeddings.base import Embeddings from langchain.schema import Generation +from langchain.vectorstores.redis import Redis as RedisVectorstore RETURN_VAL_TYPE = List[Generation] +def _hash(_input: str) -> str: + """Use a deterministic hashing approach.""" + return hashlib.md5(_input.encode()).hexdigest() + + class BaseCache(ABC): """Base interface for cache.""" @@ -117,6 +125,8 @@ class SQLiteCache(SQLAlchemyCache): class RedisCache(BaseCache): """Cache that uses Redis as a backend.""" + # TODO - implement a TTL policy in Redis + def __init__(self, redis_: Any): """Initialize by passing in Redis instance.""" try: @@ -130,28 +140,30 @@ class RedisCache(BaseCache): raise ValueError("Please pass in Redis object.") self.redis = redis_ - def _key(self, prompt: str, llm_string: str, idx: int) -> str: - """Compute key from prompt, llm_string, and idx.""" - return str(hash(prompt + llm_string)) + "_" + str(idx) + def _key(self, prompt: str, llm_string: str) -> str: + """Compute key from prompt and llm_string""" + return _hash(prompt + llm_string) def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]: """Look up based on prompt and llm_string.""" - idx = 0 generations = [] - while self.redis.get(self._key(prompt, llm_string, idx)): - result = self.redis.get(self._key(prompt, llm_string, idx)) - if not result: - break - elif isinstance(result, bytes): - result = result.decode() - generations.append(Generation(text=result)) - idx += 1 + # Read from a Redis HASH + results = self.redis.hgetall(self._key(prompt, llm_string)) + if results: + for _, text in results.items(): + generations.append(Generation(text=text)) return generations if generations else None def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None: """Update cache based on prompt and llm_string.""" - for i, generation in enumerate(return_val): - self.redis.set(self._key(prompt, llm_string, i), generation.text) + # Write to a Redis HASH + key = self._key(prompt, llm_string) + self.redis.hset( + key, + mapping={ + str(idx): generation.text for idx, generation in enumerate(return_val) + }, + ) def clear(self, **kwargs: Any) -> None: """Clear cache. If `asynchronous` is True, flush asynchronously.""" @@ -159,6 +171,106 @@ class RedisCache(BaseCache): self.redis.flushdb(asynchronous=asynchronous, **kwargs) +class RedisSemanticCache(BaseCache): + """Cache that uses Redis as a vector-store backend.""" + + # TODO - implement a TTL policy in Redis + + def __init__( + self, redis_url: str, embedding: Embeddings, score_threshold: float = 0.2 + ): + """Initialize by passing in the `init` GPTCache func + + Args: + redis_url (str): URL to connect to Redis. + embedding (Embedding): Embedding provider for semantic encoding and search. + score_threshold (float, 0.2): + + Example: + .. code-block:: python + import langchain + + from langchain.cache import RedisSemanticCache + from langchain.embeddings import OpenAIEmbeddings + + langchain.llm_cache = RedisSemanticCache( + redis_url="redis://localhost:6379", + embedding=OpenAIEmbeddings() + ) + + """ + self._cache_dict: Dict[str, RedisVectorstore] = {} + self.redis_url = redis_url + self.embedding = embedding + self.score_threshold = score_threshold + + def _index_name(self, llm_string: str) -> str: + hashed_index = _hash(llm_string) + return f"cache:{hashed_index}" + + def _get_llm_cache(self, llm_string: str) -> RedisVectorstore: + index_name = self._index_name(llm_string) + + # return vectorstore client for the specific llm string + if index_name in self._cache_dict: + return self._cache_dict[index_name] + + # create new vectorstore client for the specific llm string + try: + self._cache_dict[index_name] = RedisVectorstore.from_existing_index( + embedding=self.embedding, + index_name=index_name, + redis_url=self.redis_url, + ) + except ValueError: + redis = RedisVectorstore( + embedding_function=self.embedding.embed_query, + index_name=index_name, + redis_url=self.redis_url, + ) + _embedding = self.embedding.embed_query(text="test") + redis._create_index(dim=len(_embedding)) + self._cache_dict[index_name] = redis + + return self._cache_dict[index_name] + + def clear(self, **kwargs: Any) -> None: + """Clear semantic cache for a given llm_string.""" + index_name = self._index_name(kwargs["llm_string"]) + if index_name in self._cache_dict: + self._cache_dict[index_name].drop_index( + index_name=index_name, delete_documents=True, redis_url=self.redis_url + ) + del self._cache_dict[index_name] + + def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]: + """Look up based on prompt and llm_string.""" + llm_cache = self._get_llm_cache(llm_string) + generations = [] + # Read from a Hash + results = llm_cache.similarity_search_limit_score( + query=prompt, + k=1, + score_threshold=self.score_threshold, + ) + if results: + for document in results: + for text in document.metadata["return_val"]: + generations.append(Generation(text=text)) + return generations if generations else None + + def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None: + """Update cache based on prompt and llm_string.""" + llm_cache = self._get_llm_cache(llm_string) + # Write to vectorstore + metadata = { + "llm_string": llm_string, + "prompt": prompt, + "return_val": [generation.text for generation in return_val], + } + llm_cache.add_texts(texts=[prompt], metadatas=[metadata]) + + class GPTCache(BaseCache): """Cache that uses GPTCache as a backend.""" diff --git a/langchain/vectorstores/__init__.py b/langchain/vectorstores/__init__.py index 51ac88b5..973d4eda 100644 --- a/langchain/vectorstores/__init__.py +++ b/langchain/vectorstores/__init__.py @@ -13,11 +13,13 @@ from langchain.vectorstores.myscale import MyScale, MyScaleSettings from langchain.vectorstores.opensearch_vector_search import OpenSearchVectorSearch from langchain.vectorstores.pinecone import Pinecone from langchain.vectorstores.qdrant import Qdrant +from langchain.vectorstores.redis import Redis from langchain.vectorstores.supabase import SupabaseVectorStore from langchain.vectorstores.weaviate import Weaviate from langchain.vectorstores.zilliz import Zilliz __all__ = [ + "Redis", "ElasticVectorSearch", "FAISS", "VectorStore", diff --git a/langchain/vectorstores/redis.py b/langchain/vectorstores/redis.py index a015bba9..764d3955 100644 --- a/langchain/vectorstores/redis.py +++ b/langchain/vectorstores/redis.py @@ -4,11 +4,21 @@ from __future__ import annotations import json import logging import uuid -from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Tuple, Type +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + Iterable, + List, + Mapping, + Optional, + Tuple, + Type, +) import numpy as np from pydantic import BaseModel, root_validator -from redis.client import Redis as RedisType from langchain.docstore.document import Document from langchain.embeddings.base import Embeddings @@ -18,29 +28,36 @@ from langchain.vectorstores.base import VectorStore logger = logging.getLogger(__name__) +if TYPE_CHECKING: + from redis.client import Redis as RedisType + from redis.commands.search.query import Query + # required modules REDIS_REQUIRED_MODULES = [ {"name": "search", "ver": 20400}, + {"name": "searchlight", "ver": 20400}, ] -def _check_redis_module_exist(client: RedisType, modules: List[dict]) -> None: +def _check_redis_module_exist(client: RedisType, required_modules: List[dict]) -> None: """Check if the correct Redis modules are installed.""" installed_modules = client.module_list() installed_modules = { module[b"name"].decode("utf-8"): module for module in installed_modules } - for module in modules: - if module["name"] not in installed_modules or int( + for module in required_modules: + if module["name"] in installed_modules and int( installed_modules[module["name"]][b"ver"] - ) < int(module["ver"]): - error_message = ( - "You must add the RediSearch (>= 2.4) module from Redis Stack. " - "Please refer to Redis Stack docs: https://redis.io/docs/stack/" - ) - logging.error(error_message) - raise ValueError(error_message) + ) >= int(module["ver"]): + return + # otherwise raise error + error_message = ( + "You must add the RediSearch (>= 2.4) module from Redis Stack. " + "Please refer to Redis Stack docs: https://redis.io/docs/stack/" + ) + logging.error(error_message) + raise ValueError(error_message) def _check_index_exists(client: RedisType, index_name: str) -> bool: @@ -65,6 +82,24 @@ def _redis_prefix(index_name: str) -> str: class Redis(VectorStore): + """Wrapper around Redis vector database. + + To use, you should have the ``redis`` python package installed. + + Example: + .. code-block:: python + + from langchain.vectorstores import Redis + from langchain.embeddings import OpenAIEmbeddings + + embeddings = OpenAIEmbeddings() + vectorstore = Redis( + redis_url="redis://username:password@localhost:6379" + index_name="my-index", + embedding_function=embeddings.embed_query, + ) + """ + def __init__( self, redis_url: str, @@ -99,33 +134,92 @@ class Redis(VectorStore): self.metadata_key = metadata_key self.vector_key = vector_key + def _create_index(self, dim: int = 1536) -> None: + try: + from redis.commands.search.field import TextField, VectorField + from redis.commands.search.indexDefinition import IndexDefinition, IndexType + except ImportError: + raise ValueError( + "Could not import redis python package. " + "Please install it with `pip install redis`." + ) + + # Check if index exists + if not _check_index_exists(self.client, self.index_name): + # Constants + distance_metric = ( + "COSINE" # distance metric for the vectors (ex. COSINE, IP, L2) + ) + schema = ( + TextField(name=self.content_key), + TextField(name=self.metadata_key), + VectorField( + self.vector_key, + "FLAT", + { + "TYPE": "FLOAT32", + "DIM": dim, + "DISTANCE_METRIC": distance_metric, + }, + ), + ) + prefix = _redis_prefix(self.index_name) + + # Create Redis Index + self.client.ft(self.index_name).create_index( + fields=schema, + definition=IndexDefinition(prefix=[prefix], index_type=IndexType.HASH), + ) + def add_texts( self, texts: Iterable[str], metadatas: Optional[List[dict]] = None, + embeddings: Optional[List[List[float]]] = None, + keys: Optional[List[str]] = None, + batch_size: int = 1000, **kwargs: Any, ) -> List[str]: - """Add texts data to an existing index.""" - prefix = _redis_prefix(self.index_name) - keys = kwargs.get("keys") + """Add more texts to the vectorstore. + + Args: + texts (Iterable[str]): Iterable of strings/text to add to the vectorstore. + metadatas (Optional[List[dict]], optional): Optional list of metadatas. + Defaults to None. + embeddings (Optional[List[List[float]]], optional): Optional pre-generated + embeddings. Defaults to None. + keys (Optional[List[str]], optional): Optional key values to use as ids. + Defaults to None. + batch_size (int, optional): Batch size to use for writes. Defaults to 1000. + + Returns: + List[str]: List of ids added to the vectorstore + """ ids = [] + prefix = _redis_prefix(self.index_name) + # Write data to redis pipeline = self.client.pipeline(transaction=False) for i, text in enumerate(texts): - # Use provided key otherwise use default key + # Use provided values by default or fallback key = keys[i] if keys else _redis_key(prefix) metadata = metadatas[i] if metadatas else {} + embedding = embeddings[i] if embeddings else self.embedding_function(text) pipeline.hset( key, mapping={ self.content_key: text, - self.vector_key: np.array( - self.embedding_function(text), dtype=np.float32 - ).tobytes(), + self.vector_key: np.array(embedding, dtype=np.float32).tobytes(), self.metadata_key: json.dumps(metadata), }, ) ids.append(key) + + # Write batch + if i % batch_size == 0: + pipeline.execute() + + # Cleanup final batch pipeline.execute() return ids @@ -170,9 +264,30 @@ class Redis(VectorStore): """ docs_and_scores = self.similarity_search_with_score(query, k=k) - return [doc for doc, score in docs_and_scores if score < score_threshold] + def _prepare_query(self, k: int) -> Query: + try: + from redis.commands.search.query import Query + except ImportError: + raise ValueError( + "Could not import redis python package. " + "Please install it with `pip install redis`." + ) + # Prepare the Query + hybrid_fields = "*" + base_query = ( + f"{hybrid_fields}=>[KNN {k} @{self.vector_key} $vector AS vector_score]" + ) + return_fields = [self.metadata_key, self.content_key, "vector_score"] + return ( + Query(base_query) + .return_fields(*return_fields) + .sort_by("vector_score") + .paging(0, k) + .dialect(2) + ) + def similarity_search_with_score( self, query: str, k: int = 4 ) -> List[Tuple[Document, float]]: @@ -185,40 +300,22 @@ class Redis(VectorStore): Returns: List of Documents most similar to the query and score for each """ - try: - from redis.commands.search.query import Query - except ImportError: - raise ValueError( - "Could not import redis python package. " - "Please install it with `pip install redis`." - ) - # Creates embedding vector from user query embedding = self.embedding_function(query) - # Prepare the Query - return_fields = [self.metadata_key, self.content_key, "vector_score"] - vector_field = self.vector_key - hybrid_fields = "*" - base_query = ( - f"{hybrid_fields}=>[KNN {k} @{vector_field} $vector AS vector_score]" - ) - redis_query = ( - Query(base_query) - .return_fields(*return_fields) - .sort_by("vector_score") - .paging(0, k) - .dialect(2) - ) + # Creates Redis query + redis_query = self._prepare_query(k) + params_dict: Mapping[str, str] = { "vector": np.array(embedding) # type: ignore .astype(dtype=np.float32) .tobytes() } - # perform vector search + # Perform vector search results = self.client.ft(self.index_name).search(redis_query, params_dict) + # Prepare document results docs = [ ( Document( @@ -243,15 +340,15 @@ class Redis(VectorStore): vector_key: str = "content_vector", **kwargs: Any, ) -> Redis: - """Construct RediSearch wrapper from raw documents. + """Create a Redis vectorstore from raw documents. This is a user-friendly interface that: 1. Embeds documents. - 2. Creates a new index for the embeddings in the RediSearch instance. - 3. Adds the documents to the newly created RediSearch index. + 2. Creates a new index for the embeddings in Redis. + 3. Adds the documents to the newly created Redis index. This is intended to be a quick way to get started. Example: .. code-block:: python - from langchain import RediSearch + from langchain.vectorstores import Redis from langchain.embeddings import OpenAIEmbeddings embeddings = OpenAIEmbeddings() redisearch = RediSearch.from_texts( @@ -261,84 +358,35 @@ class Redis(VectorStore): ) """ redis_url = get_from_dict_or_env(kwargs, "redis_url", "REDIS_URL") - try: - import redis - from redis.commands.search.field import TextField, VectorField - from redis.commands.search.indexDefinition import IndexDefinition, IndexType - except ImportError: - raise ValueError( - "Could not import redis python package. " - "Please install it with `pip install redis`." - ) - try: - # We need to first remove redis_url from kwargs, - # otherwise passing it to Redis will result in an error. - if "redis_url" in kwargs: - kwargs.pop("redis_url") - client = redis.from_url(url=redis_url, **kwargs) - # check if redis has redisearch module installed - _check_redis_module_exist(client, REDIS_REQUIRED_MODULES) - except ValueError as e: - raise ValueError(f"Redis failed to connect: {e}") - # Create embeddings over documents - embeddings = embedding.embed_documents(texts) + if "redis_url" in kwargs: + kwargs.pop("redis_url") # Name of the search index if not given if not index_name: index_name = uuid.uuid4().hex - prefix = _redis_prefix(index_name) # prefix for the document keys - # Check if index exists - if not _check_index_exists(client, index_name): - # Constants - dim = len(embeddings[0]) - distance_metric = ( - "COSINE" # distance metric for the vectors (ex. COSINE, IP, L2) - ) - schema = ( - TextField(name=content_key), - TextField(name=metadata_key), - VectorField( - vector_key, - "FLAT", - { - "TYPE": "FLOAT32", - "DIM": dim, - "DISTANCE_METRIC": distance_metric, - }, - ), - ) - # Create Redis Index - client.ft(index_name).create_index( - fields=schema, - definition=IndexDefinition(prefix=[prefix], index_type=IndexType.HASH), - ) - - # Write data to Redis - pipeline = client.pipeline(transaction=False) - for i, text in enumerate(texts): - key = _redis_key(prefix) - metadata = metadatas[i] if metadatas else {} - pipeline.hset( - key, - mapping={ - content_key: text, - vector_key: np.array(embeddings[i], dtype=np.float32).tobytes(), - metadata_key: json.dumps(metadata), - }, - ) - pipeline.execute() - return cls( - redis_url, - index_name, - embedding.embed_query, + # Create instance + instance = cls( + redis_url=redis_url, + index_name=index_name, + embedding_function=embedding.embed_query, content_key=content_key, metadata_key=metadata_key, vector_key=vector_key, **kwargs, ) + # Create embeddings over documents + embeddings = embedding.embed_documents(texts) + + # Create the search index + instance._create_index(dim=len(embeddings[0])) + + # Add data to Redis + instance.add_texts(texts, metadatas, embeddings) + return instance + @staticmethod def drop_index( index_name: str, diff --git a/tests/integration_tests/cache/test_redis_cache.py b/tests/integration_tests/cache/test_redis_cache.py new file mode 100644 index 00000000..7ce18bf6 --- /dev/null +++ b/tests/integration_tests/cache/test_redis_cache.py @@ -0,0 +1,55 @@ +"""Test Redis cache functionality.""" +import redis + +import langchain +from langchain.cache import RedisCache, RedisSemanticCache +from langchain.schema import Generation, LLMResult +from tests.integration_tests.vectorstores.fake_embeddings import FakeEmbeddings +from tests.unit_tests.llms.fake_llm import FakeLLM + +REDIS_TEST_URL = "redis://localhost:6379" + + +def test_redis_cache() -> None: + langchain.llm_cache = RedisCache(redis_=redis.Redis.from_url(REDIS_TEST_URL)) + llm = FakeLLM() + params = llm.dict() + params["stop"] = None + llm_string = str(sorted([(k, v) for k, v in params.items()])) + langchain.llm_cache.update("foo", llm_string, [Generation(text="fizz")]) + output = llm.generate(["foo"]) + print(output) + expected_output = LLMResult( + generations=[[Generation(text="fizz")]], + llm_output={}, + ) + print(expected_output) + assert output == expected_output + langchain.llm_cache.redis.flushall() + + +def test_redis_semantic_cache() -> None: + langchain.llm_cache = RedisSemanticCache( + embedding=FakeEmbeddings(), redis_url=REDIS_TEST_URL, score_threshold=0.1 + ) + llm = FakeLLM() + params = llm.dict() + params["stop"] = None + llm_string = str(sorted([(k, v) for k, v in params.items()])) + langchain.llm_cache.update("foo", llm_string, [Generation(text="fizz")]) + output = llm.generate( + ["bar"] + ) # foo and bar will have the same embedding produced by FakeEmbeddings + expected_output = LLMResult( + generations=[[Generation(text="fizz")]], + llm_output={}, + ) + assert output == expected_output + # clear the cache + langchain.llm_cache.clear(llm_string=llm_string) + output = llm.generate( + ["bar"] + ) # foo and bar will have the same embedding produced by FakeEmbeddings + # expect different output now without cached result + assert output != expected_output + langchain.llm_cache.clear(llm_string=llm_string) diff --git a/tests/integration_tests/vectorstores/test_redis.py b/tests/integration_tests/vectorstores/test_redis.py index 15d5651a..cbbce989 100644 --- a/tests/integration_tests/vectorstores/test_redis.py +++ b/tests/integration_tests/vectorstores/test_redis.py @@ -1,26 +1,60 @@ """Test Redis functionality.""" - from langchain.docstore.document import Document from langchain.vectorstores.redis import Redis from tests.integration_tests.vectorstores.fake_embeddings import FakeEmbeddings +TEST_INDEX_NAME = "test" +TEST_REDIS_URL = "redis://localhost:6379" +TEST_SINGLE_RESULT = [Document(page_content="foo")] +TEST_RESULT = [Document(page_content="foo"), Document(page_content="foo")] + + +def drop(index_name: str) -> bool: + return Redis.drop_index( + index_name=index_name, delete_documents=True, redis_url=TEST_REDIS_URL + ) + def test_redis() -> None: """Test end to end construction and search.""" texts = ["foo", "bar", "baz"] - docsearch = Redis.from_texts( - texts, FakeEmbeddings(), redis_url="redis://localhost:6379" - ) + docsearch = Redis.from_texts(texts, FakeEmbeddings(), redis_url=TEST_REDIS_URL) output = docsearch.similarity_search("foo", k=1) - assert output == [Document(page_content="foo")] + assert output == TEST_SINGLE_RESULT + assert drop(docsearch.index_name) def test_redis_new_vector() -> None: """Test adding a new document""" texts = ["foo", "bar", "baz"] - docsearch = Redis.from_texts( - texts, FakeEmbeddings(), redis_url="redis://localhost:6379" + docsearch = Redis.from_texts(texts, FakeEmbeddings(), redis_url=TEST_REDIS_URL) + docsearch.add_texts(["foo"]) + output = docsearch.similarity_search("foo", k=2) + assert output == TEST_RESULT + assert drop(docsearch.index_name) + + +def test_redis_from_existing() -> None: + """Test adding a new document""" + texts = ["foo", "bar", "baz"] + Redis.from_texts( + texts, FakeEmbeddings(), index_name=TEST_INDEX_NAME, redis_url=TEST_REDIS_URL + ) + # Test creating from an existing + docsearch2 = Redis.from_existing_index( + FakeEmbeddings(), index_name=TEST_INDEX_NAME, redis_url=TEST_REDIS_URL + ) + output = docsearch2.similarity_search("foo", k=1) + assert output == TEST_SINGLE_RESULT + + +def test_redis_add_texts_to_existing() -> None: + """Test adding a new document""" + # Test creating from an existing + docsearch = Redis.from_existing_index( + FakeEmbeddings(), index_name=TEST_INDEX_NAME, redis_url=TEST_REDIS_URL ) docsearch.add_texts(["foo"]) output = docsearch.similarity_search("foo", k=2) - assert output == [Document(page_content="foo"), Document(page_content="foo")] + assert output == TEST_RESULT + assert drop(TEST_INDEX_NAME)