Allow clearing cache and fix gptcache (#3493)

This PR

* Adds `clear` method for `BaseCache` and implements it for various
caches
* Adds the default `init_func=None` and fixes gptcache integtest
* Since right now integtest is not running in CI, I've verified the
changes by running `docs/modules/models/llms/examples/llm_caching.ipynb`
(until proper e2e integtest is done in CI)
fix_agent_callbacks
Ehsan M. Kermani 1 year ago committed by GitHub
parent 83e871f1ff
commit 4a246e2fd6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

6
.gitignore vendored

@ -144,4 +144,8 @@ wandb/
/.ruff_cache/
*.pkl
*.bin
*.bin
# integration test artifacts
data_map*
\[('_type', 'fake'), ('stop', None)]

@ -785,7 +785,9 @@
"id": "9df0dab8",
"metadata": {},
"outputs": [],
"source": []
"source": [
"!rm .langchain.db sqlite.db"
]
}
],
"metadata": {

@ -1,7 +1,7 @@
"""Beta Feature: base interface for cache."""
import json
from abc import ABC, abstractmethod
from typing import Any, Callable, Dict, List, Optional, Tuple
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, cast
from sqlalchemy import Column, Integer, String, create_engine, select
from sqlalchemy.engine.base import Engine
@ -28,6 +28,10 @@ class BaseCache(ABC):
def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None:
"""Update cache based on prompt and llm_string."""
@abstractmethod
def clear(self, **kwargs: Any) -> None:
"""Clear cache that can take additional keyword arguments."""
class InMemoryCache(BaseCache):
"""Cache that stores things in memory."""
@ -44,6 +48,10 @@ class InMemoryCache(BaseCache):
"""Update cache based on prompt and llm_string."""
self._cache[(prompt, llm_string)] = return_val
def clear(self, **kwargs: Any) -> None:
"""Clear cache."""
self._cache = {}
Base = declarative_base()
@ -61,7 +69,7 @@ class FullLLMCache(Base): # type: ignore
class SQLAlchemyCache(BaseCache):
"""Cache that uses SQAlchemy as a backend."""
def __init__(self, engine: Engine, cache_schema: Any = FullLLMCache):
def __init__(self, engine: Engine, cache_schema: Type[FullLLMCache] = FullLLMCache):
"""Initialize by creating all tables."""
self.engine = engine
self.cache_schema = cache_schema
@ -76,20 +84,26 @@ class SQLAlchemyCache(BaseCache):
.order_by(self.cache_schema.idx)
)
with Session(self.engine) as session:
generations = [Generation(text=row[0]) for row in session.execute(stmt)]
if len(generations) > 0:
return generations
rows = session.execute(stmt).fetchall()
if rows:
return [Generation(text=row[0]) for row in rows]
return None
def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None:
"""Look up based on prompt and llm_string."""
for i, generation in enumerate(return_val):
item = self.cache_schema(
prompt=prompt, llm=llm_string, response=generation.text, idx=i
)
with Session(self.engine) as session, session.begin():
"""Update based on prompt and llm_string."""
items = [
self.cache_schema(prompt=prompt, llm=llm_string, response=gen.text, idx=i)
for i, gen in enumerate(return_val)
]
with Session(self.engine) as session, session.begin():
for item in items:
session.merge(item)
def clear(self, **kwargs: Any) -> None:
"""Clear cache."""
with Session(self.engine) as session:
session.execute(self.cache_schema.delete())
class SQLiteCache(SQLAlchemyCache):
"""Cache that uses SQLite as a backend."""
@ -139,19 +153,26 @@ class RedisCache(BaseCache):
for i, generation in enumerate(return_val):
self.redis.set(self._key(prompt, llm_string, i), generation.text)
def clear(self, **kwargs: Any) -> None:
"""Clear cache. If `asynchronous` is True, flush asynchronously."""
asynchronous = kwargs.get("asynchronous", False)
self.redis.flushdb(asynchronous=asynchronous, **kwargs)
class GPTCache(BaseCache):
"""Cache that uses GPTCache as a backend."""
def __init__(self, init_func: Callable[[Any], None]):
"""Initialize by passing in the `init` GPTCache func
def __init__(self, init_func: Optional[Callable[[Any], None]] = None):
"""Initialize by passing in init function (default: `None`).
Args:
init_func (Callable[[Any], None]): init `GPTCache` function
init_func (Optional[Callable[[Any], None]]): init `GPTCache` function
(default: `None`)
Example:
.. code-block:: python
# Initialize GPTCache with a custom init function
import gptcache
from gptcache.processor.pre import get_prompt
from gptcache.manager.factory import get_data_manager
@ -180,7 +201,8 @@ class GPTCache(BaseCache):
"Could not import gptcache python package. "
"Please install it with `pip install gptcache`."
)
self.init_gptcache_func: Callable[[Any], None] = init_func
self.init_gptcache_func: Optional[Callable[[Any], None]] = init_func
self.gptcache_dict: Dict[str, Any] = {}
@staticmethod
@ -205,11 +227,19 @@ class GPTCache(BaseCache):
When the corresponding llm model cache does not exist, it will be created."""
from gptcache import Cache
from gptcache.manager.factory import get_data_manager
from gptcache.processor.pre import get_prompt
_gptcache = self.gptcache_dict.get(llm_string, None)
if _gptcache is None:
_gptcache = Cache()
self.init_gptcache_func(_gptcache)
if self.init_gptcache_func is not None:
self.init_gptcache_func(_gptcache)
else:
_gptcache.init(
pre_embedding_func=get_prompt,
data_manager=get_data_manager(data_path=llm_string),
)
self.gptcache_dict[llm_string] = _gptcache
return _gptcache
@ -220,7 +250,7 @@ class GPTCache(BaseCache):
"""
from gptcache.adapter.adapter import adapt
_gptcache = self.gptcache_dict.get(llm_string)
_gptcache = self.gptcache_dict.get(llm_string, None)
if _gptcache is None:
return None
res = adapt(
@ -234,7 +264,10 @@ class GPTCache(BaseCache):
@staticmethod
def _update_cache_callback(
llm_data: RETURN_VAL_TYPE, update_cache_func: Callable[[Any], None]
llm_data: RETURN_VAL_TYPE,
update_cache_func: Callable[[Any], None],
*args: Any,
**kwargs: Any,
) -> None:
"""Save the `llm_data` to cache storage"""
handled_data = json.dumps([generation.dict() for generation in llm_data])
@ -260,3 +293,13 @@ class GPTCache(BaseCache):
cache_skip=True,
prompt=prompt,
)
def clear(self, **kwargs: Any) -> None:
"""Clear cache."""
from gptcache import Cache
for gptcache_instance in self.gptcache_dict.values():
gptcache_instance = cast(Cache, gptcache_instance)
gptcache_instance.flush()
self.gptcache_dict.clear()

@ -235,4 +235,5 @@ class ConversationEntityMemory(BaseChatMemory):
def clear(self) -> None:
"""Clear memory contents."""
self.chat_memory.clear()
self.entity_cache.clear()
self.entity_store.clear()

@ -1,61 +1,48 @@
import os
from typing import Any, Callable, Optional
import pytest
import langchain
from langchain.cache import GPTCache
from langchain.schema import Generation, LLMResult
from langchain.schema import Generation
from tests.unit_tests.llms.fake_llm import FakeLLM
try:
import gptcache # noqa: F401
from gptcache import Cache # noqa: F401
from gptcache.manager.factory import get_data_manager
from gptcache.processor.pre import get_prompt
gptcache_installed = True
except ImportError:
gptcache_installed = False
@pytest.mark.skipif(not gptcache_installed, reason="gptcache not installed")
def test_gptcache_map_caching() -> None:
"""Test gptcache caching behavior."""
from gptcache import Cache
from gptcache.manager.factory import get_data_manager
from gptcache.processor.pre import get_prompt
i = 0
file_prefix = "data_map"
def init_gptcache_map(cache_obj: Cache) -> None:
nonlocal i
cache_path = f"{file_prefix}_{i}.txt"
if os.path.isfile(cache_path):
os.remove(cache_path)
cache_obj.init(
pre_embedding_func=get_prompt,
data_manager=get_data_manager(data_path=cache_path),
)
i += 1
def init_gptcache_map(cache_obj: Cache) -> None:
i = getattr(init_gptcache_map, "_i", 0)
cache_path = f"data_map_{i}.txt"
if os.path.isfile(cache_path):
os.remove(cache_path)
cache_obj.init(
pre_embedding_func=get_prompt,
data_manager=get_data_manager(data_path=cache_path),
)
init_gptcache_map._i = i + 1 # type: ignore
langchain.llm_cache = GPTCache(init_gptcache_map)
@pytest.mark.skipif(not gptcache_installed, reason="gptcache not installed")
@pytest.mark.parametrize("init_func", [None, init_gptcache_map])
def test_gptcache_caching(init_func: Optional[Callable[[Any], None]]) -> None:
"""Test gptcache default caching behavior."""
langchain.llm_cache = GPTCache(init_func)
llm = FakeLLM()
params = llm.dict()
params["stop"] = None
llm_string = str(sorted([(k, v) for k, v in params.items()]))
langchain.llm_cache.update("foo", llm_string, [Generation(text="fizz")])
output = llm.generate(["foo", "bar", "foo"])
expected_cache_output = [Generation(text="foo")]
cache_output = langchain.llm_cache.lookup("bar", llm_string)
assert cache_output == expected_cache_output
langchain.llm_cache = None
expected_generations = [
[Generation(text="fizz")],
[Generation(text="foo")],
[Generation(text="fizz")],
]
expected_output = LLMResult(
generations=expected_generations,
llm_output=None,
)
assert output == expected_output
_ = llm.generate(["foo", "bar", "foo"])
cache_output = langchain.llm_cache.lookup("foo", llm_string)
assert cache_output == [Generation(text="fizz")]
langchain.llm_cache.clear()
assert langchain.llm_cache.lookup("bar", llm_string) is None

Loading…
Cancel
Save