From e49f1e628cdfb71c75e4cbe67d8e05aaf8daffaa Mon Sep 17 00:00:00 2001 From: Harrison Chase Date: Wed, 12 Apr 2023 14:16:58 -0700 Subject: [PATCH] Harrison/gpt cache (#2744) Co-authored-by: SimFG --- .../models/llms/examples/llm_caching.ipynb | 255 +++++++++++++++++- langchain/cache.py | 125 ++++++++- poetry.lock | 19 +- pyproject.toml | 1 + tests/integration_tests/cache/__init__.py | 1 + .../integration_tests/cache/test_gptcache.py | 61 +++++ 6 files changed, 454 insertions(+), 8 deletions(-) create mode 100644 tests/integration_tests/cache/__init__.py create mode 100644 tests/integration_tests/cache/test_gptcache.py diff --git a/docs/modules/models/llms/examples/llm_caching.ipynb b/docs/modules/models/llms/examples/llm_caching.ipynb index e9ae75eb..8b65d7ba 100644 --- a/docs/modules/models/llms/examples/llm_caching.ipynb +++ b/docs/modules/models/llms/examples/llm_caching.ipynb @@ -60,14 +60,14 @@ "name": "stdout", "output_type": "stream", "text": [ - "CPU times: user 30.7 ms, sys: 18.6 ms, total: 49.3 ms\n", - "Wall time: 791 ms\n" + "CPU times: user 14.2 ms, sys: 4.9 ms, total: 19.1 ms\n", + "Wall time: 1.1 s\n" ] }, { "data": { "text/plain": [ - "\"\\n\\nWhy couldn't the bicycle stand up by itself? Because it was...two tired!\"" + "'\\n\\nWhy did the chicken cross the road?\\n\\nTo get to the other side.'" ] }, "execution_count": 4, @@ -91,14 +91,14 @@ "name": "stdout", "output_type": "stream", "text": [ - "CPU times: user 80 µs, sys: 0 ns, total: 80 µs\n", - "Wall time: 83.9 µs\n" + "CPU times: user 162 µs, sys: 7 µs, total: 169 µs\n", + "Wall time: 175 µs\n" ] }, { "data": { "text/plain": [ - "\"\\n\\nWhy couldn't the bicycle stand up by itself? Because it was...two tired!\"" + "'\\n\\nWhy did the chicken cross the road?\\n\\nTo get to the other side.'" ] }, "execution_count": 5, @@ -252,6 +252,249 @@ "llm(\"Tell me a joke\")" ] }, + { + "cell_type": "markdown", + "id": "684eab55", + "metadata": {}, + "source": [ + "## GPTCache\n", + "\n", + "We can use [GPTCache](https://github.com/zilliztech/GPTCache) for exact match caching OR to cache results based on semantic similarity\n", + "\n", + "Let's first start with an example of exact match" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "14a82124", + "metadata": {}, + "outputs": [], + "source": [ + "import gptcache\n", + "from gptcache.processor.pre import get_prompt\n", + "from gptcache.manager.factory import get_data_manager\n", + "from langchain.cache import GPTCache\n", + "\n", + "# Avoid multiple caches using the same file, causing different llm model caches to affect each other\n", + "i = 0\n", + "file_prefix = \"data_map\"\n", + "\n", + "def init_gptcache_map(cache_obj: gptcache.Cache):\n", + " global i\n", + " cache_path = f'{file_prefix}_{i}.txt'\n", + " cache_obj.init(\n", + " pre_embedding_func=get_prompt,\n", + " data_manager=get_data_manager(data_path=cache_path),\n", + " )\n", + " i += 1\n", + "\n", + "langchain.llm_cache = GPTCache(init_gptcache_map)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "9e4ecfd1", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "CPU times: user 8.6 ms, sys: 3.82 ms, total: 12.4 ms\n", + "Wall time: 881 ms\n" + ] + }, + { + "data": { + "text/plain": [ + "'\\n\\nWhy did the chicken cross the road?\\n\\nTo get to the other side.'" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "%%time\n", + "# The first time, it is not yet in cache, so it should take longer\n", + "llm(\"Tell me a joke\")" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "c98bbe3b", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "CPU times: user 286 µs, sys: 21 µs, total: 307 µs\n", + "Wall time: 316 µs\n" + ] + }, + { + "data": { + "text/plain": [ + "'\\n\\nWhy did the chicken cross the road?\\n\\nTo get to the other side.'" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "%%time\n", + "# The second time it is, so it goes faster\n", + "llm(\"Tell me a joke\")" + ] + }, + { + "cell_type": "markdown", + "id": "502b6076", + "metadata": {}, + "source": [ + "Let's now show an example of similarity caching" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "b3c663bb", + "metadata": {}, + "outputs": [], + "source": [ + "import gptcache\n", + "from gptcache.processor.pre import get_prompt\n", + "from gptcache.manager.factory import get_data_manager\n", + "from langchain.cache import GPTCache\n", + "from gptcache.manager import get_data_manager, CacheBase, VectorBase\n", + "from gptcache import Cache\n", + "from gptcache.embedding import Onnx\n", + "from gptcache.similarity_evaluation.distance import SearchDistanceEvaluation\n", + "\n", + "# Avoid multiple caches using the same file, causing different llm model caches to affect each other\n", + "i = 0\n", + "file_prefix = \"data_map\"\n", + "llm_cache = Cache()\n", + "\n", + "\n", + "def init_gptcache_map(cache_obj: gptcache.Cache):\n", + " global i\n", + " cache_path = f'{file_prefix}_{i}.txt'\n", + " onnx = Onnx()\n", + " cache_base = CacheBase('sqlite')\n", + " vector_base = VectorBase('faiss', dimension=onnx.dimension)\n", + " data_manager = get_data_manager(cache_base, vector_base, max_size=10, clean_size=2)\n", + " cache_obj.init(\n", + " pre_embedding_func=get_prompt,\n", + " embedding_func=onnx.to_embeddings,\n", + " data_manager=data_manager,\n", + " similarity_evaluation=SearchDistanceEvaluation(),\n", + " )\n", + " i += 1\n", + "\n", + "langchain.llm_cache = GPTCache(init_gptcache_map)" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "8c273ced", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "CPU times: user 1.01 s, sys: 153 ms, total: 1.16 s\n", + "Wall time: 2.49 s\n" + ] + }, + { + "data": { + "text/plain": [ + "'\\n\\nWhy did the chicken cross the road?\\n\\nTo get to the other side.'" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "%%time\n", + "# The first time, it is not yet in cache, so it should take longer\n", + "llm(\"Tell me a joke\")" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "93e21a5f", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "CPU times: user 745 ms, sys: 13.2 ms, total: 758 ms\n", + "Wall time: 136 ms\n" + ] + }, + { + "data": { + "text/plain": [ + "'\\n\\nWhy did the chicken cross the road?\\n\\nTo get to the other side.'" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "%%time\n", + "# This is an exact match, so it finds it in the cache\n", + "llm(\"Tell me a joke\")" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "c4bb024b", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "CPU times: user 737 ms, sys: 7.79 ms, total: 745 ms\n", + "Wall time: 135 ms\n" + ] + }, + { + "data": { + "text/plain": [ + "'\\n\\nWhy did the chicken cross the road?\\n\\nTo get to the other side.'" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "%%time\n", + "# This is not an exact match, but semantically within distance so it hits!\n", + "llm(\"Tell me joke\")" + ] + }, { "cell_type": "markdown", "id": "934943dc", diff --git a/langchain/cache.py b/langchain/cache.py index 3d7149d9..6f388827 100644 --- a/langchain/cache.py +++ b/langchain/cache.py @@ -1,6 +1,7 @@ """Beta Feature: base interface for cache.""" +import json from abc import ABC, abstractmethod -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Callable, Dict, List, Optional, Tuple from sqlalchemy import Column, Integer, String, create_engine, select from sqlalchemy.engine.base import Engine @@ -137,3 +138,125 @@ class RedisCache(BaseCache): """Update cache based on prompt and llm_string.""" for i, generation in enumerate(return_val): self.redis.set(self._key(prompt, llm_string, i), generation.text) + + +class GPTCache(BaseCache): + """Cache that uses GPTCache as a backend.""" + + def __init__(self, init_func: Callable[[Any], None]): + """Initialize by passing in the `init` GPTCache func + + Args: + init_func (Callable[[Any], None]): init `GPTCache` function + + Example: + .. code-block:: python + + import gptcache + from gptcache.processor.pre import get_prompt + from gptcache.manager.factory import get_data_manager + + # Avoid multiple caches using the same file, + causing different llm model caches to affect each other + i = 0 + file_prefix = "data_map" + + def init_gptcache_map(cache_obj: gptcache.Cache): + nonlocal i + cache_path = f'{file_prefix}_{i}.txt' + cache_obj.init( + pre_embedding_func=get_prompt, + data_manager=get_data_manager(data_path=cache_path), + ) + i += 1 + + langchain.llm_cache = GPTCache(init_gptcache_map) + + """ + try: + import gptcache # noqa: F401 + except ImportError: + raise ValueError( + "Could not import gptcache python package. " + "Please install it with `pip install gptcache`." + ) + self.init_gptcache_func: Callable[[Any], None] = init_func + self.gptcache_dict: Dict[str, Any] = {} + + @staticmethod + def _update_cache_callback_none(*_: Any, **__: Any) -> None: + """When updating cached data, do nothing. + + Because currently only cached queries are processed.""" + return None + + @staticmethod + def _llm_handle_none(*_: Any, **__: Any) -> None: + """Do nothing on a cache miss""" + return None + + @staticmethod + def _cache_data_converter(data: str) -> RETURN_VAL_TYPE: + """Convert the `data` in the cache to the `RETURN_VAL_TYPE` data format.""" + return [Generation(**generation_dict) for generation_dict in json.loads(data)] + + def _get_gptcache(self, llm_string: str) -> Any: + """Get a cache object. + + When the corresponding llm model cache does not exist, it will be created.""" + from gptcache import Cache + + _gptcache = self.gptcache_dict.get(llm_string, None) + if _gptcache is None: + _gptcache = Cache() + self.init_gptcache_func(_gptcache) + self.gptcache_dict[llm_string] = _gptcache + return _gptcache + + def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]: + """Look up the cache data. + First, retrieve the corresponding cache object using the `llm_string` parameter, + and then retrieve the data from the cache based on the `prompt`. + """ + from gptcache.adapter.adapter import adapt + + _gptcache = self.gptcache_dict.get(llm_string) + if _gptcache is None: + return None + res = adapt( + GPTCache._llm_handle_none, + GPTCache._cache_data_converter, + GPTCache._update_cache_callback_none, + cache_obj=_gptcache, + prompt=prompt, + ) + return res + + @staticmethod + def _update_cache_callback( + llm_data: RETURN_VAL_TYPE, update_cache_func: Callable[[Any], None] + ) -> None: + """Save the `llm_data` to cache storage""" + handled_data = json.dumps([generation.dict() for generation in llm_data]) + update_cache_func(handled_data) + + def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None: + """Update cache. + First, retrieve the corresponding cache object using the `llm_string` parameter, + and then store the `prompt` and `return_val` in the cache object. + """ + from gptcache.adapter.adapter import adapt + + _gptcache = self._get_gptcache(llm_string) + + def llm_handle(*_: Any, **__: Any) -> RETURN_VAL_TYPE: + return return_val + + return adapt( + llm_handle, + GPTCache._cache_data_converter, + GPTCache._update_cache_callback, + cache_obj=_gptcache, + cache_skip=True, + prompt=prompt, + ) diff --git a/poetry.lock b/poetry.lock index 8e2bdb52..197c8d5b 100644 --- a/poetry.lock +++ b/poetry.lock @@ -2062,6 +2062,23 @@ protobuf = ">=3.19.5,<3.20.0 || >3.20.0,<3.20.1 || >3.20.1,<4.21.1 || >4.21.1,<4 [package.extras] grpc = ["grpcio (>=1.44.0,<2.0.0dev)"] +[[package]] +name = "gptcache" +version = "0.1.8" +description = "GPT Cache, a powerful caching library that can be used to speed up and lower the cost of chat applications that rely on the LLM service. GPT Cache works as a memcache for AIGC applications, similar to how Redis works for traditional applications." +category = "main" +optional = true +python-versions = ">=3.8.1" +files = [ + {file = "gptcache-0.1.8-py3-none-any.whl", hash = "sha256:953662291819471e5461920c89367084f905237a8506f1a1605729f3e633f147"}, + {file = "gptcache-0.1.8.tar.gz", hash = "sha256:23200cc0783776210cce85a588ae68222d522ce9456f74b7836945ebe8b15820"}, +] + +[package.dependencies] +cachetools = "*" +numpy = "*" +openai = "*" + [[package]] name = "greenlet" version = "2.0.1" @@ -8994,4 +9011,4 @@ qdrant = ["qdrant-client"] [metadata] lock-version = "2.0" python-versions = ">=3.8.1,<4.0" -content-hash = "11bbe0042c3c1e56a5d1abbaab33185efad5dcdacb095fc91e91c382f2c9ebb7" +content-hash = "26b1bbfbc3a228b892b2466af3561b799238a6d379853d325dc3c798776df0d8" diff --git a/pyproject.toml b/pyproject.toml index 060ec2bd..4fb03e2c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -59,6 +59,7 @@ psycopg2-binary = {version = "^2.9.5", optional = true} #boto3 = {version = "^1.26.96", optional = true} # TODO: fix it, commented because the version failed with deeplake pyowm = {version = "^3.3.0", optional = true} async-timeout = {version = "^4.0.0", python = "<3.11"} +gptcache = {version = ">=0.1.7", optional = true} [tool.poetry.group.docs.dependencies] autodoc_pydantic = "^1.8.0" diff --git a/tests/integration_tests/cache/__init__.py b/tests/integration_tests/cache/__init__.py new file mode 100644 index 00000000..f75c193f --- /dev/null +++ b/tests/integration_tests/cache/__init__.py @@ -0,0 +1 @@ +"""All integration tests for Cache objects.""" diff --git a/tests/integration_tests/cache/test_gptcache.py b/tests/integration_tests/cache/test_gptcache.py new file mode 100644 index 00000000..8a7f6cdb --- /dev/null +++ b/tests/integration_tests/cache/test_gptcache.py @@ -0,0 +1,61 @@ +import os + +import pytest + +import langchain +from langchain.cache import GPTCache +from langchain.schema import Generation, LLMResult +from tests.unit_tests.llms.fake_llm import FakeLLM + +try: + import gptcache # noqa: F401 + + gptcache_installed = True +except ImportError: + gptcache_installed = False + + +@pytest.mark.skipif(not gptcache_installed, reason="gptcache not installed") +def test_gptcache_map_caching() -> None: + """Test gptcache caching behavior.""" + + from gptcache import Cache + from gptcache.manager.factory import get_data_manager + from gptcache.processor.pre import get_prompt + + i = 0 + file_prefix = "data_map" + + def init_gptcache_map(cache_obj: Cache) -> None: + nonlocal i + cache_path = f"{file_prefix}_{i}.txt" + if os.path.isfile(cache_path): + os.remove(cache_path) + cache_obj.init( + pre_embedding_func=get_prompt, + data_manager=get_data_manager(data_path=cache_path), + ) + i += 1 + + langchain.llm_cache = GPTCache(init_gptcache_map) + + llm = FakeLLM() + params = llm.dict() + params["stop"] = None + llm_string = str(sorted([(k, v) for k, v in params.items()])) + langchain.llm_cache.update("foo", llm_string, [Generation(text="fizz")]) + output = llm.generate(["foo", "bar", "foo"]) + expected_cache_output = [Generation(text="foo")] + cache_output = langchain.llm_cache.lookup("bar", llm_string) + assert cache_output == expected_cache_output + langchain.llm_cache = None + expected_generations = [ + [Generation(text="fizz")], + [Generation(text="foo")], + [Generation(text="fizz")], + ] + expected_output = LLMResult( + generations=expected_generations, + llm_output=None, + ) + assert output == expected_output