From 54d7f1c9330b054a357e5a438f78c13643f55013 Mon Sep 17 00:00:00 2001 From: Harrison Chase Date: Thu, 19 Jan 2023 15:33:45 -0800 Subject: [PATCH] fix caching (#658) --- langchain/cache.py | 2 +- tests/unit_tests/llms/test_base.py | 43 +++++++++++++++++++++++++++++- 2 files changed, 43 insertions(+), 2 deletions(-) diff --git a/langchain/cache.py b/langchain/cache.py index 869e155611..52c272c4fb 100644 --- a/langchain/cache.py +++ b/langchain/cache.py @@ -60,7 +60,7 @@ class SQLAlchemyCache(BaseCache): """Initialize by creating all tables.""" self.engine = engine self.cache_schema = cache_schema - Base.metadata.create_all(self.engine) + self.cache_schema.metadata.create_all(self.engine) def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]: """Look up based on prompt and llm_string.""" diff --git a/tests/unit_tests/llms/test_base.py b/tests/unit_tests/llms/test_base.py index 05f97d6efc..a5c9cc00ba 100644 --- a/tests/unit_tests/llms/test_base.py +++ b/tests/unit_tests/llms/test_base.py @@ -1,6 +1,9 @@ """Test base LLM functionality.""" +from sqlalchemy import Column, Integer, Sequence, String, create_engine +from sqlalchemy.ext.declarative import declarative_base + import langchain -from langchain.cache import InMemoryCache +from langchain.cache import InMemoryCache, SQLAlchemyCache from langchain.schema import Generation, LLMResult from tests.unit_tests.llms.fake_llm import FakeLLM @@ -28,3 +31,41 @@ def test_caching() -> None: llm_output=None, ) assert output == expected_output + + +def test_custom_caching() -> None: + """Test custom_caching behavior.""" + Base = declarative_base() + + class FulltextLLMCache(Base): # type: ignore + """Postgres table for fulltext-indexed LLM Cache.""" + + __tablename__ = "llm_cache_fulltext" + id = Column(Integer, Sequence("cache_id"), primary_key=True) + prompt = Column(String, nullable=False) + llm = Column(String, nullable=False) + idx = Column(Integer) + response = Column(String) + + engine = create_engine("sqlite://") + langchain.llm_cache = SQLAlchemyCache(engine, FulltextLLMCache) + llm = FakeLLM() + params = llm._llm_dict() + params["stop"] = None + llm_string = str(sorted([(k, v) for k, v in params.items()])) + langchain.llm_cache.update("foo", llm_string, [Generation(text="fizz")]) + output = llm.generate(["foo", "bar", "foo"]) + expected_cache_output = [Generation(text="foo")] + cache_output = langchain.llm_cache.lookup("bar", llm_string) + assert cache_output == expected_cache_output + langchain.llm_cache = None + expected_generations = [ + [Generation(text="fizz")], + [Generation(text="foo")], + [Generation(text="fizz")], + ] + expected_output = LLMResult( + expected_generations, + llm_output=None, + ) + assert output == expected_output