mirror of
https://github.com/hwchase17/langchain
synced 2024-11-06 03:20:49 +00:00
Harrison/redis cache (#3766)
Co-authored-by: Tyler Hutcherson <tyler.hutcherson@redis.com>
This commit is contained in:
parent
b588446bf9
commit
be7a8e0824
79
docs/ecosystem/redis.md
Normal file
79
docs/ecosystem/redis.md
Normal file
@ -0,0 +1,79 @@
|
|||||||
|
# Redis
|
||||||
|
|
||||||
|
This page covers how to use the [Redis](https://redis.com) ecosystem within LangChain.
|
||||||
|
It is broken into two parts: installation and setup, and then references to specific Redis wrappers.
|
||||||
|
|
||||||
|
## Installation and Setup
|
||||||
|
- Install the Redis Python SDK with `pip install redis`
|
||||||
|
|
||||||
|
## Wrappers
|
||||||
|
|
||||||
|
### Cache
|
||||||
|
|
||||||
|
The Cache wrapper allows for [Redis](https://redis.io) to be used as a remote, low-latency, in-memory cache for LLM prompts and responses.
|
||||||
|
|
||||||
|
#### Standard Cache
|
||||||
|
The standard cache is the Redis bread & butter of use case in production for both [open source](https://redis.io) and [enterprise](https://redis.com) users globally.
|
||||||
|
|
||||||
|
To import this cache:
|
||||||
|
```python
|
||||||
|
from langchain.cache import RedisCache
|
||||||
|
```
|
||||||
|
|
||||||
|
To use this cache with your LLMs:
|
||||||
|
```python
|
||||||
|
import langchain
|
||||||
|
import redis
|
||||||
|
|
||||||
|
redis_client = redis.Redis.from_url(...)
|
||||||
|
langchain.llm_cache = RedisCache(redis_client)
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Semantic Cache
|
||||||
|
Semantic caching allows users to retrieve cached prompts based on semantic similarity between the user input and previously cached results. Under the hood it blends Redis as both a cache and a vectorstore.
|
||||||
|
|
||||||
|
To import this cache:
|
||||||
|
```python
|
||||||
|
from langchain.cache import RedisSemanticCache
|
||||||
|
```
|
||||||
|
|
||||||
|
To use this cache with your LLMs:
|
||||||
|
```python
|
||||||
|
import langchain
|
||||||
|
import redis
|
||||||
|
|
||||||
|
# use any embedding provider...
|
||||||
|
from tests.integration_tests.vectorstores.fake_embeddings import FakeEmbeddings
|
||||||
|
|
||||||
|
redis_url = "redis://localhost:6379"
|
||||||
|
|
||||||
|
langchain.llm_cache = RedisSemanticCache(
|
||||||
|
embedding=FakeEmbeddings(),
|
||||||
|
redis_url=redis_url
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
### VectorStore
|
||||||
|
|
||||||
|
The vectorstore wrapper turns Redis into a low-latency [vector database](https://redis.com/solutions/use-cases/vector-database/) for semantic search or LLM content retrieval.
|
||||||
|
|
||||||
|
To import this vectorstore:
|
||||||
|
```python
|
||||||
|
from langchain.vectorstores import Redis
|
||||||
|
```
|
||||||
|
|
||||||
|
For a more detailed walkthrough of the Redis vectorstore wrapper, see [this notebook](../modules/indexes/vectorstores/examples/redis.ipynb).
|
||||||
|
|
||||||
|
### Retriever
|
||||||
|
|
||||||
|
The Redis vector store retriever wrapper generalizes the vectorstore class to perform low-latency document retrieval. To create the retriever, simply call `.as_retriever()` on the base vectorstore class.
|
||||||
|
|
||||||
|
### Memory
|
||||||
|
Redis can be used to persist LLM conversations.
|
||||||
|
|
||||||
|
#### Vector Store Retriever Memory
|
||||||
|
|
||||||
|
For a more detailed walkthrough of the `VectorStoreRetrieverMemory` wrapper, see [this notebook](../modules/memory/types/vectorstore_retriever_memory.ipynb).
|
||||||
|
|
||||||
|
#### Chat Message History Memory
|
||||||
|
For a detailed example of Redis to cache conversation message history, see [this notebook](../modules/memory/examples/redis_chat_message_history.ipynb).
|
@ -41,7 +41,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 3,
|
"execution_count": 5,
|
||||||
"id": "f69f6283",
|
"id": "f69f6283",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
@ -52,7 +52,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 4,
|
"execution_count": 6,
|
||||||
"id": "64005d1f",
|
"id": "64005d1f",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
@ -60,8 +60,8 @@
|
|||||||
"name": "stdout",
|
"name": "stdout",
|
||||||
"output_type": "stream",
|
"output_type": "stream",
|
||||||
"text": [
|
"text": [
|
||||||
"CPU times: user 14.2 ms, sys: 4.9 ms, total: 19.1 ms\n",
|
"CPU times: user 26.1 ms, sys: 21.5 ms, total: 47.6 ms\n",
|
||||||
"Wall time: 1.1 s\n"
|
"Wall time: 1.68 s\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -70,7 +70,7 @@
|
|||||||
"'\\n\\nWhy did the chicken cross the road?\\n\\nTo get to the other side.'"
|
"'\\n\\nWhy did the chicken cross the road?\\n\\nTo get to the other side.'"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"execution_count": 4,
|
"execution_count": 6,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"output_type": "execute_result"
|
"output_type": "execute_result"
|
||||||
}
|
}
|
||||||
@ -83,7 +83,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 5,
|
"execution_count": 7,
|
||||||
"id": "c8a1cb2b",
|
"id": "c8a1cb2b",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
@ -91,8 +91,8 @@
|
|||||||
"name": "stdout",
|
"name": "stdout",
|
||||||
"output_type": "stream",
|
"output_type": "stream",
|
||||||
"text": [
|
"text": [
|
||||||
"CPU times: user 162 µs, sys: 7 µs, total: 169 µs\n",
|
"CPU times: user 238 µs, sys: 143 µs, total: 381 µs\n",
|
||||||
"Wall time: 175 µs\n"
|
"Wall time: 1.76 ms\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -101,7 +101,7 @@
|
|||||||
"'\\n\\nWhy did the chicken cross the road?\\n\\nTo get to the other side.'"
|
"'\\n\\nWhy did the chicken cross the road?\\n\\nTo get to the other side.'"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"execution_count": 5,
|
"execution_count": 7,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"output_type": "execute_result"
|
"output_type": "execute_result"
|
||||||
}
|
}
|
||||||
@ -214,9 +214,18 @@
|
|||||||
"## Redis Cache"
|
"## Redis Cache"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "c5c9a4d5",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"### Standard Cache\n",
|
||||||
|
"Use [Redis](../../../../ecosystem/redis.md) to cache prompts and responses."
|
||||||
|
]
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": 8,
|
||||||
"id": "39f6eb0b",
|
"id": "39f6eb0b",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
@ -225,15 +234,35 @@
|
|||||||
"# (make sure your local Redis instance is running first before running this example)\n",
|
"# (make sure your local Redis instance is running first before running this example)\n",
|
||||||
"from redis import Redis\n",
|
"from redis import Redis\n",
|
||||||
"from langchain.cache import RedisCache\n",
|
"from langchain.cache import RedisCache\n",
|
||||||
|
"\n",
|
||||||
"langchain.llm_cache = RedisCache(redis_=Redis())"
|
"langchain.llm_cache = RedisCache(redis_=Redis())"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": 9,
|
||||||
"id": "28920749",
|
"id": "28920749",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"CPU times: user 6.88 ms, sys: 8.75 ms, total: 15.6 ms\n",
|
||||||
|
"Wall time: 1.04 s\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"'\\n\\nWhy did the chicken cross the road?\\n\\nTo get to the other side!'"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 9,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"%%time\n",
|
"%%time\n",
|
||||||
"# The first time, it is not yet in cache, so it should take longer\n",
|
"# The first time, it is not yet in cache, so it should take longer\n",
|
||||||
@ -242,16 +271,124 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": 14,
|
||||||
"id": "94bf9415",
|
"id": "94bf9415",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"CPU times: user 1.59 ms, sys: 610 µs, total: 2.2 ms\n",
|
||||||
|
"Wall time: 5.58 ms\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"'\\n\\nWhy did the chicken cross the road?\\n\\nTo get to the other side!'"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 14,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"%%time\n",
|
"%%time\n",
|
||||||
"# The second time it is, so it goes faster\n",
|
"# The second time it is, so it goes faster\n",
|
||||||
"llm(\"Tell me a joke\")"
|
"llm(\"Tell me a joke\")"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "82be23f6",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"### Semantic Cache\n",
|
||||||
|
"Use [Redis](../../../../ecosystem/redis.md) to cache prompts and responses and evaluate hits based on semantic similarity."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 15,
|
||||||
|
"id": "64df3099",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from langchain.embeddings import OpenAIEmbeddings\n",
|
||||||
|
"from langchain.cache import RedisSemanticCache\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"langchain.llm_cache = RedisSemanticCache(\n",
|
||||||
|
" redis_url=\"redis://localhost:6379\",\n",
|
||||||
|
" embedding=OpenAIEmbeddings()\n",
|
||||||
|
")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 16,
|
||||||
|
"id": "8e91d3ac",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"CPU times: user 351 ms, sys: 156 ms, total: 507 ms\n",
|
||||||
|
"Wall time: 3.37 s\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"\"\\n\\nWhy don't scientists trust atoms?\\nBecause they make up everything.\""
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 16,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"%%time\n",
|
||||||
|
"# The first time, it is not yet in cache, so it should take longer\n",
|
||||||
|
"llm(\"Tell me a joke\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 27,
|
||||||
|
"id": "df856948",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"CPU times: user 6.25 ms, sys: 2.72 ms, total: 8.97 ms\n",
|
||||||
|
"Wall time: 262 ms\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"\"\\n\\nWhy don't scientists trust atoms?\\nBecause they make up everything.\""
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 27,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"%%time\n",
|
||||||
|
"# The second time, while not a direct hit, the question is semantically similar to the original question,\n",
|
||||||
|
"# so it uses the cached result!\n",
|
||||||
|
"llm(\"Tell me one joke\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"id": "684eab55",
|
"id": "684eab55",
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
"""Beta Feature: base interface for cache."""
|
"""Beta Feature: base interface for cache."""
|
||||||
|
import hashlib
|
||||||
import json
|
import json
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, cast
|
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, cast
|
||||||
@ -12,11 +13,18 @@ try:
|
|||||||
except ImportError:
|
except ImportError:
|
||||||
from sqlalchemy.ext.declarative import declarative_base
|
from sqlalchemy.ext.declarative import declarative_base
|
||||||
|
|
||||||
|
from langchain.embeddings.base import Embeddings
|
||||||
from langchain.schema import Generation
|
from langchain.schema import Generation
|
||||||
|
from langchain.vectorstores.redis import Redis as RedisVectorstore
|
||||||
|
|
||||||
RETURN_VAL_TYPE = List[Generation]
|
RETURN_VAL_TYPE = List[Generation]
|
||||||
|
|
||||||
|
|
||||||
|
def _hash(_input: str) -> str:
|
||||||
|
"""Use a deterministic hashing approach."""
|
||||||
|
return hashlib.md5(_input.encode()).hexdigest()
|
||||||
|
|
||||||
|
|
||||||
class BaseCache(ABC):
|
class BaseCache(ABC):
|
||||||
"""Base interface for cache."""
|
"""Base interface for cache."""
|
||||||
|
|
||||||
@ -117,6 +125,8 @@ class SQLiteCache(SQLAlchemyCache):
|
|||||||
class RedisCache(BaseCache):
|
class RedisCache(BaseCache):
|
||||||
"""Cache that uses Redis as a backend."""
|
"""Cache that uses Redis as a backend."""
|
||||||
|
|
||||||
|
# TODO - implement a TTL policy in Redis
|
||||||
|
|
||||||
def __init__(self, redis_: Any):
|
def __init__(self, redis_: Any):
|
||||||
"""Initialize by passing in Redis instance."""
|
"""Initialize by passing in Redis instance."""
|
||||||
try:
|
try:
|
||||||
@ -130,28 +140,30 @@ class RedisCache(BaseCache):
|
|||||||
raise ValueError("Please pass in Redis object.")
|
raise ValueError("Please pass in Redis object.")
|
||||||
self.redis = redis_
|
self.redis = redis_
|
||||||
|
|
||||||
def _key(self, prompt: str, llm_string: str, idx: int) -> str:
|
def _key(self, prompt: str, llm_string: str) -> str:
|
||||||
"""Compute key from prompt, llm_string, and idx."""
|
"""Compute key from prompt and llm_string"""
|
||||||
return str(hash(prompt + llm_string)) + "_" + str(idx)
|
return _hash(prompt + llm_string)
|
||||||
|
|
||||||
def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]:
|
def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]:
|
||||||
"""Look up based on prompt and llm_string."""
|
"""Look up based on prompt and llm_string."""
|
||||||
idx = 0
|
|
||||||
generations = []
|
generations = []
|
||||||
while self.redis.get(self._key(prompt, llm_string, idx)):
|
# Read from a Redis HASH
|
||||||
result = self.redis.get(self._key(prompt, llm_string, idx))
|
results = self.redis.hgetall(self._key(prompt, llm_string))
|
||||||
if not result:
|
if results:
|
||||||
break
|
for _, text in results.items():
|
||||||
elif isinstance(result, bytes):
|
generations.append(Generation(text=text))
|
||||||
result = result.decode()
|
|
||||||
generations.append(Generation(text=result))
|
|
||||||
idx += 1
|
|
||||||
return generations if generations else None
|
return generations if generations else None
|
||||||
|
|
||||||
def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None:
|
def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None:
|
||||||
"""Update cache based on prompt and llm_string."""
|
"""Update cache based on prompt and llm_string."""
|
||||||
for i, generation in enumerate(return_val):
|
# Write to a Redis HASH
|
||||||
self.redis.set(self._key(prompt, llm_string, i), generation.text)
|
key = self._key(prompt, llm_string)
|
||||||
|
self.redis.hset(
|
||||||
|
key,
|
||||||
|
mapping={
|
||||||
|
str(idx): generation.text for idx, generation in enumerate(return_val)
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
def clear(self, **kwargs: Any) -> None:
|
def clear(self, **kwargs: Any) -> None:
|
||||||
"""Clear cache. If `asynchronous` is True, flush asynchronously."""
|
"""Clear cache. If `asynchronous` is True, flush asynchronously."""
|
||||||
@ -159,6 +171,106 @@ class RedisCache(BaseCache):
|
|||||||
self.redis.flushdb(asynchronous=asynchronous, **kwargs)
|
self.redis.flushdb(asynchronous=asynchronous, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class RedisSemanticCache(BaseCache):
|
||||||
|
"""Cache that uses Redis as a vector-store backend."""
|
||||||
|
|
||||||
|
# TODO - implement a TTL policy in Redis
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, redis_url: str, embedding: Embeddings, score_threshold: float = 0.2
|
||||||
|
):
|
||||||
|
"""Initialize by passing in the `init` GPTCache func
|
||||||
|
|
||||||
|
Args:
|
||||||
|
redis_url (str): URL to connect to Redis.
|
||||||
|
embedding (Embedding): Embedding provider for semantic encoding and search.
|
||||||
|
score_threshold (float, 0.2):
|
||||||
|
|
||||||
|
Example:
|
||||||
|
.. code-block:: python
|
||||||
|
import langchain
|
||||||
|
|
||||||
|
from langchain.cache import RedisSemanticCache
|
||||||
|
from langchain.embeddings import OpenAIEmbeddings
|
||||||
|
|
||||||
|
langchain.llm_cache = RedisSemanticCache(
|
||||||
|
redis_url="redis://localhost:6379",
|
||||||
|
embedding=OpenAIEmbeddings()
|
||||||
|
)
|
||||||
|
|
||||||
|
"""
|
||||||
|
self._cache_dict: Dict[str, RedisVectorstore] = {}
|
||||||
|
self.redis_url = redis_url
|
||||||
|
self.embedding = embedding
|
||||||
|
self.score_threshold = score_threshold
|
||||||
|
|
||||||
|
def _index_name(self, llm_string: str) -> str:
|
||||||
|
hashed_index = _hash(llm_string)
|
||||||
|
return f"cache:{hashed_index}"
|
||||||
|
|
||||||
|
def _get_llm_cache(self, llm_string: str) -> RedisVectorstore:
|
||||||
|
index_name = self._index_name(llm_string)
|
||||||
|
|
||||||
|
# return vectorstore client for the specific llm string
|
||||||
|
if index_name in self._cache_dict:
|
||||||
|
return self._cache_dict[index_name]
|
||||||
|
|
||||||
|
# create new vectorstore client for the specific llm string
|
||||||
|
try:
|
||||||
|
self._cache_dict[index_name] = RedisVectorstore.from_existing_index(
|
||||||
|
embedding=self.embedding,
|
||||||
|
index_name=index_name,
|
||||||
|
redis_url=self.redis_url,
|
||||||
|
)
|
||||||
|
except ValueError:
|
||||||
|
redis = RedisVectorstore(
|
||||||
|
embedding_function=self.embedding.embed_query,
|
||||||
|
index_name=index_name,
|
||||||
|
redis_url=self.redis_url,
|
||||||
|
)
|
||||||
|
_embedding = self.embedding.embed_query(text="test")
|
||||||
|
redis._create_index(dim=len(_embedding))
|
||||||
|
self._cache_dict[index_name] = redis
|
||||||
|
|
||||||
|
return self._cache_dict[index_name]
|
||||||
|
|
||||||
|
def clear(self, **kwargs: Any) -> None:
|
||||||
|
"""Clear semantic cache for a given llm_string."""
|
||||||
|
index_name = self._index_name(kwargs["llm_string"])
|
||||||
|
if index_name in self._cache_dict:
|
||||||
|
self._cache_dict[index_name].drop_index(
|
||||||
|
index_name=index_name, delete_documents=True, redis_url=self.redis_url
|
||||||
|
)
|
||||||
|
del self._cache_dict[index_name]
|
||||||
|
|
||||||
|
def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]:
|
||||||
|
"""Look up based on prompt and llm_string."""
|
||||||
|
llm_cache = self._get_llm_cache(llm_string)
|
||||||
|
generations = []
|
||||||
|
# Read from a Hash
|
||||||
|
results = llm_cache.similarity_search_limit_score(
|
||||||
|
query=prompt,
|
||||||
|
k=1,
|
||||||
|
score_threshold=self.score_threshold,
|
||||||
|
)
|
||||||
|
if results:
|
||||||
|
for document in results:
|
||||||
|
for text in document.metadata["return_val"]:
|
||||||
|
generations.append(Generation(text=text))
|
||||||
|
return generations if generations else None
|
||||||
|
|
||||||
|
def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None:
|
||||||
|
"""Update cache based on prompt and llm_string."""
|
||||||
|
llm_cache = self._get_llm_cache(llm_string)
|
||||||
|
# Write to vectorstore
|
||||||
|
metadata = {
|
||||||
|
"llm_string": llm_string,
|
||||||
|
"prompt": prompt,
|
||||||
|
"return_val": [generation.text for generation in return_val],
|
||||||
|
}
|
||||||
|
llm_cache.add_texts(texts=[prompt], metadatas=[metadata])
|
||||||
|
|
||||||
|
|
||||||
class GPTCache(BaseCache):
|
class GPTCache(BaseCache):
|
||||||
"""Cache that uses GPTCache as a backend."""
|
"""Cache that uses GPTCache as a backend."""
|
||||||
|
|
||||||
|
@ -13,11 +13,13 @@ from langchain.vectorstores.myscale import MyScale, MyScaleSettings
|
|||||||
from langchain.vectorstores.opensearch_vector_search import OpenSearchVectorSearch
|
from langchain.vectorstores.opensearch_vector_search import OpenSearchVectorSearch
|
||||||
from langchain.vectorstores.pinecone import Pinecone
|
from langchain.vectorstores.pinecone import Pinecone
|
||||||
from langchain.vectorstores.qdrant import Qdrant
|
from langchain.vectorstores.qdrant import Qdrant
|
||||||
|
from langchain.vectorstores.redis import Redis
|
||||||
from langchain.vectorstores.supabase import SupabaseVectorStore
|
from langchain.vectorstores.supabase import SupabaseVectorStore
|
||||||
from langchain.vectorstores.weaviate import Weaviate
|
from langchain.vectorstores.weaviate import Weaviate
|
||||||
from langchain.vectorstores.zilliz import Zilliz
|
from langchain.vectorstores.zilliz import Zilliz
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
"Redis",
|
||||||
"ElasticVectorSearch",
|
"ElasticVectorSearch",
|
||||||
"FAISS",
|
"FAISS",
|
||||||
"VectorStore",
|
"VectorStore",
|
||||||
|
@ -4,11 +4,21 @@ from __future__ import annotations
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Tuple, Type
|
from typing import (
|
||||||
|
TYPE_CHECKING,
|
||||||
|
Any,
|
||||||
|
Callable,
|
||||||
|
Dict,
|
||||||
|
Iterable,
|
||||||
|
List,
|
||||||
|
Mapping,
|
||||||
|
Optional,
|
||||||
|
Tuple,
|
||||||
|
Type,
|
||||||
|
)
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from pydantic import BaseModel, root_validator
|
from pydantic import BaseModel, root_validator
|
||||||
from redis.client import Redis as RedisType
|
|
||||||
|
|
||||||
from langchain.docstore.document import Document
|
from langchain.docstore.document import Document
|
||||||
from langchain.embeddings.base import Embeddings
|
from langchain.embeddings.base import Embeddings
|
||||||
@ -18,23 +28,30 @@ from langchain.vectorstores.base import VectorStore
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from redis.client import Redis as RedisType
|
||||||
|
from redis.commands.search.query import Query
|
||||||
|
|
||||||
|
|
||||||
# required modules
|
# required modules
|
||||||
REDIS_REQUIRED_MODULES = [
|
REDIS_REQUIRED_MODULES = [
|
||||||
{"name": "search", "ver": 20400},
|
{"name": "search", "ver": 20400},
|
||||||
|
{"name": "searchlight", "ver": 20400},
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def _check_redis_module_exist(client: RedisType, modules: List[dict]) -> None:
|
def _check_redis_module_exist(client: RedisType, required_modules: List[dict]) -> None:
|
||||||
"""Check if the correct Redis modules are installed."""
|
"""Check if the correct Redis modules are installed."""
|
||||||
installed_modules = client.module_list()
|
installed_modules = client.module_list()
|
||||||
installed_modules = {
|
installed_modules = {
|
||||||
module[b"name"].decode("utf-8"): module for module in installed_modules
|
module[b"name"].decode("utf-8"): module for module in installed_modules
|
||||||
}
|
}
|
||||||
for module in modules:
|
for module in required_modules:
|
||||||
if module["name"] not in installed_modules or int(
|
if module["name"] in installed_modules and int(
|
||||||
installed_modules[module["name"]][b"ver"]
|
installed_modules[module["name"]][b"ver"]
|
||||||
) < int(module["ver"]):
|
) >= int(module["ver"]):
|
||||||
|
return
|
||||||
|
# otherwise raise error
|
||||||
error_message = (
|
error_message = (
|
||||||
"You must add the RediSearch (>= 2.4) module from Redis Stack. "
|
"You must add the RediSearch (>= 2.4) module from Redis Stack. "
|
||||||
"Please refer to Redis Stack docs: https://redis.io/docs/stack/"
|
"Please refer to Redis Stack docs: https://redis.io/docs/stack/"
|
||||||
@ -65,6 +82,24 @@ def _redis_prefix(index_name: str) -> str:
|
|||||||
|
|
||||||
|
|
||||||
class Redis(VectorStore):
|
class Redis(VectorStore):
|
||||||
|
"""Wrapper around Redis vector database.
|
||||||
|
|
||||||
|
To use, you should have the ``redis`` python package installed.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
from langchain.vectorstores import Redis
|
||||||
|
from langchain.embeddings import OpenAIEmbeddings
|
||||||
|
|
||||||
|
embeddings = OpenAIEmbeddings()
|
||||||
|
vectorstore = Redis(
|
||||||
|
redis_url="redis://username:password@localhost:6379"
|
||||||
|
index_name="my-index",
|
||||||
|
embedding_function=embeddings.embed_query,
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
redis_url: str,
|
redis_url: str,
|
||||||
@ -99,33 +134,92 @@ class Redis(VectorStore):
|
|||||||
self.metadata_key = metadata_key
|
self.metadata_key = metadata_key
|
||||||
self.vector_key = vector_key
|
self.vector_key = vector_key
|
||||||
|
|
||||||
|
def _create_index(self, dim: int = 1536) -> None:
|
||||||
|
try:
|
||||||
|
from redis.commands.search.field import TextField, VectorField
|
||||||
|
from redis.commands.search.indexDefinition import IndexDefinition, IndexType
|
||||||
|
except ImportError:
|
||||||
|
raise ValueError(
|
||||||
|
"Could not import redis python package. "
|
||||||
|
"Please install it with `pip install redis`."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check if index exists
|
||||||
|
if not _check_index_exists(self.client, self.index_name):
|
||||||
|
# Constants
|
||||||
|
distance_metric = (
|
||||||
|
"COSINE" # distance metric for the vectors (ex. COSINE, IP, L2)
|
||||||
|
)
|
||||||
|
schema = (
|
||||||
|
TextField(name=self.content_key),
|
||||||
|
TextField(name=self.metadata_key),
|
||||||
|
VectorField(
|
||||||
|
self.vector_key,
|
||||||
|
"FLAT",
|
||||||
|
{
|
||||||
|
"TYPE": "FLOAT32",
|
||||||
|
"DIM": dim,
|
||||||
|
"DISTANCE_METRIC": distance_metric,
|
||||||
|
},
|
||||||
|
),
|
||||||
|
)
|
||||||
|
prefix = _redis_prefix(self.index_name)
|
||||||
|
|
||||||
|
# Create Redis Index
|
||||||
|
self.client.ft(self.index_name).create_index(
|
||||||
|
fields=schema,
|
||||||
|
definition=IndexDefinition(prefix=[prefix], index_type=IndexType.HASH),
|
||||||
|
)
|
||||||
|
|
||||||
def add_texts(
|
def add_texts(
|
||||||
self,
|
self,
|
||||||
texts: Iterable[str],
|
texts: Iterable[str],
|
||||||
metadatas: Optional[List[dict]] = None,
|
metadatas: Optional[List[dict]] = None,
|
||||||
|
embeddings: Optional[List[List[float]]] = None,
|
||||||
|
keys: Optional[List[str]] = None,
|
||||||
|
batch_size: int = 1000,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> List[str]:
|
) -> List[str]:
|
||||||
"""Add texts data to an existing index."""
|
"""Add more texts to the vectorstore.
|
||||||
prefix = _redis_prefix(self.index_name)
|
|
||||||
keys = kwargs.get("keys")
|
Args:
|
||||||
|
texts (Iterable[str]): Iterable of strings/text to add to the vectorstore.
|
||||||
|
metadatas (Optional[List[dict]], optional): Optional list of metadatas.
|
||||||
|
Defaults to None.
|
||||||
|
embeddings (Optional[List[List[float]]], optional): Optional pre-generated
|
||||||
|
embeddings. Defaults to None.
|
||||||
|
keys (Optional[List[str]], optional): Optional key values to use as ids.
|
||||||
|
Defaults to None.
|
||||||
|
batch_size (int, optional): Batch size to use for writes. Defaults to 1000.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[str]: List of ids added to the vectorstore
|
||||||
|
"""
|
||||||
ids = []
|
ids = []
|
||||||
|
prefix = _redis_prefix(self.index_name)
|
||||||
|
|
||||||
# Write data to redis
|
# Write data to redis
|
||||||
pipeline = self.client.pipeline(transaction=False)
|
pipeline = self.client.pipeline(transaction=False)
|
||||||
for i, text in enumerate(texts):
|
for i, text in enumerate(texts):
|
||||||
# Use provided key otherwise use default key
|
# Use provided values by default or fallback
|
||||||
key = keys[i] if keys else _redis_key(prefix)
|
key = keys[i] if keys else _redis_key(prefix)
|
||||||
metadata = metadatas[i] if metadatas else {}
|
metadata = metadatas[i] if metadatas else {}
|
||||||
|
embedding = embeddings[i] if embeddings else self.embedding_function(text)
|
||||||
pipeline.hset(
|
pipeline.hset(
|
||||||
key,
|
key,
|
||||||
mapping={
|
mapping={
|
||||||
self.content_key: text,
|
self.content_key: text,
|
||||||
self.vector_key: np.array(
|
self.vector_key: np.array(embedding, dtype=np.float32).tobytes(),
|
||||||
self.embedding_function(text), dtype=np.float32
|
|
||||||
).tobytes(),
|
|
||||||
self.metadata_key: json.dumps(metadata),
|
self.metadata_key: json.dumps(metadata),
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
ids.append(key)
|
ids.append(key)
|
||||||
|
|
||||||
|
# Write batch
|
||||||
|
if i % batch_size == 0:
|
||||||
|
pipeline.execute()
|
||||||
|
|
||||||
|
# Cleanup final batch
|
||||||
pipeline.execute()
|
pipeline.execute()
|
||||||
return ids
|
return ids
|
||||||
|
|
||||||
@ -170,9 +264,30 @@ class Redis(VectorStore):
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
docs_and_scores = self.similarity_search_with_score(query, k=k)
|
docs_and_scores = self.similarity_search_with_score(query, k=k)
|
||||||
|
|
||||||
return [doc for doc, score in docs_and_scores if score < score_threshold]
|
return [doc for doc, score in docs_and_scores if score < score_threshold]
|
||||||
|
|
||||||
|
def _prepare_query(self, k: int) -> Query:
|
||||||
|
try:
|
||||||
|
from redis.commands.search.query import Query
|
||||||
|
except ImportError:
|
||||||
|
raise ValueError(
|
||||||
|
"Could not import redis python package. "
|
||||||
|
"Please install it with `pip install redis`."
|
||||||
|
)
|
||||||
|
# Prepare the Query
|
||||||
|
hybrid_fields = "*"
|
||||||
|
base_query = (
|
||||||
|
f"{hybrid_fields}=>[KNN {k} @{self.vector_key} $vector AS vector_score]"
|
||||||
|
)
|
||||||
|
return_fields = [self.metadata_key, self.content_key, "vector_score"]
|
||||||
|
return (
|
||||||
|
Query(base_query)
|
||||||
|
.return_fields(*return_fields)
|
||||||
|
.sort_by("vector_score")
|
||||||
|
.paging(0, k)
|
||||||
|
.dialect(2)
|
||||||
|
)
|
||||||
|
|
||||||
def similarity_search_with_score(
|
def similarity_search_with_score(
|
||||||
self, query: str, k: int = 4
|
self, query: str, k: int = 4
|
||||||
) -> List[Tuple[Document, float]]:
|
) -> List[Tuple[Document, float]]:
|
||||||
@ -185,40 +300,22 @@ class Redis(VectorStore):
|
|||||||
Returns:
|
Returns:
|
||||||
List of Documents most similar to the query and score for each
|
List of Documents most similar to the query and score for each
|
||||||
"""
|
"""
|
||||||
try:
|
|
||||||
from redis.commands.search.query import Query
|
|
||||||
except ImportError:
|
|
||||||
raise ValueError(
|
|
||||||
"Could not import redis python package. "
|
|
||||||
"Please install it with `pip install redis`."
|
|
||||||
)
|
|
||||||
|
|
||||||
# Creates embedding vector from user query
|
# Creates embedding vector from user query
|
||||||
embedding = self.embedding_function(query)
|
embedding = self.embedding_function(query)
|
||||||
|
|
||||||
# Prepare the Query
|
# Creates Redis query
|
||||||
return_fields = [self.metadata_key, self.content_key, "vector_score"]
|
redis_query = self._prepare_query(k)
|
||||||
vector_field = self.vector_key
|
|
||||||
hybrid_fields = "*"
|
|
||||||
base_query = (
|
|
||||||
f"{hybrid_fields}=>[KNN {k} @{vector_field} $vector AS vector_score]"
|
|
||||||
)
|
|
||||||
redis_query = (
|
|
||||||
Query(base_query)
|
|
||||||
.return_fields(*return_fields)
|
|
||||||
.sort_by("vector_score")
|
|
||||||
.paging(0, k)
|
|
||||||
.dialect(2)
|
|
||||||
)
|
|
||||||
params_dict: Mapping[str, str] = {
|
params_dict: Mapping[str, str] = {
|
||||||
"vector": np.array(embedding) # type: ignore
|
"vector": np.array(embedding) # type: ignore
|
||||||
.astype(dtype=np.float32)
|
.astype(dtype=np.float32)
|
||||||
.tobytes()
|
.tobytes()
|
||||||
}
|
}
|
||||||
|
|
||||||
# perform vector search
|
# Perform vector search
|
||||||
results = self.client.ft(self.index_name).search(redis_query, params_dict)
|
results = self.client.ft(self.index_name).search(redis_query, params_dict)
|
||||||
|
|
||||||
|
# Prepare document results
|
||||||
docs = [
|
docs = [
|
||||||
(
|
(
|
||||||
Document(
|
Document(
|
||||||
@ -243,15 +340,15 @@ class Redis(VectorStore):
|
|||||||
vector_key: str = "content_vector",
|
vector_key: str = "content_vector",
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Redis:
|
) -> Redis:
|
||||||
"""Construct RediSearch wrapper from raw documents.
|
"""Create a Redis vectorstore from raw documents.
|
||||||
This is a user-friendly interface that:
|
This is a user-friendly interface that:
|
||||||
1. Embeds documents.
|
1. Embeds documents.
|
||||||
2. Creates a new index for the embeddings in the RediSearch instance.
|
2. Creates a new index for the embeddings in Redis.
|
||||||
3. Adds the documents to the newly created RediSearch index.
|
3. Adds the documents to the newly created Redis index.
|
||||||
This is intended to be a quick way to get started.
|
This is intended to be a quick way to get started.
|
||||||
Example:
|
Example:
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
from langchain import RediSearch
|
from langchain.vectorstores import Redis
|
||||||
from langchain.embeddings import OpenAIEmbeddings
|
from langchain.embeddings import OpenAIEmbeddings
|
||||||
embeddings = OpenAIEmbeddings()
|
embeddings = OpenAIEmbeddings()
|
||||||
redisearch = RediSearch.from_texts(
|
redisearch = RediSearch.from_texts(
|
||||||
@ -261,84 +358,35 @@ class Redis(VectorStore):
|
|||||||
)
|
)
|
||||||
"""
|
"""
|
||||||
redis_url = get_from_dict_or_env(kwargs, "redis_url", "REDIS_URL")
|
redis_url = get_from_dict_or_env(kwargs, "redis_url", "REDIS_URL")
|
||||||
try:
|
|
||||||
import redis
|
|
||||||
from redis.commands.search.field import TextField, VectorField
|
|
||||||
from redis.commands.search.indexDefinition import IndexDefinition, IndexType
|
|
||||||
except ImportError:
|
|
||||||
raise ValueError(
|
|
||||||
"Could not import redis python package. "
|
|
||||||
"Please install it with `pip install redis`."
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
# We need to first remove redis_url from kwargs,
|
|
||||||
# otherwise passing it to Redis will result in an error.
|
|
||||||
if "redis_url" in kwargs:
|
if "redis_url" in kwargs:
|
||||||
kwargs.pop("redis_url")
|
kwargs.pop("redis_url")
|
||||||
client = redis.from_url(url=redis_url, **kwargs)
|
|
||||||
# check if redis has redisearch module installed
|
|
||||||
_check_redis_module_exist(client, REDIS_REQUIRED_MODULES)
|
|
||||||
except ValueError as e:
|
|
||||||
raise ValueError(f"Redis failed to connect: {e}")
|
|
||||||
|
|
||||||
# Create embeddings over documents
|
|
||||||
embeddings = embedding.embed_documents(texts)
|
|
||||||
|
|
||||||
# Name of the search index if not given
|
# Name of the search index if not given
|
||||||
if not index_name:
|
if not index_name:
|
||||||
index_name = uuid.uuid4().hex
|
index_name = uuid.uuid4().hex
|
||||||
prefix = _redis_prefix(index_name) # prefix for the document keys
|
|
||||||
|
|
||||||
# Check if index exists
|
# Create instance
|
||||||
if not _check_index_exists(client, index_name):
|
instance = cls(
|
||||||
# Constants
|
redis_url=redis_url,
|
||||||
dim = len(embeddings[0])
|
index_name=index_name,
|
||||||
distance_metric = (
|
embedding_function=embedding.embed_query,
|
||||||
"COSINE" # distance metric for the vectors (ex. COSINE, IP, L2)
|
|
||||||
)
|
|
||||||
schema = (
|
|
||||||
TextField(name=content_key),
|
|
||||||
TextField(name=metadata_key),
|
|
||||||
VectorField(
|
|
||||||
vector_key,
|
|
||||||
"FLAT",
|
|
||||||
{
|
|
||||||
"TYPE": "FLOAT32",
|
|
||||||
"DIM": dim,
|
|
||||||
"DISTANCE_METRIC": distance_metric,
|
|
||||||
},
|
|
||||||
),
|
|
||||||
)
|
|
||||||
# Create Redis Index
|
|
||||||
client.ft(index_name).create_index(
|
|
||||||
fields=schema,
|
|
||||||
definition=IndexDefinition(prefix=[prefix], index_type=IndexType.HASH),
|
|
||||||
)
|
|
||||||
|
|
||||||
# Write data to Redis
|
|
||||||
pipeline = client.pipeline(transaction=False)
|
|
||||||
for i, text in enumerate(texts):
|
|
||||||
key = _redis_key(prefix)
|
|
||||||
metadata = metadatas[i] if metadatas else {}
|
|
||||||
pipeline.hset(
|
|
||||||
key,
|
|
||||||
mapping={
|
|
||||||
content_key: text,
|
|
||||||
vector_key: np.array(embeddings[i], dtype=np.float32).tobytes(),
|
|
||||||
metadata_key: json.dumps(metadata),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
pipeline.execute()
|
|
||||||
return cls(
|
|
||||||
redis_url,
|
|
||||||
index_name,
|
|
||||||
embedding.embed_query,
|
|
||||||
content_key=content_key,
|
content_key=content_key,
|
||||||
metadata_key=metadata_key,
|
metadata_key=metadata_key,
|
||||||
vector_key=vector_key,
|
vector_key=vector_key,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Create embeddings over documents
|
||||||
|
embeddings = embedding.embed_documents(texts)
|
||||||
|
|
||||||
|
# Create the search index
|
||||||
|
instance._create_index(dim=len(embeddings[0]))
|
||||||
|
|
||||||
|
# Add data to Redis
|
||||||
|
instance.add_texts(texts, metadatas, embeddings)
|
||||||
|
return instance
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def drop_index(
|
def drop_index(
|
||||||
index_name: str,
|
index_name: str,
|
||||||
|
55
tests/integration_tests/cache/test_redis_cache.py
vendored
Normal file
55
tests/integration_tests/cache/test_redis_cache.py
vendored
Normal file
@ -0,0 +1,55 @@
|
|||||||
|
"""Test Redis cache functionality."""
|
||||||
|
import redis
|
||||||
|
|
||||||
|
import langchain
|
||||||
|
from langchain.cache import RedisCache, RedisSemanticCache
|
||||||
|
from langchain.schema import Generation, LLMResult
|
||||||
|
from tests.integration_tests.vectorstores.fake_embeddings import FakeEmbeddings
|
||||||
|
from tests.unit_tests.llms.fake_llm import FakeLLM
|
||||||
|
|
||||||
|
REDIS_TEST_URL = "redis://localhost:6379"
|
||||||
|
|
||||||
|
|
||||||
|
def test_redis_cache() -> None:
|
||||||
|
langchain.llm_cache = RedisCache(redis_=redis.Redis.from_url(REDIS_TEST_URL))
|
||||||
|
llm = FakeLLM()
|
||||||
|
params = llm.dict()
|
||||||
|
params["stop"] = None
|
||||||
|
llm_string = str(sorted([(k, v) for k, v in params.items()]))
|
||||||
|
langchain.llm_cache.update("foo", llm_string, [Generation(text="fizz")])
|
||||||
|
output = llm.generate(["foo"])
|
||||||
|
print(output)
|
||||||
|
expected_output = LLMResult(
|
||||||
|
generations=[[Generation(text="fizz")]],
|
||||||
|
llm_output={},
|
||||||
|
)
|
||||||
|
print(expected_output)
|
||||||
|
assert output == expected_output
|
||||||
|
langchain.llm_cache.redis.flushall()
|
||||||
|
|
||||||
|
|
||||||
|
def test_redis_semantic_cache() -> None:
|
||||||
|
langchain.llm_cache = RedisSemanticCache(
|
||||||
|
embedding=FakeEmbeddings(), redis_url=REDIS_TEST_URL, score_threshold=0.1
|
||||||
|
)
|
||||||
|
llm = FakeLLM()
|
||||||
|
params = llm.dict()
|
||||||
|
params["stop"] = None
|
||||||
|
llm_string = str(sorted([(k, v) for k, v in params.items()]))
|
||||||
|
langchain.llm_cache.update("foo", llm_string, [Generation(text="fizz")])
|
||||||
|
output = llm.generate(
|
||||||
|
["bar"]
|
||||||
|
) # foo and bar will have the same embedding produced by FakeEmbeddings
|
||||||
|
expected_output = LLMResult(
|
||||||
|
generations=[[Generation(text="fizz")]],
|
||||||
|
llm_output={},
|
||||||
|
)
|
||||||
|
assert output == expected_output
|
||||||
|
# clear the cache
|
||||||
|
langchain.llm_cache.clear(llm_string=llm_string)
|
||||||
|
output = llm.generate(
|
||||||
|
["bar"]
|
||||||
|
) # foo and bar will have the same embedding produced by FakeEmbeddings
|
||||||
|
# expect different output now without cached result
|
||||||
|
assert output != expected_output
|
||||||
|
langchain.llm_cache.clear(llm_string=llm_string)
|
@ -1,26 +1,60 @@
|
|||||||
"""Test Redis functionality."""
|
"""Test Redis functionality."""
|
||||||
|
|
||||||
from langchain.docstore.document import Document
|
from langchain.docstore.document import Document
|
||||||
from langchain.vectorstores.redis import Redis
|
from langchain.vectorstores.redis import Redis
|
||||||
from tests.integration_tests.vectorstores.fake_embeddings import FakeEmbeddings
|
from tests.integration_tests.vectorstores.fake_embeddings import FakeEmbeddings
|
||||||
|
|
||||||
|
TEST_INDEX_NAME = "test"
|
||||||
|
TEST_REDIS_URL = "redis://localhost:6379"
|
||||||
|
TEST_SINGLE_RESULT = [Document(page_content="foo")]
|
||||||
|
TEST_RESULT = [Document(page_content="foo"), Document(page_content="foo")]
|
||||||
|
|
||||||
|
|
||||||
|
def drop(index_name: str) -> bool:
|
||||||
|
return Redis.drop_index(
|
||||||
|
index_name=index_name, delete_documents=True, redis_url=TEST_REDIS_URL
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_redis() -> None:
|
def test_redis() -> None:
|
||||||
"""Test end to end construction and search."""
|
"""Test end to end construction and search."""
|
||||||
texts = ["foo", "bar", "baz"]
|
texts = ["foo", "bar", "baz"]
|
||||||
docsearch = Redis.from_texts(
|
docsearch = Redis.from_texts(texts, FakeEmbeddings(), redis_url=TEST_REDIS_URL)
|
||||||
texts, FakeEmbeddings(), redis_url="redis://localhost:6379"
|
|
||||||
)
|
|
||||||
output = docsearch.similarity_search("foo", k=1)
|
output = docsearch.similarity_search("foo", k=1)
|
||||||
assert output == [Document(page_content="foo")]
|
assert output == TEST_SINGLE_RESULT
|
||||||
|
assert drop(docsearch.index_name)
|
||||||
|
|
||||||
|
|
||||||
def test_redis_new_vector() -> None:
|
def test_redis_new_vector() -> None:
|
||||||
"""Test adding a new document"""
|
"""Test adding a new document"""
|
||||||
texts = ["foo", "bar", "baz"]
|
texts = ["foo", "bar", "baz"]
|
||||||
docsearch = Redis.from_texts(
|
docsearch = Redis.from_texts(texts, FakeEmbeddings(), redis_url=TEST_REDIS_URL)
|
||||||
texts, FakeEmbeddings(), redis_url="redis://localhost:6379"
|
docsearch.add_texts(["foo"])
|
||||||
|
output = docsearch.similarity_search("foo", k=2)
|
||||||
|
assert output == TEST_RESULT
|
||||||
|
assert drop(docsearch.index_name)
|
||||||
|
|
||||||
|
|
||||||
|
def test_redis_from_existing() -> None:
|
||||||
|
"""Test adding a new document"""
|
||||||
|
texts = ["foo", "bar", "baz"]
|
||||||
|
Redis.from_texts(
|
||||||
|
texts, FakeEmbeddings(), index_name=TEST_INDEX_NAME, redis_url=TEST_REDIS_URL
|
||||||
|
)
|
||||||
|
# Test creating from an existing
|
||||||
|
docsearch2 = Redis.from_existing_index(
|
||||||
|
FakeEmbeddings(), index_name=TEST_INDEX_NAME, redis_url=TEST_REDIS_URL
|
||||||
|
)
|
||||||
|
output = docsearch2.similarity_search("foo", k=1)
|
||||||
|
assert output == TEST_SINGLE_RESULT
|
||||||
|
|
||||||
|
|
||||||
|
def test_redis_add_texts_to_existing() -> None:
|
||||||
|
"""Test adding a new document"""
|
||||||
|
# Test creating from an existing
|
||||||
|
docsearch = Redis.from_existing_index(
|
||||||
|
FakeEmbeddings(), index_name=TEST_INDEX_NAME, redis_url=TEST_REDIS_URL
|
||||||
)
|
)
|
||||||
docsearch.add_texts(["foo"])
|
docsearch.add_texts(["foo"])
|
||||||
output = docsearch.similarity_search("foo", k=2)
|
output = docsearch.similarity_search("foo", k=2)
|
||||||
assert output == [Document(page_content="foo"), Document(page_content="foo")]
|
assert output == TEST_RESULT
|
||||||
|
assert drop(TEST_INDEX_NAME)
|
||||||
|
Loading…
Reference in New Issue
Block a user