langchain/tests/unit_tests/llms/test_base.py

72 lines
2.4 KiB
Python
Raw Normal View History

2023-01-05 02:39:06 +00:00
"""Test base LLM functionality."""
2023-01-19 23:33:45 +00:00
from sqlalchemy import Column, Integer, Sequence, String, create_engine
from sqlalchemy.ext.declarative import declarative_base
2023-01-05 02:39:06 +00:00
import langchain
2023-01-19 23:33:45 +00:00
from langchain.cache import InMemoryCache, SQLAlchemyCache
2023-01-05 02:39:06 +00:00
from langchain.schema import Generation, LLMResult
from tests.unit_tests.llms.fake_llm import FakeLLM
def test_caching() -> None:
"""Test caching behavior."""
langchain.llm_cache = InMemoryCache()
llm = FakeLLM()
2023-01-25 05:36:19 +00:00
params = llm.dict()
2023-01-05 02:39:06 +00:00
params["stop"] = None
llm_string = str(sorted([(k, v) for k, v in params.items()]))
langchain.llm_cache.update("foo", llm_string, [Generation(text="fizz")])
output = llm.generate(["foo", "bar", "foo"])
2023-01-06 15:30:10 +00:00
expected_cache_output = [Generation(text="foo")]
cache_output = langchain.llm_cache.lookup("bar", llm_string)
assert cache_output == expected_cache_output
2023-01-05 02:39:06 +00:00
langchain.llm_cache = None
expected_generations = [
[Generation(text="fizz")],
[Generation(text="foo")],
[Generation(text="fizz")],
]
expected_output = LLMResult(
expected_generations,
llm_output=None,
)
assert output == expected_output
2023-01-19 23:33:45 +00:00
def test_custom_caching() -> None:
"""Test custom_caching behavior."""
Base = declarative_base()
class FulltextLLMCache(Base): # type: ignore
"""Postgres table for fulltext-indexed LLM Cache."""
__tablename__ = "llm_cache_fulltext"
id = Column(Integer, Sequence("cache_id"), primary_key=True)
prompt = Column(String, nullable=False)
llm = Column(String, nullable=False)
idx = Column(Integer)
response = Column(String)
engine = create_engine("sqlite://")
langchain.llm_cache = SQLAlchemyCache(engine, FulltextLLMCache)
llm = FakeLLM()
2023-01-25 05:36:19 +00:00
params = llm.dict()
2023-01-19 23:33:45 +00:00
params["stop"] = None
llm_string = str(sorted([(k, v) for k, v in params.items()]))
langchain.llm_cache.update("foo", llm_string, [Generation(text="fizz")])
output = llm.generate(["foo", "bar", "foo"])
expected_cache_output = [Generation(text="foo")]
cache_output = langchain.llm_cache.lookup("bar", llm_string)
assert cache_output == expected_cache_output
langchain.llm_cache = None
expected_generations = [
[Generation(text="fizz")],
[Generation(text="foo")],
[Generation(text="fizz")],
]
expected_output = LLMResult(
expected_generations,
llm_output=None,
)
assert output == expected_output