From 4a246e2fd62d5beb348d0e2c0b4be1e6046ae3a6 Mon Sep 17 00:00:00 2001 From: "Ehsan M. Kermani" <6980212+ehsanmok@users.noreply.github.com> Date: Wed, 26 Apr 2023 22:03:50 -0700 Subject: [PATCH] Allow clearing cache and fix gptcache (#3493) This PR * Adds `clear` method for `BaseCache` and implements it for various caches * Adds the default `init_func=None` and fixes gptcache integtest * Since right now integtest is not running in CI, I've verified the changes by running `docs/modules/models/llms/examples/llm_caching.ipynb` (until proper e2e integtest is done in CI) --- .gitignore | 6 +- .../models/llms/examples/llm_caching.ipynb | 4 +- langchain/cache.py | 79 ++++++++++++++----- langchain/memory/entity.py | 1 + .../integration_tests/cache/test_gptcache.py | 65 ++++++--------- 5 files changed, 96 insertions(+), 59 deletions(-) diff --git a/.gitignore b/.gitignore index 4301f27d..69fb95f0 100644 --- a/.gitignore +++ b/.gitignore @@ -144,4 +144,8 @@ wandb/ /.ruff_cache/ *.pkl -*.bin \ No newline at end of file +*.bin + +# integration test artifacts +data_map* +\[('_type', 'fake'), ('stop', None)] \ No newline at end of file diff --git a/docs/modules/models/llms/examples/llm_caching.ipynb b/docs/modules/models/llms/examples/llm_caching.ipynb index 8b65d7ba..a37adaad 100644 --- a/docs/modules/models/llms/examples/llm_caching.ipynb +++ b/docs/modules/models/llms/examples/llm_caching.ipynb @@ -785,7 +785,9 @@ "id": "9df0dab8", "metadata": {}, "outputs": [], - "source": [] + "source": [ + "!rm .langchain.db sqlite.db" + ] } ], "metadata": { diff --git a/langchain/cache.py b/langchain/cache.py index 6f388827..74d32d30 100644 --- a/langchain/cache.py +++ b/langchain/cache.py @@ -1,7 +1,7 @@ """Beta Feature: base interface for cache.""" import json from abc import ABC, abstractmethod -from typing import Any, Callable, Dict, List, Optional, Tuple +from typing import Any, Callable, Dict, List, Optional, Tuple, Type, cast from sqlalchemy import Column, Integer, String, create_engine, select from sqlalchemy.engine.base import Engine @@ -28,6 +28,10 @@ class BaseCache(ABC): def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None: """Update cache based on prompt and llm_string.""" + @abstractmethod + def clear(self, **kwargs: Any) -> None: + """Clear cache that can take additional keyword arguments.""" + class InMemoryCache(BaseCache): """Cache that stores things in memory.""" @@ -44,6 +48,10 @@ class InMemoryCache(BaseCache): """Update cache based on prompt and llm_string.""" self._cache[(prompt, llm_string)] = return_val + def clear(self, **kwargs: Any) -> None: + """Clear cache.""" + self._cache = {} + Base = declarative_base() @@ -61,7 +69,7 @@ class FullLLMCache(Base): # type: ignore class SQLAlchemyCache(BaseCache): """Cache that uses SQAlchemy as a backend.""" - def __init__(self, engine: Engine, cache_schema: Any = FullLLMCache): + def __init__(self, engine: Engine, cache_schema: Type[FullLLMCache] = FullLLMCache): """Initialize by creating all tables.""" self.engine = engine self.cache_schema = cache_schema @@ -76,20 +84,26 @@ class SQLAlchemyCache(BaseCache): .order_by(self.cache_schema.idx) ) with Session(self.engine) as session: - generations = [Generation(text=row[0]) for row in session.execute(stmt)] - if len(generations) > 0: - return generations + rows = session.execute(stmt).fetchall() + if rows: + return [Generation(text=row[0]) for row in rows] return None def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None: - """Look up based on prompt and llm_string.""" - for i, generation in enumerate(return_val): - item = self.cache_schema( - prompt=prompt, llm=llm_string, response=generation.text, idx=i - ) - with Session(self.engine) as session, session.begin(): + """Update based on prompt and llm_string.""" + items = [ + self.cache_schema(prompt=prompt, llm=llm_string, response=gen.text, idx=i) + for i, gen in enumerate(return_val) + ] + with Session(self.engine) as session, session.begin(): + for item in items: session.merge(item) + def clear(self, **kwargs: Any) -> None: + """Clear cache.""" + with Session(self.engine) as session: + session.execute(self.cache_schema.delete()) + class SQLiteCache(SQLAlchemyCache): """Cache that uses SQLite as a backend.""" @@ -139,19 +153,26 @@ class RedisCache(BaseCache): for i, generation in enumerate(return_val): self.redis.set(self._key(prompt, llm_string, i), generation.text) + def clear(self, **kwargs: Any) -> None: + """Clear cache. If `asynchronous` is True, flush asynchronously.""" + asynchronous = kwargs.get("asynchronous", False) + self.redis.flushdb(asynchronous=asynchronous, **kwargs) + class GPTCache(BaseCache): """Cache that uses GPTCache as a backend.""" - def __init__(self, init_func: Callable[[Any], None]): - """Initialize by passing in the `init` GPTCache func + def __init__(self, init_func: Optional[Callable[[Any], None]] = None): + """Initialize by passing in init function (default: `None`). Args: - init_func (Callable[[Any], None]): init `GPTCache` function + init_func (Optional[Callable[[Any], None]]): init `GPTCache` function + (default: `None`) Example: .. code-block:: python + # Initialize GPTCache with a custom init function import gptcache from gptcache.processor.pre import get_prompt from gptcache.manager.factory import get_data_manager @@ -180,7 +201,8 @@ class GPTCache(BaseCache): "Could not import gptcache python package. " "Please install it with `pip install gptcache`." ) - self.init_gptcache_func: Callable[[Any], None] = init_func + + self.init_gptcache_func: Optional[Callable[[Any], None]] = init_func self.gptcache_dict: Dict[str, Any] = {} @staticmethod @@ -205,11 +227,19 @@ class GPTCache(BaseCache): When the corresponding llm model cache does not exist, it will be created.""" from gptcache import Cache + from gptcache.manager.factory import get_data_manager + from gptcache.processor.pre import get_prompt _gptcache = self.gptcache_dict.get(llm_string, None) if _gptcache is None: _gptcache = Cache() - self.init_gptcache_func(_gptcache) + if self.init_gptcache_func is not None: + self.init_gptcache_func(_gptcache) + else: + _gptcache.init( + pre_embedding_func=get_prompt, + data_manager=get_data_manager(data_path=llm_string), + ) self.gptcache_dict[llm_string] = _gptcache return _gptcache @@ -220,7 +250,7 @@ class GPTCache(BaseCache): """ from gptcache.adapter.adapter import adapt - _gptcache = self.gptcache_dict.get(llm_string) + _gptcache = self.gptcache_dict.get(llm_string, None) if _gptcache is None: return None res = adapt( @@ -234,7 +264,10 @@ class GPTCache(BaseCache): @staticmethod def _update_cache_callback( - llm_data: RETURN_VAL_TYPE, update_cache_func: Callable[[Any], None] + llm_data: RETURN_VAL_TYPE, + update_cache_func: Callable[[Any], None], + *args: Any, + **kwargs: Any, ) -> None: """Save the `llm_data` to cache storage""" handled_data = json.dumps([generation.dict() for generation in llm_data]) @@ -260,3 +293,13 @@ class GPTCache(BaseCache): cache_skip=True, prompt=prompt, ) + + def clear(self, **kwargs: Any) -> None: + """Clear cache.""" + from gptcache import Cache + + for gptcache_instance in self.gptcache_dict.values(): + gptcache_instance = cast(Cache, gptcache_instance) + gptcache_instance.flush() + + self.gptcache_dict.clear() diff --git a/langchain/memory/entity.py b/langchain/memory/entity.py index 8863c076..e4a1aed0 100644 --- a/langchain/memory/entity.py +++ b/langchain/memory/entity.py @@ -235,4 +235,5 @@ class ConversationEntityMemory(BaseChatMemory): def clear(self) -> None: """Clear memory contents.""" self.chat_memory.clear() + self.entity_cache.clear() self.entity_store.clear() diff --git a/tests/integration_tests/cache/test_gptcache.py b/tests/integration_tests/cache/test_gptcache.py index 8a7f6cdb..471f959b 100644 --- a/tests/integration_tests/cache/test_gptcache.py +++ b/tests/integration_tests/cache/test_gptcache.py @@ -1,61 +1,48 @@ import os +from typing import Any, Callable, Optional import pytest import langchain from langchain.cache import GPTCache -from langchain.schema import Generation, LLMResult +from langchain.schema import Generation from tests.unit_tests.llms.fake_llm import FakeLLM try: - import gptcache # noqa: F401 + from gptcache import Cache # noqa: F401 + from gptcache.manager.factory import get_data_manager + from gptcache.processor.pre import get_prompt gptcache_installed = True except ImportError: gptcache_installed = False -@pytest.mark.skipif(not gptcache_installed, reason="gptcache not installed") -def test_gptcache_map_caching() -> None: - """Test gptcache caching behavior.""" - - from gptcache import Cache - from gptcache.manager.factory import get_data_manager - from gptcache.processor.pre import get_prompt - - i = 0 - file_prefix = "data_map" - - def init_gptcache_map(cache_obj: Cache) -> None: - nonlocal i - cache_path = f"{file_prefix}_{i}.txt" - if os.path.isfile(cache_path): - os.remove(cache_path) - cache_obj.init( - pre_embedding_func=get_prompt, - data_manager=get_data_manager(data_path=cache_path), - ) - i += 1 +def init_gptcache_map(cache_obj: Cache) -> None: + i = getattr(init_gptcache_map, "_i", 0) + cache_path = f"data_map_{i}.txt" + if os.path.isfile(cache_path): + os.remove(cache_path) + cache_obj.init( + pre_embedding_func=get_prompt, + data_manager=get_data_manager(data_path=cache_path), + ) + init_gptcache_map._i = i + 1 # type: ignore - langchain.llm_cache = GPTCache(init_gptcache_map) +@pytest.mark.skipif(not gptcache_installed, reason="gptcache not installed") +@pytest.mark.parametrize("init_func", [None, init_gptcache_map]) +def test_gptcache_caching(init_func: Optional[Callable[[Any], None]]) -> None: + """Test gptcache default caching behavior.""" + langchain.llm_cache = GPTCache(init_func) llm = FakeLLM() params = llm.dict() params["stop"] = None llm_string = str(sorted([(k, v) for k, v in params.items()])) langchain.llm_cache.update("foo", llm_string, [Generation(text="fizz")]) - output = llm.generate(["foo", "bar", "foo"]) - expected_cache_output = [Generation(text="foo")] - cache_output = langchain.llm_cache.lookup("bar", llm_string) - assert cache_output == expected_cache_output - langchain.llm_cache = None - expected_generations = [ - [Generation(text="fizz")], - [Generation(text="foo")], - [Generation(text="fizz")], - ] - expected_output = LLMResult( - generations=expected_generations, - llm_output=None, - ) - assert output == expected_output + _ = llm.generate(["foo", "bar", "foo"]) + cache_output = langchain.llm_cache.lookup("foo", llm_string) + assert cache_output == [Generation(text="fizz")] + + langchain.llm_cache.clear() + assert langchain.llm_cache.lookup("bar", llm_string) is None