forked from Archives/langchain
Harrison/redis cache (#3766)
Co-authored-by: Tyler Hutcherson <tyler.hutcherson@redis.com>
This commit is contained in:
parent
b588446bf9
commit
be7a8e0824
79
docs/ecosystem/redis.md
Normal file
79
docs/ecosystem/redis.md
Normal file
@ -0,0 +1,79 @@
|
||||
# Redis
|
||||
|
||||
This page covers how to use the [Redis](https://redis.com) ecosystem within LangChain.
|
||||
It is broken into two parts: installation and setup, and then references to specific Redis wrappers.
|
||||
|
||||
## Installation and Setup
|
||||
- Install the Redis Python SDK with `pip install redis`
|
||||
|
||||
## Wrappers
|
||||
|
||||
### Cache
|
||||
|
||||
The Cache wrapper allows for [Redis](https://redis.io) to be used as a remote, low-latency, in-memory cache for LLM prompts and responses.
|
||||
|
||||
#### Standard Cache
|
||||
The standard cache is the Redis bread & butter of use case in production for both [open source](https://redis.io) and [enterprise](https://redis.com) users globally.
|
||||
|
||||
To import this cache:
|
||||
```python
|
||||
from langchain.cache import RedisCache
|
||||
```
|
||||
|
||||
To use this cache with your LLMs:
|
||||
```python
|
||||
import langchain
|
||||
import redis
|
||||
|
||||
redis_client = redis.Redis.from_url(...)
|
||||
langchain.llm_cache = RedisCache(redis_client)
|
||||
```
|
||||
|
||||
#### Semantic Cache
|
||||
Semantic caching allows users to retrieve cached prompts based on semantic similarity between the user input and previously cached results. Under the hood it blends Redis as both a cache and a vectorstore.
|
||||
|
||||
To import this cache:
|
||||
```python
|
||||
from langchain.cache import RedisSemanticCache
|
||||
```
|
||||
|
||||
To use this cache with your LLMs:
|
||||
```python
|
||||
import langchain
|
||||
import redis
|
||||
|
||||
# use any embedding provider...
|
||||
from tests.integration_tests.vectorstores.fake_embeddings import FakeEmbeddings
|
||||
|
||||
redis_url = "redis://localhost:6379"
|
||||
|
||||
langchain.llm_cache = RedisSemanticCache(
|
||||
embedding=FakeEmbeddings(),
|
||||
redis_url=redis_url
|
||||
)
|
||||
```
|
||||
|
||||
### VectorStore
|
||||
|
||||
The vectorstore wrapper turns Redis into a low-latency [vector database](https://redis.com/solutions/use-cases/vector-database/) for semantic search or LLM content retrieval.
|
||||
|
||||
To import this vectorstore:
|
||||
```python
|
||||
from langchain.vectorstores import Redis
|
||||
```
|
||||
|
||||
For a more detailed walkthrough of the Redis vectorstore wrapper, see [this notebook](../modules/indexes/vectorstores/examples/redis.ipynb).
|
||||
|
||||
### Retriever
|
||||
|
||||
The Redis vector store retriever wrapper generalizes the vectorstore class to perform low-latency document retrieval. To create the retriever, simply call `.as_retriever()` on the base vectorstore class.
|
||||
|
||||
### Memory
|
||||
Redis can be used to persist LLM conversations.
|
||||
|
||||
#### Vector Store Retriever Memory
|
||||
|
||||
For a more detailed walkthrough of the `VectorStoreRetrieverMemory` wrapper, see [this notebook](../modules/memory/types/vectorstore_retriever_memory.ipynb).
|
||||
|
||||
#### Chat Message History Memory
|
||||
For a detailed example of Redis to cache conversation message history, see [this notebook](../modules/memory/examples/redis_chat_message_history.ipynb).
|
@ -41,7 +41,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"execution_count": 5,
|
||||
"id": "f69f6283",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@ -52,7 +52,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"execution_count": 6,
|
||||
"id": "64005d1f",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@ -60,8 +60,8 @@
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"CPU times: user 14.2 ms, sys: 4.9 ms, total: 19.1 ms\n",
|
||||
"Wall time: 1.1 s\n"
|
||||
"CPU times: user 26.1 ms, sys: 21.5 ms, total: 47.6 ms\n",
|
||||
"Wall time: 1.68 s\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -70,7 +70,7 @@
|
||||
"'\\n\\nWhy did the chicken cross the road?\\n\\nTo get to the other side.'"
|
||||
]
|
||||
},
|
||||
"execution_count": 4,
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
@ -83,7 +83,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"execution_count": 7,
|
||||
"id": "c8a1cb2b",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@ -91,8 +91,8 @@
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"CPU times: user 162 µs, sys: 7 µs, total: 169 µs\n",
|
||||
"Wall time: 175 µs\n"
|
||||
"CPU times: user 238 µs, sys: 143 µs, total: 381 µs\n",
|
||||
"Wall time: 1.76 ms\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -101,7 +101,7 @@
|
||||
"'\\n\\nWhy did the chicken cross the road?\\n\\nTo get to the other side.'"
|
||||
]
|
||||
},
|
||||
"execution_count": 5,
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
@ -214,9 +214,18 @@
|
||||
"## Redis Cache"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "c5c9a4d5",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Standard Cache\n",
|
||||
"Use [Redis](../../../../ecosystem/redis.md) to cache prompts and responses."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 8,
|
||||
"id": "39f6eb0b",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@ -225,15 +234,35 @@
|
||||
"# (make sure your local Redis instance is running first before running this example)\n",
|
||||
"from redis import Redis\n",
|
||||
"from langchain.cache import RedisCache\n",
|
||||
"\n",
|
||||
"langchain.llm_cache = RedisCache(redis_=Redis())"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 9,
|
||||
"id": "28920749",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"CPU times: user 6.88 ms, sys: 8.75 ms, total: 15.6 ms\n",
|
||||
"Wall time: 1.04 s\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"'\\n\\nWhy did the chicken cross the road?\\n\\nTo get to the other side!'"
|
||||
]
|
||||
},
|
||||
"execution_count": 9,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"%%time\n",
|
||||
"# The first time, it is not yet in cache, so it should take longer\n",
|
||||
@ -242,16 +271,124 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 14,
|
||||
"id": "94bf9415",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"CPU times: user 1.59 ms, sys: 610 µs, total: 2.2 ms\n",
|
||||
"Wall time: 5.58 ms\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"'\\n\\nWhy did the chicken cross the road?\\n\\nTo get to the other side!'"
|
||||
]
|
||||
},
|
||||
"execution_count": 14,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"%%time\n",
|
||||
"# The second time it is, so it goes faster\n",
|
||||
"llm(\"Tell me a joke\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "82be23f6",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Semantic Cache\n",
|
||||
"Use [Redis](../../../../ecosystem/redis.md) to cache prompts and responses and evaluate hits based on semantic similarity."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 15,
|
||||
"id": "64df3099",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.embeddings import OpenAIEmbeddings\n",
|
||||
"from langchain.cache import RedisSemanticCache\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"langchain.llm_cache = RedisSemanticCache(\n",
|
||||
" redis_url=\"redis://localhost:6379\",\n",
|
||||
" embedding=OpenAIEmbeddings()\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 16,
|
||||
"id": "8e91d3ac",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"CPU times: user 351 ms, sys: 156 ms, total: 507 ms\n",
|
||||
"Wall time: 3.37 s\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"\"\\n\\nWhy don't scientists trust atoms?\\nBecause they make up everything.\""
|
||||
]
|
||||
},
|
||||
"execution_count": 16,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"%%time\n",
|
||||
"# The first time, it is not yet in cache, so it should take longer\n",
|
||||
"llm(\"Tell me a joke\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 27,
|
||||
"id": "df856948",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"CPU times: user 6.25 ms, sys: 2.72 ms, total: 8.97 ms\n",
|
||||
"Wall time: 262 ms\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"\"\\n\\nWhy don't scientists trust atoms?\\nBecause they make up everything.\""
|
||||
]
|
||||
},
|
||||
"execution_count": 27,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"%%time\n",
|
||||
"# The second time, while not a direct hit, the question is semantically similar to the original question,\n",
|
||||
"# so it uses the cached result!\n",
|
||||
"llm(\"Tell me one joke\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "684eab55",
|
||||
|
@ -1,4 +1,5 @@
|
||||
"""Beta Feature: base interface for cache."""
|
||||
import hashlib
|
||||
import json
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, cast
|
||||
@ -12,11 +13,18 @@ try:
|
||||
except ImportError:
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.schema import Generation
|
||||
from langchain.vectorstores.redis import Redis as RedisVectorstore
|
||||
|
||||
RETURN_VAL_TYPE = List[Generation]
|
||||
|
||||
|
||||
def _hash(_input: str) -> str:
|
||||
"""Use a deterministic hashing approach."""
|
||||
return hashlib.md5(_input.encode()).hexdigest()
|
||||
|
||||
|
||||
class BaseCache(ABC):
|
||||
"""Base interface for cache."""
|
||||
|
||||
@ -117,6 +125,8 @@ class SQLiteCache(SQLAlchemyCache):
|
||||
class RedisCache(BaseCache):
|
||||
"""Cache that uses Redis as a backend."""
|
||||
|
||||
# TODO - implement a TTL policy in Redis
|
||||
|
||||
def __init__(self, redis_: Any):
|
||||
"""Initialize by passing in Redis instance."""
|
||||
try:
|
||||
@ -130,28 +140,30 @@ class RedisCache(BaseCache):
|
||||
raise ValueError("Please pass in Redis object.")
|
||||
self.redis = redis_
|
||||
|
||||
def _key(self, prompt: str, llm_string: str, idx: int) -> str:
|
||||
"""Compute key from prompt, llm_string, and idx."""
|
||||
return str(hash(prompt + llm_string)) + "_" + str(idx)
|
||||
def _key(self, prompt: str, llm_string: str) -> str:
|
||||
"""Compute key from prompt and llm_string"""
|
||||
return _hash(prompt + llm_string)
|
||||
|
||||
def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]:
|
||||
"""Look up based on prompt and llm_string."""
|
||||
idx = 0
|
||||
generations = []
|
||||
while self.redis.get(self._key(prompt, llm_string, idx)):
|
||||
result = self.redis.get(self._key(prompt, llm_string, idx))
|
||||
if not result:
|
||||
break
|
||||
elif isinstance(result, bytes):
|
||||
result = result.decode()
|
||||
generations.append(Generation(text=result))
|
||||
idx += 1
|
||||
# Read from a Redis HASH
|
||||
results = self.redis.hgetall(self._key(prompt, llm_string))
|
||||
if results:
|
||||
for _, text in results.items():
|
||||
generations.append(Generation(text=text))
|
||||
return generations if generations else None
|
||||
|
||||
def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None:
|
||||
"""Update cache based on prompt and llm_string."""
|
||||
for i, generation in enumerate(return_val):
|
||||
self.redis.set(self._key(prompt, llm_string, i), generation.text)
|
||||
# Write to a Redis HASH
|
||||
key = self._key(prompt, llm_string)
|
||||
self.redis.hset(
|
||||
key,
|
||||
mapping={
|
||||
str(idx): generation.text for idx, generation in enumerate(return_val)
|
||||
},
|
||||
)
|
||||
|
||||
def clear(self, **kwargs: Any) -> None:
|
||||
"""Clear cache. If `asynchronous` is True, flush asynchronously."""
|
||||
@ -159,6 +171,106 @@ class RedisCache(BaseCache):
|
||||
self.redis.flushdb(asynchronous=asynchronous, **kwargs)
|
||||
|
||||
|
||||
class RedisSemanticCache(BaseCache):
|
||||
"""Cache that uses Redis as a vector-store backend."""
|
||||
|
||||
# TODO - implement a TTL policy in Redis
|
||||
|
||||
def __init__(
|
||||
self, redis_url: str, embedding: Embeddings, score_threshold: float = 0.2
|
||||
):
|
||||
"""Initialize by passing in the `init` GPTCache func
|
||||
|
||||
Args:
|
||||
redis_url (str): URL to connect to Redis.
|
||||
embedding (Embedding): Embedding provider for semantic encoding and search.
|
||||
score_threshold (float, 0.2):
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
import langchain
|
||||
|
||||
from langchain.cache import RedisSemanticCache
|
||||
from langchain.embeddings import OpenAIEmbeddings
|
||||
|
||||
langchain.llm_cache = RedisSemanticCache(
|
||||
redis_url="redis://localhost:6379",
|
||||
embedding=OpenAIEmbeddings()
|
||||
)
|
||||
|
||||
"""
|
||||
self._cache_dict: Dict[str, RedisVectorstore] = {}
|
||||
self.redis_url = redis_url
|
||||
self.embedding = embedding
|
||||
self.score_threshold = score_threshold
|
||||
|
||||
def _index_name(self, llm_string: str) -> str:
|
||||
hashed_index = _hash(llm_string)
|
||||
return f"cache:{hashed_index}"
|
||||
|
||||
def _get_llm_cache(self, llm_string: str) -> RedisVectorstore:
|
||||
index_name = self._index_name(llm_string)
|
||||
|
||||
# return vectorstore client for the specific llm string
|
||||
if index_name in self._cache_dict:
|
||||
return self._cache_dict[index_name]
|
||||
|
||||
# create new vectorstore client for the specific llm string
|
||||
try:
|
||||
self._cache_dict[index_name] = RedisVectorstore.from_existing_index(
|
||||
embedding=self.embedding,
|
||||
index_name=index_name,
|
||||
redis_url=self.redis_url,
|
||||
)
|
||||
except ValueError:
|
||||
redis = RedisVectorstore(
|
||||
embedding_function=self.embedding.embed_query,
|
||||
index_name=index_name,
|
||||
redis_url=self.redis_url,
|
||||
)
|
||||
_embedding = self.embedding.embed_query(text="test")
|
||||
redis._create_index(dim=len(_embedding))
|
||||
self._cache_dict[index_name] = redis
|
||||
|
||||
return self._cache_dict[index_name]
|
||||
|
||||
def clear(self, **kwargs: Any) -> None:
|
||||
"""Clear semantic cache for a given llm_string."""
|
||||
index_name = self._index_name(kwargs["llm_string"])
|
||||
if index_name in self._cache_dict:
|
||||
self._cache_dict[index_name].drop_index(
|
||||
index_name=index_name, delete_documents=True, redis_url=self.redis_url
|
||||
)
|
||||
del self._cache_dict[index_name]
|
||||
|
||||
def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]:
|
||||
"""Look up based on prompt and llm_string."""
|
||||
llm_cache = self._get_llm_cache(llm_string)
|
||||
generations = []
|
||||
# Read from a Hash
|
||||
results = llm_cache.similarity_search_limit_score(
|
||||
query=prompt,
|
||||
k=1,
|
||||
score_threshold=self.score_threshold,
|
||||
)
|
||||
if results:
|
||||
for document in results:
|
||||
for text in document.metadata["return_val"]:
|
||||
generations.append(Generation(text=text))
|
||||
return generations if generations else None
|
||||
|
||||
def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None:
|
||||
"""Update cache based on prompt and llm_string."""
|
||||
llm_cache = self._get_llm_cache(llm_string)
|
||||
# Write to vectorstore
|
||||
metadata = {
|
||||
"llm_string": llm_string,
|
||||
"prompt": prompt,
|
||||
"return_val": [generation.text for generation in return_val],
|
||||
}
|
||||
llm_cache.add_texts(texts=[prompt], metadatas=[metadata])
|
||||
|
||||
|
||||
class GPTCache(BaseCache):
|
||||
"""Cache that uses GPTCache as a backend."""
|
||||
|
||||
|
@ -13,11 +13,13 @@ from langchain.vectorstores.myscale import MyScale, MyScaleSettings
|
||||
from langchain.vectorstores.opensearch_vector_search import OpenSearchVectorSearch
|
||||
from langchain.vectorstores.pinecone import Pinecone
|
||||
from langchain.vectorstores.qdrant import Qdrant
|
||||
from langchain.vectorstores.redis import Redis
|
||||
from langchain.vectorstores.supabase import SupabaseVectorStore
|
||||
from langchain.vectorstores.weaviate import Weaviate
|
||||
from langchain.vectorstores.zilliz import Zilliz
|
||||
|
||||
__all__ = [
|
||||
"Redis",
|
||||
"ElasticVectorSearch",
|
||||
"FAISS",
|
||||
"VectorStore",
|
||||
|
@ -4,11 +4,21 @@ from __future__ import annotations
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Tuple, Type
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
Iterable,
|
||||
List,
|
||||
Mapping,
|
||||
Optional,
|
||||
Tuple,
|
||||
Type,
|
||||
)
|
||||
|
||||
import numpy as np
|
||||
from pydantic import BaseModel, root_validator
|
||||
from redis.client import Redis as RedisType
|
||||
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.embeddings.base import Embeddings
|
||||
@ -18,29 +28,36 @@ from langchain.vectorstores.base import VectorStore
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from redis.client import Redis as RedisType
|
||||
from redis.commands.search.query import Query
|
||||
|
||||
|
||||
# required modules
|
||||
REDIS_REQUIRED_MODULES = [
|
||||
{"name": "search", "ver": 20400},
|
||||
{"name": "searchlight", "ver": 20400},
|
||||
]
|
||||
|
||||
|
||||
def _check_redis_module_exist(client: RedisType, modules: List[dict]) -> None:
|
||||
def _check_redis_module_exist(client: RedisType, required_modules: List[dict]) -> None:
|
||||
"""Check if the correct Redis modules are installed."""
|
||||
installed_modules = client.module_list()
|
||||
installed_modules = {
|
||||
module[b"name"].decode("utf-8"): module for module in installed_modules
|
||||
}
|
||||
for module in modules:
|
||||
if module["name"] not in installed_modules or int(
|
||||
for module in required_modules:
|
||||
if module["name"] in installed_modules and int(
|
||||
installed_modules[module["name"]][b"ver"]
|
||||
) < int(module["ver"]):
|
||||
error_message = (
|
||||
"You must add the RediSearch (>= 2.4) module from Redis Stack. "
|
||||
"Please refer to Redis Stack docs: https://redis.io/docs/stack/"
|
||||
)
|
||||
logging.error(error_message)
|
||||
raise ValueError(error_message)
|
||||
) >= int(module["ver"]):
|
||||
return
|
||||
# otherwise raise error
|
||||
error_message = (
|
||||
"You must add the RediSearch (>= 2.4) module from Redis Stack. "
|
||||
"Please refer to Redis Stack docs: https://redis.io/docs/stack/"
|
||||
)
|
||||
logging.error(error_message)
|
||||
raise ValueError(error_message)
|
||||
|
||||
|
||||
def _check_index_exists(client: RedisType, index_name: str) -> bool:
|
||||
@ -65,6 +82,24 @@ def _redis_prefix(index_name: str) -> str:
|
||||
|
||||
|
||||
class Redis(VectorStore):
|
||||
"""Wrapper around Redis vector database.
|
||||
|
||||
To use, you should have the ``redis`` python package installed.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain.vectorstores import Redis
|
||||
from langchain.embeddings import OpenAIEmbeddings
|
||||
|
||||
embeddings = OpenAIEmbeddings()
|
||||
vectorstore = Redis(
|
||||
redis_url="redis://username:password@localhost:6379"
|
||||
index_name="my-index",
|
||||
embedding_function=embeddings.embed_query,
|
||||
)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
redis_url: str,
|
||||
@ -99,33 +134,92 @@ class Redis(VectorStore):
|
||||
self.metadata_key = metadata_key
|
||||
self.vector_key = vector_key
|
||||
|
||||
def _create_index(self, dim: int = 1536) -> None:
|
||||
try:
|
||||
from redis.commands.search.field import TextField, VectorField
|
||||
from redis.commands.search.indexDefinition import IndexDefinition, IndexType
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"Could not import redis python package. "
|
||||
"Please install it with `pip install redis`."
|
||||
)
|
||||
|
||||
# Check if index exists
|
||||
if not _check_index_exists(self.client, self.index_name):
|
||||
# Constants
|
||||
distance_metric = (
|
||||
"COSINE" # distance metric for the vectors (ex. COSINE, IP, L2)
|
||||
)
|
||||
schema = (
|
||||
TextField(name=self.content_key),
|
||||
TextField(name=self.metadata_key),
|
||||
VectorField(
|
||||
self.vector_key,
|
||||
"FLAT",
|
||||
{
|
||||
"TYPE": "FLOAT32",
|
||||
"DIM": dim,
|
||||
"DISTANCE_METRIC": distance_metric,
|
||||
},
|
||||
),
|
||||
)
|
||||
prefix = _redis_prefix(self.index_name)
|
||||
|
||||
# Create Redis Index
|
||||
self.client.ft(self.index_name).create_index(
|
||||
fields=schema,
|
||||
definition=IndexDefinition(prefix=[prefix], index_type=IndexType.HASH),
|
||||
)
|
||||
|
||||
def add_texts(
|
||||
self,
|
||||
texts: Iterable[str],
|
||||
metadatas: Optional[List[dict]] = None,
|
||||
embeddings: Optional[List[List[float]]] = None,
|
||||
keys: Optional[List[str]] = None,
|
||||
batch_size: int = 1000,
|
||||
**kwargs: Any,
|
||||
) -> List[str]:
|
||||
"""Add texts data to an existing index."""
|
||||
prefix = _redis_prefix(self.index_name)
|
||||
keys = kwargs.get("keys")
|
||||
"""Add more texts to the vectorstore.
|
||||
|
||||
Args:
|
||||
texts (Iterable[str]): Iterable of strings/text to add to the vectorstore.
|
||||
metadatas (Optional[List[dict]], optional): Optional list of metadatas.
|
||||
Defaults to None.
|
||||
embeddings (Optional[List[List[float]]], optional): Optional pre-generated
|
||||
embeddings. Defaults to None.
|
||||
keys (Optional[List[str]], optional): Optional key values to use as ids.
|
||||
Defaults to None.
|
||||
batch_size (int, optional): Batch size to use for writes. Defaults to 1000.
|
||||
|
||||
Returns:
|
||||
List[str]: List of ids added to the vectorstore
|
||||
"""
|
||||
ids = []
|
||||
prefix = _redis_prefix(self.index_name)
|
||||
|
||||
# Write data to redis
|
||||
pipeline = self.client.pipeline(transaction=False)
|
||||
for i, text in enumerate(texts):
|
||||
# Use provided key otherwise use default key
|
||||
# Use provided values by default or fallback
|
||||
key = keys[i] if keys else _redis_key(prefix)
|
||||
metadata = metadatas[i] if metadatas else {}
|
||||
embedding = embeddings[i] if embeddings else self.embedding_function(text)
|
||||
pipeline.hset(
|
||||
key,
|
||||
mapping={
|
||||
self.content_key: text,
|
||||
self.vector_key: np.array(
|
||||
self.embedding_function(text), dtype=np.float32
|
||||
).tobytes(),
|
||||
self.vector_key: np.array(embedding, dtype=np.float32).tobytes(),
|
||||
self.metadata_key: json.dumps(metadata),
|
||||
},
|
||||
)
|
||||
ids.append(key)
|
||||
|
||||
# Write batch
|
||||
if i % batch_size == 0:
|
||||
pipeline.execute()
|
||||
|
||||
# Cleanup final batch
|
||||
pipeline.execute()
|
||||
return ids
|
||||
|
||||
@ -170,9 +264,30 @@ class Redis(VectorStore):
|
||||
|
||||
"""
|
||||
docs_and_scores = self.similarity_search_with_score(query, k=k)
|
||||
|
||||
return [doc for doc, score in docs_and_scores if score < score_threshold]
|
||||
|
||||
def _prepare_query(self, k: int) -> Query:
|
||||
try:
|
||||
from redis.commands.search.query import Query
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"Could not import redis python package. "
|
||||
"Please install it with `pip install redis`."
|
||||
)
|
||||
# Prepare the Query
|
||||
hybrid_fields = "*"
|
||||
base_query = (
|
||||
f"{hybrid_fields}=>[KNN {k} @{self.vector_key} $vector AS vector_score]"
|
||||
)
|
||||
return_fields = [self.metadata_key, self.content_key, "vector_score"]
|
||||
return (
|
||||
Query(base_query)
|
||||
.return_fields(*return_fields)
|
||||
.sort_by("vector_score")
|
||||
.paging(0, k)
|
||||
.dialect(2)
|
||||
)
|
||||
|
||||
def similarity_search_with_score(
|
||||
self, query: str, k: int = 4
|
||||
) -> List[Tuple[Document, float]]:
|
||||
@ -185,40 +300,22 @@ class Redis(VectorStore):
|
||||
Returns:
|
||||
List of Documents most similar to the query and score for each
|
||||
"""
|
||||
try:
|
||||
from redis.commands.search.query import Query
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"Could not import redis python package. "
|
||||
"Please install it with `pip install redis`."
|
||||
)
|
||||
|
||||
# Creates embedding vector from user query
|
||||
embedding = self.embedding_function(query)
|
||||
|
||||
# Prepare the Query
|
||||
return_fields = [self.metadata_key, self.content_key, "vector_score"]
|
||||
vector_field = self.vector_key
|
||||
hybrid_fields = "*"
|
||||
base_query = (
|
||||
f"{hybrid_fields}=>[KNN {k} @{vector_field} $vector AS vector_score]"
|
||||
)
|
||||
redis_query = (
|
||||
Query(base_query)
|
||||
.return_fields(*return_fields)
|
||||
.sort_by("vector_score")
|
||||
.paging(0, k)
|
||||
.dialect(2)
|
||||
)
|
||||
# Creates Redis query
|
||||
redis_query = self._prepare_query(k)
|
||||
|
||||
params_dict: Mapping[str, str] = {
|
||||
"vector": np.array(embedding) # type: ignore
|
||||
.astype(dtype=np.float32)
|
||||
.tobytes()
|
||||
}
|
||||
|
||||
# perform vector search
|
||||
# Perform vector search
|
||||
results = self.client.ft(self.index_name).search(redis_query, params_dict)
|
||||
|
||||
# Prepare document results
|
||||
docs = [
|
||||
(
|
||||
Document(
|
||||
@ -243,15 +340,15 @@ class Redis(VectorStore):
|
||||
vector_key: str = "content_vector",
|
||||
**kwargs: Any,
|
||||
) -> Redis:
|
||||
"""Construct RediSearch wrapper from raw documents.
|
||||
"""Create a Redis vectorstore from raw documents.
|
||||
This is a user-friendly interface that:
|
||||
1. Embeds documents.
|
||||
2. Creates a new index for the embeddings in the RediSearch instance.
|
||||
3. Adds the documents to the newly created RediSearch index.
|
||||
2. Creates a new index for the embeddings in Redis.
|
||||
3. Adds the documents to the newly created Redis index.
|
||||
This is intended to be a quick way to get started.
|
||||
Example:
|
||||
.. code-block:: python
|
||||
from langchain import RediSearch
|
||||
from langchain.vectorstores import Redis
|
||||
from langchain.embeddings import OpenAIEmbeddings
|
||||
embeddings = OpenAIEmbeddings()
|
||||
redisearch = RediSearch.from_texts(
|
||||
@ -261,84 +358,35 @@ class Redis(VectorStore):
|
||||
)
|
||||
"""
|
||||
redis_url = get_from_dict_or_env(kwargs, "redis_url", "REDIS_URL")
|
||||
try:
|
||||
import redis
|
||||
from redis.commands.search.field import TextField, VectorField
|
||||
from redis.commands.search.indexDefinition import IndexDefinition, IndexType
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"Could not import redis python package. "
|
||||
"Please install it with `pip install redis`."
|
||||
)
|
||||
try:
|
||||
# We need to first remove redis_url from kwargs,
|
||||
# otherwise passing it to Redis will result in an error.
|
||||
if "redis_url" in kwargs:
|
||||
kwargs.pop("redis_url")
|
||||
client = redis.from_url(url=redis_url, **kwargs)
|
||||
# check if redis has redisearch module installed
|
||||
_check_redis_module_exist(client, REDIS_REQUIRED_MODULES)
|
||||
except ValueError as e:
|
||||
raise ValueError(f"Redis failed to connect: {e}")
|
||||
|
||||
# Create embeddings over documents
|
||||
embeddings = embedding.embed_documents(texts)
|
||||
if "redis_url" in kwargs:
|
||||
kwargs.pop("redis_url")
|
||||
|
||||
# Name of the search index if not given
|
||||
if not index_name:
|
||||
index_name = uuid.uuid4().hex
|
||||
prefix = _redis_prefix(index_name) # prefix for the document keys
|
||||
|
||||
# Check if index exists
|
||||
if not _check_index_exists(client, index_name):
|
||||
# Constants
|
||||
dim = len(embeddings[0])
|
||||
distance_metric = (
|
||||
"COSINE" # distance metric for the vectors (ex. COSINE, IP, L2)
|
||||
)
|
||||
schema = (
|
||||
TextField(name=content_key),
|
||||
TextField(name=metadata_key),
|
||||
VectorField(
|
||||
vector_key,
|
||||
"FLAT",
|
||||
{
|
||||
"TYPE": "FLOAT32",
|
||||
"DIM": dim,
|
||||
"DISTANCE_METRIC": distance_metric,
|
||||
},
|
||||
),
|
||||
)
|
||||
# Create Redis Index
|
||||
client.ft(index_name).create_index(
|
||||
fields=schema,
|
||||
definition=IndexDefinition(prefix=[prefix], index_type=IndexType.HASH),
|
||||
)
|
||||
|
||||
# Write data to Redis
|
||||
pipeline = client.pipeline(transaction=False)
|
||||
for i, text in enumerate(texts):
|
||||
key = _redis_key(prefix)
|
||||
metadata = metadatas[i] if metadatas else {}
|
||||
pipeline.hset(
|
||||
key,
|
||||
mapping={
|
||||
content_key: text,
|
||||
vector_key: np.array(embeddings[i], dtype=np.float32).tobytes(),
|
||||
metadata_key: json.dumps(metadata),
|
||||
},
|
||||
)
|
||||
pipeline.execute()
|
||||
return cls(
|
||||
redis_url,
|
||||
index_name,
|
||||
embedding.embed_query,
|
||||
# Create instance
|
||||
instance = cls(
|
||||
redis_url=redis_url,
|
||||
index_name=index_name,
|
||||
embedding_function=embedding.embed_query,
|
||||
content_key=content_key,
|
||||
metadata_key=metadata_key,
|
||||
vector_key=vector_key,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Create embeddings over documents
|
||||
embeddings = embedding.embed_documents(texts)
|
||||
|
||||
# Create the search index
|
||||
instance._create_index(dim=len(embeddings[0]))
|
||||
|
||||
# Add data to Redis
|
||||
instance.add_texts(texts, metadatas, embeddings)
|
||||
return instance
|
||||
|
||||
@staticmethod
|
||||
def drop_index(
|
||||
index_name: str,
|
||||
|
55
tests/integration_tests/cache/test_redis_cache.py
vendored
Normal file
55
tests/integration_tests/cache/test_redis_cache.py
vendored
Normal file
@ -0,0 +1,55 @@
|
||||
"""Test Redis cache functionality."""
|
||||
import redis
|
||||
|
||||
import langchain
|
||||
from langchain.cache import RedisCache, RedisSemanticCache
|
||||
from langchain.schema import Generation, LLMResult
|
||||
from tests.integration_tests.vectorstores.fake_embeddings import FakeEmbeddings
|
||||
from tests.unit_tests.llms.fake_llm import FakeLLM
|
||||
|
||||
REDIS_TEST_URL = "redis://localhost:6379"
|
||||
|
||||
|
||||
def test_redis_cache() -> None:
|
||||
langchain.llm_cache = RedisCache(redis_=redis.Redis.from_url(REDIS_TEST_URL))
|
||||
llm = FakeLLM()
|
||||
params = llm.dict()
|
||||
params["stop"] = None
|
||||
llm_string = str(sorted([(k, v) for k, v in params.items()]))
|
||||
langchain.llm_cache.update("foo", llm_string, [Generation(text="fizz")])
|
||||
output = llm.generate(["foo"])
|
||||
print(output)
|
||||
expected_output = LLMResult(
|
||||
generations=[[Generation(text="fizz")]],
|
||||
llm_output={},
|
||||
)
|
||||
print(expected_output)
|
||||
assert output == expected_output
|
||||
langchain.llm_cache.redis.flushall()
|
||||
|
||||
|
||||
def test_redis_semantic_cache() -> None:
|
||||
langchain.llm_cache = RedisSemanticCache(
|
||||
embedding=FakeEmbeddings(), redis_url=REDIS_TEST_URL, score_threshold=0.1
|
||||
)
|
||||
llm = FakeLLM()
|
||||
params = llm.dict()
|
||||
params["stop"] = None
|
||||
llm_string = str(sorted([(k, v) for k, v in params.items()]))
|
||||
langchain.llm_cache.update("foo", llm_string, [Generation(text="fizz")])
|
||||
output = llm.generate(
|
||||
["bar"]
|
||||
) # foo and bar will have the same embedding produced by FakeEmbeddings
|
||||
expected_output = LLMResult(
|
||||
generations=[[Generation(text="fizz")]],
|
||||
llm_output={},
|
||||
)
|
||||
assert output == expected_output
|
||||
# clear the cache
|
||||
langchain.llm_cache.clear(llm_string=llm_string)
|
||||
output = llm.generate(
|
||||
["bar"]
|
||||
) # foo and bar will have the same embedding produced by FakeEmbeddings
|
||||
# expect different output now without cached result
|
||||
assert output != expected_output
|
||||
langchain.llm_cache.clear(llm_string=llm_string)
|
@ -1,26 +1,60 @@
|
||||
"""Test Redis functionality."""
|
||||
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.vectorstores.redis import Redis
|
||||
from tests.integration_tests.vectorstores.fake_embeddings import FakeEmbeddings
|
||||
|
||||
TEST_INDEX_NAME = "test"
|
||||
TEST_REDIS_URL = "redis://localhost:6379"
|
||||
TEST_SINGLE_RESULT = [Document(page_content="foo")]
|
||||
TEST_RESULT = [Document(page_content="foo"), Document(page_content="foo")]
|
||||
|
||||
|
||||
def drop(index_name: str) -> bool:
|
||||
return Redis.drop_index(
|
||||
index_name=index_name, delete_documents=True, redis_url=TEST_REDIS_URL
|
||||
)
|
||||
|
||||
|
||||
def test_redis() -> None:
|
||||
"""Test end to end construction and search."""
|
||||
texts = ["foo", "bar", "baz"]
|
||||
docsearch = Redis.from_texts(
|
||||
texts, FakeEmbeddings(), redis_url="redis://localhost:6379"
|
||||
)
|
||||
docsearch = Redis.from_texts(texts, FakeEmbeddings(), redis_url=TEST_REDIS_URL)
|
||||
output = docsearch.similarity_search("foo", k=1)
|
||||
assert output == [Document(page_content="foo")]
|
||||
assert output == TEST_SINGLE_RESULT
|
||||
assert drop(docsearch.index_name)
|
||||
|
||||
|
||||
def test_redis_new_vector() -> None:
|
||||
"""Test adding a new document"""
|
||||
texts = ["foo", "bar", "baz"]
|
||||
docsearch = Redis.from_texts(
|
||||
texts, FakeEmbeddings(), redis_url="redis://localhost:6379"
|
||||
docsearch = Redis.from_texts(texts, FakeEmbeddings(), redis_url=TEST_REDIS_URL)
|
||||
docsearch.add_texts(["foo"])
|
||||
output = docsearch.similarity_search("foo", k=2)
|
||||
assert output == TEST_RESULT
|
||||
assert drop(docsearch.index_name)
|
||||
|
||||
|
||||
def test_redis_from_existing() -> None:
|
||||
"""Test adding a new document"""
|
||||
texts = ["foo", "bar", "baz"]
|
||||
Redis.from_texts(
|
||||
texts, FakeEmbeddings(), index_name=TEST_INDEX_NAME, redis_url=TEST_REDIS_URL
|
||||
)
|
||||
# Test creating from an existing
|
||||
docsearch2 = Redis.from_existing_index(
|
||||
FakeEmbeddings(), index_name=TEST_INDEX_NAME, redis_url=TEST_REDIS_URL
|
||||
)
|
||||
output = docsearch2.similarity_search("foo", k=1)
|
||||
assert output == TEST_SINGLE_RESULT
|
||||
|
||||
|
||||
def test_redis_add_texts_to_existing() -> None:
|
||||
"""Test adding a new document"""
|
||||
# Test creating from an existing
|
||||
docsearch = Redis.from_existing_index(
|
||||
FakeEmbeddings(), index_name=TEST_INDEX_NAME, redis_url=TEST_REDIS_URL
|
||||
)
|
||||
docsearch.add_texts(["foo"])
|
||||
output = docsearch.similarity_search("foo", k=2)
|
||||
assert output == [Document(page_content="foo"), Document(page_content="foo")]
|
||||
assert output == TEST_RESULT
|
||||
assert drop(TEST_INDEX_NAME)
|
||||
|
Loading…
Reference in New Issue
Block a user