Harrison/gpt cache (#2744)

Co-authored-by: SimFG <bang.fu@zilliz.com>
fix_agent_callbacks
Harrison Chase 1 year ago committed by GitHub
parent 425c437cd3
commit e49f1e628c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -60,14 +60,14 @@
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 30.7 ms, sys: 18.6 ms, total: 49.3 ms\n",
"Wall time: 791 ms\n"
"CPU times: user 14.2 ms, sys: 4.9 ms, total: 19.1 ms\n",
"Wall time: 1.1 s\n"
]
},
{
"data": {
"text/plain": [
"\"\\n\\nWhy couldn't the bicycle stand up by itself? Because it was...two tired!\""
"'\\n\\nWhy did the chicken cross the road?\\n\\nTo get to the other side.'"
]
},
"execution_count": 4,
@ -91,14 +91,14 @@
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 80 µs, sys: 0 ns, total: 80 µs\n",
"Wall time: 83.9 µs\n"
"CPU times: user 162 µs, sys: 7 µs, total: 169 µs\n",
"Wall time: 175 µs\n"
]
},
{
"data": {
"text/plain": [
"\"\\n\\nWhy couldn't the bicycle stand up by itself? Because it was...two tired!\""
"'\\n\\nWhy did the chicken cross the road?\\n\\nTo get to the other side.'"
]
},
"execution_count": 5,
@ -252,6 +252,249 @@
"llm(\"Tell me a joke\")"
]
},
{
"cell_type": "markdown",
"id": "684eab55",
"metadata": {},
"source": [
"## GPTCache\n",
"\n",
"We can use [GPTCache](https://github.com/zilliztech/GPTCache) for exact match caching OR to cache results based on semantic similarity\n",
"\n",
"Let's first start with an example of exact match"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "14a82124",
"metadata": {},
"outputs": [],
"source": [
"import gptcache\n",
"from gptcache.processor.pre import get_prompt\n",
"from gptcache.manager.factory import get_data_manager\n",
"from langchain.cache import GPTCache\n",
"\n",
"# Avoid multiple caches using the same file, causing different llm model caches to affect each other\n",
"i = 0\n",
"file_prefix = \"data_map\"\n",
"\n",
"def init_gptcache_map(cache_obj: gptcache.Cache):\n",
" global i\n",
" cache_path = f'{file_prefix}_{i}.txt'\n",
" cache_obj.init(\n",
" pre_embedding_func=get_prompt,\n",
" data_manager=get_data_manager(data_path=cache_path),\n",
" )\n",
" i += 1\n",
"\n",
"langchain.llm_cache = GPTCache(init_gptcache_map)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "9e4ecfd1",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 8.6 ms, sys: 3.82 ms, total: 12.4 ms\n",
"Wall time: 881 ms\n"
]
},
{
"data": {
"text/plain": [
"'\\n\\nWhy did the chicken cross the road?\\n\\nTo get to the other side.'"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%time\n",
"# The first time, it is not yet in cache, so it should take longer\n",
"llm(\"Tell me a joke\")"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "c98bbe3b",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 286 µs, sys: 21 µs, total: 307 µs\n",
"Wall time: 316 µs\n"
]
},
{
"data": {
"text/plain": [
"'\\n\\nWhy did the chicken cross the road?\\n\\nTo get to the other side.'"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%time\n",
"# The second time it is, so it goes faster\n",
"llm(\"Tell me a joke\")"
]
},
{
"cell_type": "markdown",
"id": "502b6076",
"metadata": {},
"source": [
"Let's now show an example of similarity caching"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "b3c663bb",
"metadata": {},
"outputs": [],
"source": [
"import gptcache\n",
"from gptcache.processor.pre import get_prompt\n",
"from gptcache.manager.factory import get_data_manager\n",
"from langchain.cache import GPTCache\n",
"from gptcache.manager import get_data_manager, CacheBase, VectorBase\n",
"from gptcache import Cache\n",
"from gptcache.embedding import Onnx\n",
"from gptcache.similarity_evaluation.distance import SearchDistanceEvaluation\n",
"\n",
"# Avoid multiple caches using the same file, causing different llm model caches to affect each other\n",
"i = 0\n",
"file_prefix = \"data_map\"\n",
"llm_cache = Cache()\n",
"\n",
"\n",
"def init_gptcache_map(cache_obj: gptcache.Cache):\n",
" global i\n",
" cache_path = f'{file_prefix}_{i}.txt'\n",
" onnx = Onnx()\n",
" cache_base = CacheBase('sqlite')\n",
" vector_base = VectorBase('faiss', dimension=onnx.dimension)\n",
" data_manager = get_data_manager(cache_base, vector_base, max_size=10, clean_size=2)\n",
" cache_obj.init(\n",
" pre_embedding_func=get_prompt,\n",
" embedding_func=onnx.to_embeddings,\n",
" data_manager=data_manager,\n",
" similarity_evaluation=SearchDistanceEvaluation(),\n",
" )\n",
" i += 1\n",
"\n",
"langchain.llm_cache = GPTCache(init_gptcache_map)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "8c273ced",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 1.01 s, sys: 153 ms, total: 1.16 s\n",
"Wall time: 2.49 s\n"
]
},
{
"data": {
"text/plain": [
"'\\n\\nWhy did the chicken cross the road?\\n\\nTo get to the other side.'"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%time\n",
"# The first time, it is not yet in cache, so it should take longer\n",
"llm(\"Tell me a joke\")"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "93e21a5f",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 745 ms, sys: 13.2 ms, total: 758 ms\n",
"Wall time: 136 ms\n"
]
},
{
"data": {
"text/plain": [
"'\\n\\nWhy did the chicken cross the road?\\n\\nTo get to the other side.'"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%time\n",
"# This is an exact match, so it finds it in the cache\n",
"llm(\"Tell me a joke\")"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "c4bb024b",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 737 ms, sys: 7.79 ms, total: 745 ms\n",
"Wall time: 135 ms\n"
]
},
{
"data": {
"text/plain": [
"'\\n\\nWhy did the chicken cross the road?\\n\\nTo get to the other side.'"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%time\n",
"# This is not an exact match, but semantically within distance so it hits!\n",
"llm(\"Tell me joke\")"
]
},
{
"cell_type": "markdown",
"id": "934943dc",

@ -1,6 +1,7 @@
"""Beta Feature: base interface for cache."""
import json
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional, Tuple
from typing import Any, Callable, Dict, List, Optional, Tuple
from sqlalchemy import Column, Integer, String, create_engine, select
from sqlalchemy.engine.base import Engine
@ -137,3 +138,125 @@ class RedisCache(BaseCache):
"""Update cache based on prompt and llm_string."""
for i, generation in enumerate(return_val):
self.redis.set(self._key(prompt, llm_string, i), generation.text)
class GPTCache(BaseCache):
"""Cache that uses GPTCache as a backend."""
def __init__(self, init_func: Callable[[Any], None]):
"""Initialize by passing in the `init` GPTCache func
Args:
init_func (Callable[[Any], None]): init `GPTCache` function
Example:
.. code-block:: python
import gptcache
from gptcache.processor.pre import get_prompt
from gptcache.manager.factory import get_data_manager
# Avoid multiple caches using the same file,
causing different llm model caches to affect each other
i = 0
file_prefix = "data_map"
def init_gptcache_map(cache_obj: gptcache.Cache):
nonlocal i
cache_path = f'{file_prefix}_{i}.txt'
cache_obj.init(
pre_embedding_func=get_prompt,
data_manager=get_data_manager(data_path=cache_path),
)
i += 1
langchain.llm_cache = GPTCache(init_gptcache_map)
"""
try:
import gptcache # noqa: F401
except ImportError:
raise ValueError(
"Could not import gptcache python package. "
"Please install it with `pip install gptcache`."
)
self.init_gptcache_func: Callable[[Any], None] = init_func
self.gptcache_dict: Dict[str, Any] = {}
@staticmethod
def _update_cache_callback_none(*_: Any, **__: Any) -> None:
"""When updating cached data, do nothing.
Because currently only cached queries are processed."""
return None
@staticmethod
def _llm_handle_none(*_: Any, **__: Any) -> None:
"""Do nothing on a cache miss"""
return None
@staticmethod
def _cache_data_converter(data: str) -> RETURN_VAL_TYPE:
"""Convert the `data` in the cache to the `RETURN_VAL_TYPE` data format."""
return [Generation(**generation_dict) for generation_dict in json.loads(data)]
def _get_gptcache(self, llm_string: str) -> Any:
"""Get a cache object.
When the corresponding llm model cache does not exist, it will be created."""
from gptcache import Cache
_gptcache = self.gptcache_dict.get(llm_string, None)
if _gptcache is None:
_gptcache = Cache()
self.init_gptcache_func(_gptcache)
self.gptcache_dict[llm_string] = _gptcache
return _gptcache
def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]:
"""Look up the cache data.
First, retrieve the corresponding cache object using the `llm_string` parameter,
and then retrieve the data from the cache based on the `prompt`.
"""
from gptcache.adapter.adapter import adapt
_gptcache = self.gptcache_dict.get(llm_string)
if _gptcache is None:
return None
res = adapt(
GPTCache._llm_handle_none,
GPTCache._cache_data_converter,
GPTCache._update_cache_callback_none,
cache_obj=_gptcache,
prompt=prompt,
)
return res
@staticmethod
def _update_cache_callback(
llm_data: RETURN_VAL_TYPE, update_cache_func: Callable[[Any], None]
) -> None:
"""Save the `llm_data` to cache storage"""
handled_data = json.dumps([generation.dict() for generation in llm_data])
update_cache_func(handled_data)
def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None:
"""Update cache.
First, retrieve the corresponding cache object using the `llm_string` parameter,
and then store the `prompt` and `return_val` in the cache object.
"""
from gptcache.adapter.adapter import adapt
_gptcache = self._get_gptcache(llm_string)
def llm_handle(*_: Any, **__: Any) -> RETURN_VAL_TYPE:
return return_val
return adapt(
llm_handle,
GPTCache._cache_data_converter,
GPTCache._update_cache_callback,
cache_obj=_gptcache,
cache_skip=True,
prompt=prompt,
)

19
poetry.lock generated

@ -2062,6 +2062,23 @@ protobuf = ">=3.19.5,<3.20.0 || >3.20.0,<3.20.1 || >3.20.1,<4.21.1 || >4.21.1,<4
[package.extras]
grpc = ["grpcio (>=1.44.0,<2.0.0dev)"]
[[package]]
name = "gptcache"
version = "0.1.8"
description = "GPT Cache, a powerful caching library that can be used to speed up and lower the cost of chat applications that rely on the LLM service. GPT Cache works as a memcache for AIGC applications, similar to how Redis works for traditional applications."
category = "main"
optional = true
python-versions = ">=3.8.1"
files = [
{file = "gptcache-0.1.8-py3-none-any.whl", hash = "sha256:953662291819471e5461920c89367084f905237a8506f1a1605729f3e633f147"},
{file = "gptcache-0.1.8.tar.gz", hash = "sha256:23200cc0783776210cce85a588ae68222d522ce9456f74b7836945ebe8b15820"},
]
[package.dependencies]
cachetools = "*"
numpy = "*"
openai = "*"
[[package]]
name = "greenlet"
version = "2.0.1"
@ -8994,4 +9011,4 @@ qdrant = ["qdrant-client"]
[metadata]
lock-version = "2.0"
python-versions = ">=3.8.1,<4.0"
content-hash = "11bbe0042c3c1e56a5d1abbaab33185efad5dcdacb095fc91e91c382f2c9ebb7"
content-hash = "26b1bbfbc3a228b892b2466af3561b799238a6d379853d325dc3c798776df0d8"

@ -59,6 +59,7 @@ psycopg2-binary = {version = "^2.9.5", optional = true}
#boto3 = {version = "^1.26.96", optional = true} # TODO: fix it, commented because the version failed with deeplake
pyowm = {version = "^3.3.0", optional = true}
async-timeout = {version = "^4.0.0", python = "<3.11"}
gptcache = {version = ">=0.1.7", optional = true}
[tool.poetry.group.docs.dependencies]
autodoc_pydantic = "^1.8.0"

@ -0,0 +1 @@
"""All integration tests for Cache objects."""

@ -0,0 +1,61 @@
import os
import pytest
import langchain
from langchain.cache import GPTCache
from langchain.schema import Generation, LLMResult
from tests.unit_tests.llms.fake_llm import FakeLLM
try:
import gptcache # noqa: F401
gptcache_installed = True
except ImportError:
gptcache_installed = False
@pytest.mark.skipif(not gptcache_installed, reason="gptcache not installed")
def test_gptcache_map_caching() -> None:
"""Test gptcache caching behavior."""
from gptcache import Cache
from gptcache.manager.factory import get_data_manager
from gptcache.processor.pre import get_prompt
i = 0
file_prefix = "data_map"
def init_gptcache_map(cache_obj: Cache) -> None:
nonlocal i
cache_path = f"{file_prefix}_{i}.txt"
if os.path.isfile(cache_path):
os.remove(cache_path)
cache_obj.init(
pre_embedding_func=get_prompt,
data_manager=get_data_manager(data_path=cache_path),
)
i += 1
langchain.llm_cache = GPTCache(init_gptcache_map)
llm = FakeLLM()
params = llm.dict()
params["stop"] = None
llm_string = str(sorted([(k, v) for k, v in params.items()]))
langchain.llm_cache.update("foo", llm_string, [Generation(text="fizz")])
output = llm.generate(["foo", "bar", "foo"])
expected_cache_output = [Generation(text="foo")]
cache_output = langchain.llm_cache.lookup("bar", llm_string)
assert cache_output == expected_cache_output
langchain.llm_cache = None
expected_generations = [
[Generation(text="fizz")],
[Generation(text="foo")],
[Generation(text="fizz")],
]
expected_output = LLMResult(
generations=expected_generations,
llm_output=None,
)
assert output == expected_output
Loading…
Cancel
Save