forked from Archives/langchain
b7747017d7
In [pyproject.toml](https://github.com/hwchase17/langchain/blob/master/pyproject.toml), the expectation is `SQLAlchemy = "^1"`. But, the way `declarative_base` is imported in [cache.py](https://github.com/hwchase17/langchain/blob/master/langchain/cache.py) will only work with SQLAlchemy >=1.4. This PR makes sure Langchain can be run in environments with SQLAlchemy <1.4
76 lines
2.5 KiB
Python
76 lines
2.5 KiB
Python
"""Test base LLM functionality."""
|
|
from sqlalchemy import Column, Integer, Sequence, String, create_engine
|
|
|
|
try:
|
|
from sqlalchemy.orm import declarative_base
|
|
except ImportError:
|
|
from sqlalchemy.ext.declarative import declarative_base
|
|
|
|
import langchain
|
|
from langchain.cache import InMemoryCache, SQLAlchemyCache
|
|
from langchain.schema import Generation, LLMResult
|
|
from tests.unit_tests.llms.fake_llm import FakeLLM
|
|
|
|
|
|
def test_caching() -> None:
|
|
"""Test caching behavior."""
|
|
langchain.llm_cache = InMemoryCache()
|
|
llm = FakeLLM()
|
|
params = llm.dict()
|
|
params["stop"] = None
|
|
llm_string = str(sorted([(k, v) for k, v in params.items()]))
|
|
langchain.llm_cache.update("foo", llm_string, [Generation(text="fizz")])
|
|
output = llm.generate(["foo", "bar", "foo"])
|
|
expected_cache_output = [Generation(text="foo")]
|
|
cache_output = langchain.llm_cache.lookup("bar", llm_string)
|
|
assert cache_output == expected_cache_output
|
|
langchain.llm_cache = None
|
|
expected_generations = [
|
|
[Generation(text="fizz")],
|
|
[Generation(text="foo")],
|
|
[Generation(text="fizz")],
|
|
]
|
|
expected_output = LLMResult(
|
|
expected_generations,
|
|
llm_output=None,
|
|
)
|
|
assert output == expected_output
|
|
|
|
|
|
def test_custom_caching() -> None:
|
|
"""Test custom_caching behavior."""
|
|
Base = declarative_base()
|
|
|
|
class FulltextLLMCache(Base): # type: ignore
|
|
"""Postgres table for fulltext-indexed LLM Cache."""
|
|
|
|
__tablename__ = "llm_cache_fulltext"
|
|
id = Column(Integer, Sequence("cache_id"), primary_key=True)
|
|
prompt = Column(String, nullable=False)
|
|
llm = Column(String, nullable=False)
|
|
idx = Column(Integer)
|
|
response = Column(String)
|
|
|
|
engine = create_engine("sqlite://")
|
|
langchain.llm_cache = SQLAlchemyCache(engine, FulltextLLMCache)
|
|
llm = FakeLLM()
|
|
params = llm.dict()
|
|
params["stop"] = None
|
|
llm_string = str(sorted([(k, v) for k, v in params.items()]))
|
|
langchain.llm_cache.update("foo", llm_string, [Generation(text="fizz")])
|
|
output = llm.generate(["foo", "bar", "foo"])
|
|
expected_cache_output = [Generation(text="foo")]
|
|
cache_output = langchain.llm_cache.lookup("bar", llm_string)
|
|
assert cache_output == expected_cache_output
|
|
langchain.llm_cache = None
|
|
expected_generations = [
|
|
[Generation(text="fizz")],
|
|
[Generation(text="foo")],
|
|
[Generation(text="fizz")],
|
|
]
|
|
expected_output = LLMResult(
|
|
expected_generations,
|
|
llm_output=None,
|
|
)
|
|
assert output == expected_output
|