Use the GPTCache api interface (#3693)

Use the GPTCache api interface to reduce the possibility of
compatibility issues
This commit is contained in:
SimFG 2023-04-29 11:18:51 +08:00 committed by GitHub
parent f37a932b24
commit 5998b53596
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -205,23 +205,6 @@ class GPTCache(BaseCache):
self.init_gptcache_func: Optional[Callable[[Any], None]] = init_func self.init_gptcache_func: Optional[Callable[[Any], None]] = init_func
self.gptcache_dict: Dict[str, Any] = {} self.gptcache_dict: Dict[str, Any] = {}
@staticmethod
def _update_cache_callback_none(*_: Any, **__: Any) -> None:
"""When updating cached data, do nothing.
Because currently only cached queries are processed."""
return None
@staticmethod
def _llm_handle_none(*_: Any, **__: Any) -> None:
"""Do nothing on a cache miss"""
return None
@staticmethod
def _cache_data_converter(data: str) -> RETURN_VAL_TYPE:
"""Convert the `data` in the cache to the `RETURN_VAL_TYPE` data format."""
return [Generation(**generation_dict) for generation_dict in json.loads(data)]
def _get_gptcache(self, llm_string: str) -> Any: def _get_gptcache(self, llm_string: str) -> Any:
"""Get a cache object. """Get a cache object.
@ -248,51 +231,29 @@ class GPTCache(BaseCache):
First, retrieve the corresponding cache object using the `llm_string` parameter, First, retrieve the corresponding cache object using the `llm_string` parameter,
and then retrieve the data from the cache based on the `prompt`. and then retrieve the data from the cache based on the `prompt`.
""" """
from gptcache.adapter.adapter import adapt from gptcache.adapter.api import get
_gptcache = self.gptcache_dict.get(llm_string, None) _gptcache = self.gptcache_dict.get(llm_string, None)
if _gptcache is None: if _gptcache is None:
return None return None
res = adapt( res = get(prompt, cache_obj=_gptcache)
GPTCache._llm_handle_none, if res:
GPTCache._cache_data_converter, return [
GPTCache._update_cache_callback_none, Generation(**generation_dict) for generation_dict in json.loads(res)
cache_obj=_gptcache, ]
prompt=prompt, return None
)
return res
@staticmethod
def _update_cache_callback(
llm_data: RETURN_VAL_TYPE,
update_cache_func: Callable[[Any], None],
*args: Any,
**kwargs: Any,
) -> None:
"""Save the `llm_data` to cache storage"""
handled_data = json.dumps([generation.dict() for generation in llm_data])
update_cache_func(handled_data)
def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None: def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None:
"""Update cache. """Update cache.
First, retrieve the corresponding cache object using the `llm_string` parameter, First, retrieve the corresponding cache object using the `llm_string` parameter,
and then store the `prompt` and `return_val` in the cache object. and then store the `prompt` and `return_val` in the cache object.
""" """
from gptcache.adapter.adapter import adapt from gptcache.adapter.api import put
_gptcache = self._get_gptcache(llm_string) _gptcache = self._get_gptcache(llm_string)
handled_data = json.dumps([generation.dict() for generation in return_val])
def llm_handle(*_: Any, **__: Any) -> RETURN_VAL_TYPE: put(prompt, handled_data, cache_obj=_gptcache)
return return_val return None
return adapt(
llm_handle,
GPTCache._cache_data_converter,
GPTCache._update_cache_callback,
cache_obj=_gptcache,
cache_skip=True,
prompt=prompt,
)
def clear(self, **kwargs: Any) -> None: def clear(self, **kwargs: Any) -> None:
"""Clear cache.""" """Clear cache."""