Harrison/fix and test caching (#538)

This commit is contained in:
Harrison Chase 2023-01-04 18:39:06 -08:00 committed by GitHub
parent 73f7ebd9d1
commit 1631981f84
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 28 additions and 1 deletions

View File

@ -96,7 +96,7 @@ class BaseLLM(BaseModel, ABC):
new_results = self._generate(missing_prompts, stop=stop) new_results = self._generate(missing_prompts, stop=stop)
self.callback_manager.on_llm_end(new_results) self.callback_manager.on_llm_end(new_results)
for i, result in enumerate(new_results.generations): for i, result in enumerate(new_results.generations):
existing_prompts[i] = result existing_prompts[missing_prompt_idxs[i]] = result
prompt = prompts[i] prompt = prompts[i]
langchain.llm_cache.update(prompt, llm_string, result) langchain.llm_cache.update(prompt, llm_string, result)
generations = [existing_prompts[i] for i in range(len(prompts))] generations = [existing_prompts[i] for i in range(len(prompts))]

View File

@ -0,0 +1,27 @@
"""Test base LLM functionality."""
import langchain
from langchain.cache import InMemoryCache
from langchain.schema import Generation, LLMResult
from tests.unit_tests.llms.fake_llm import FakeLLM
def test_caching() -> None:
"""Test caching behavior."""
langchain.llm_cache = InMemoryCache()
llm = FakeLLM()
params = llm._llm_dict()
params["stop"] = None
llm_string = str(sorted([(k, v) for k, v in params.items()]))
langchain.llm_cache.update("foo", llm_string, [Generation(text="fizz")])
output = llm.generate(["foo", "bar", "foo"])
langchain.llm_cache = None
expected_generations = [
[Generation(text="fizz")],
[Generation(text="foo")],
[Generation(text="fizz")],
]
expected_output = LLMResult(
expected_generations,
llm_output=None,
)
assert output == expected_output