From 7bcf238a1acf40aef21a5a198cf0e62d76f93c15 Mon Sep 17 00:00:00 2001 From: SimFG Date: Fri, 12 May 2023 07:15:23 +0800 Subject: [PATCH] Optimize the initialization method of GPTCache (#4522) Optimize the initialization method of GPTCache, so that users can use GPTCache more quickly. --- .../models/llms/examples/llm_caching.ipynb | 50 ++++----------- langchain/cache.py | 62 +++++++++++-------- .../integration_tests/cache/test_gptcache.py | 20 +++++- 3 files changed, 66 insertions(+), 66 deletions(-) diff --git a/docs/modules/models/llms/examples/llm_caching.ipynb b/docs/modules/models/llms/examples/llm_caching.ipynb index 8655c000..cec16090 100644 --- a/docs/modules/models/llms/examples/llm_caching.ipynb +++ b/docs/modules/models/llms/examples/llm_caching.ipynb @@ -408,25 +408,20 @@ "metadata": {}, "outputs": [], "source": [ - "import gptcache\n", + "from gptcache import Cache\n", + "from gptcache.manager.factory import manager_factory\n", "from gptcache.processor.pre import get_prompt\n", - "from gptcache.manager.factory import get_data_manager\n", "from langchain.cache import GPTCache\n", "\n", "# Avoid multiple caches using the same file, causing different llm model caches to affect each other\n", - "i = 0\n", - "file_prefix = \"data_map\"\n", "\n", - "def init_gptcache_map(cache_obj: gptcache.Cache):\n", - " global i\n", - " cache_path = f'{file_prefix}_{i}.txt'\n", + "def init_gptcache(cache_obj: Cache, llm str):\n", " cache_obj.init(\n", " pre_embedding_func=get_prompt,\n", - " data_manager=get_data_manager(data_path=cache_path),\n", + " data_manager=manager_factory(manager=\"map\", data_dir=f\"map_cache_{llm}\"),\n", " )\n", - " i += 1\n", "\n", - "langchain.llm_cache = GPTCache(init_gptcache_map)" + "langchain.llm_cache = GPTCache(init_gptcache)" ] }, { @@ -506,37 +501,16 @@ "metadata": {}, "outputs": [], "source": [ - "import gptcache\n", - "from gptcache.processor.pre import get_prompt\n", - "from gptcache.manager.factory import get_data_manager\n", - "from langchain.cache import GPTCache\n", - "from gptcache.manager import get_data_manager, CacheBase, VectorBase\n", "from gptcache import Cache\n", - "from gptcache.embedding import Onnx\n", - "from gptcache.similarity_evaluation.distance import SearchDistanceEvaluation\n", + "from gptcache.adapter.api import init_similar_cache\n", + "from langchain.cache import GPTCache\n", "\n", "# Avoid multiple caches using the same file, causing different llm model caches to affect each other\n", - "i = 0\n", - "file_prefix = \"data_map\"\n", - "llm_cache = Cache()\n", "\n", + "def init_gptcache(cache_obj: Cache, llm str):\n", + " init_similar_cache(cache_obj=cache_obj, data_dir=f\"similar_cache_{llm}\")\n", "\n", - "def init_gptcache_map(cache_obj: gptcache.Cache):\n", - " global i\n", - " cache_path = f'{file_prefix}_{i}.txt'\n", - " onnx = Onnx()\n", - " cache_base = CacheBase('sqlite')\n", - " vector_base = VectorBase('faiss', dimension=onnx.dimension)\n", - " data_manager = get_data_manager(cache_base, vector_base, max_size=10, clean_size=2)\n", - " cache_obj.init(\n", - " pre_embedding_func=get_prompt,\n", - " embedding_func=onnx.to_embeddings,\n", - " data_manager=data_manager,\n", - " similarity_evaluation=SearchDistanceEvaluation(),\n", - " )\n", - " i += 1\n", - "\n", - "langchain.llm_cache = GPTCache(init_gptcache_map)" + "langchain.llm_cache = GPTCache(init_gptcache)" ] }, { @@ -929,7 +903,7 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3 (ipykernel)", + "display_name": "Python 3", "language": "python", "name": "python3" }, @@ -943,7 +917,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.1" + "version": "3.8.8" } }, "nbformat": 4, diff --git a/langchain/cache.py b/langchain/cache.py index ce1e7306..3d89233d 100644 --- a/langchain/cache.py +++ b/langchain/cache.py @@ -1,8 +1,9 @@ """Beta Feature: base interface for cache.""" import hashlib +import inspect import json from abc import ABC, abstractmethod -from typing import Any, Callable, Dict, List, Optional, Tuple, Type, cast +from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union, cast from sqlalchemy import Column, Integer, String, create_engine, select from sqlalchemy.engine.base import Engine @@ -274,7 +275,12 @@ class RedisSemanticCache(BaseCache): class GPTCache(BaseCache): """Cache that uses GPTCache as a backend.""" - def __init__(self, init_func: Optional[Callable[[Any], None]] = None): + def __init__( + self, + init_func: Union[ + Callable[[Any, str], None], Callable[[Any], None], None + ] = None, + ): """Initialize by passing in init function (default: `None`). Args: @@ -291,19 +297,17 @@ class GPTCache(BaseCache): # Avoid multiple caches using the same file, causing different llm model caches to affect each other - i = 0 - file_prefix = "data_map" - def init_gptcache_map(cache_obj: gptcache.Cache): - nonlocal i - cache_path = f'{file_prefix}_{i}.txt' + def init_gptcache(cache_obj: gptcache.Cache, llm str): cache_obj.init( pre_embedding_func=get_prompt, - data_manager=get_data_manager(data_path=cache_path), + data_manager=manager_factory( + manager="map", + data_dir=f"map_cache_{llm}" + ), ) - i += 1 - langchain.llm_cache = GPTCache(init_gptcache_map) + langchain.llm_cache = GPTCache(init_gptcache) """ try: @@ -314,30 +318,38 @@ class GPTCache(BaseCache): "Please install it with `pip install gptcache`." ) - self.init_gptcache_func: Optional[Callable[[Any], None]] = init_func + self.init_gptcache_func: Union[ + Callable[[Any, str], None], Callable[[Any], None], None + ] = init_func self.gptcache_dict: Dict[str, Any] = {} - def _get_gptcache(self, llm_string: str) -> Any: - """Get a cache object. - - When the corresponding llm model cache does not exist, it will be created.""" + def _new_gptcache(self, llm_string: str) -> Any: + """New gptcache object""" from gptcache import Cache from gptcache.manager.factory import get_data_manager from gptcache.processor.pre import get_prompt - _gptcache = self.gptcache_dict.get(llm_string, None) - if _gptcache is None: - _gptcache = Cache() - if self.init_gptcache_func is not None: - self.init_gptcache_func(_gptcache) + _gptcache = Cache() + if self.init_gptcache_func is not None: + sig = inspect.signature(self.init_gptcache_func) + if len(sig.parameters) == 2: + self.init_gptcache_func(_gptcache, llm_string) # type: ignore[call-arg] else: - _gptcache.init( - pre_embedding_func=get_prompt, - data_manager=get_data_manager(data_path=llm_string), - ) - self.gptcache_dict[llm_string] = _gptcache + self.init_gptcache_func(_gptcache) # type: ignore[call-arg] + else: + _gptcache.init( + pre_embedding_func=get_prompt, + data_manager=get_data_manager(data_path=llm_string), + ) return _gptcache + def _get_gptcache(self, llm_string: str) -> Any: + """Get a cache object. + + When the corresponding llm model cache does not exist, it will be created.""" + + return self.gptcache_dict.get(llm_string, self._new_gptcache(llm_string)) + def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]: """Look up the cache data. First, retrieve the corresponding cache object using the `llm_string` parameter, diff --git a/tests/integration_tests/cache/test_gptcache.py b/tests/integration_tests/cache/test_gptcache.py index 471f959b..823ec0c3 100644 --- a/tests/integration_tests/cache/test_gptcache.py +++ b/tests/integration_tests/cache/test_gptcache.py @@ -1,5 +1,5 @@ import os -from typing import Any, Callable, Optional +from typing import Any, Callable, Union import pytest @@ -30,9 +30,23 @@ def init_gptcache_map(cache_obj: Cache) -> None: init_gptcache_map._i = i + 1 # type: ignore +def init_gptcache_map_with_llm(cache_obj: Cache, llm: str) -> None: + cache_path = f"data_map_{llm}.txt" + if os.path.isfile(cache_path): + os.remove(cache_path) + cache_obj.init( + pre_embedding_func=get_prompt, + data_manager=get_data_manager(data_path=cache_path), + ) + + @pytest.mark.skipif(not gptcache_installed, reason="gptcache not installed") -@pytest.mark.parametrize("init_func", [None, init_gptcache_map]) -def test_gptcache_caching(init_func: Optional[Callable[[Any], None]]) -> None: +@pytest.mark.parametrize( + "init_func", [None, init_gptcache_map, init_gptcache_map_with_llm] +) +def test_gptcache_caching( + init_func: Union[Callable[[Any, str], None], Callable[[Any], None], None] +) -> None: """Test gptcache default caching behavior.""" langchain.llm_cache = GPTCache(init_func) llm = FakeLLM()