fix caching (#555)

pull/556/head^2
Harrison Chase 1 year ago committed by GitHub
parent 74932f2516
commit 9833fcfe32
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -109,7 +109,7 @@ class BaseLLM(BaseModel, ABC):
self.callback_manager.on_llm_end(new_results)
for i, result in enumerate(new_results.generations):
existing_prompts[missing_prompt_idxs[i]] = result
prompt = prompts[i]
prompt = prompts[missing_prompt_idxs[i]]
langchain.llm_cache.update(prompt, llm_string, result)
generations = [existing_prompts[i] for i in range(len(prompts))]
return LLMResult(generations=generations, llm_output=new_results.llm_output)

@ -14,6 +14,9 @@ def test_caching() -> None:
llm_string = str(sorted([(k, v) for k, v in params.items()]))
langchain.llm_cache.update("foo", llm_string, [Generation(text="fizz")])
output = llm.generate(["foo", "bar", "foo"])
expected_cache_output = [Generation(text="foo")]
cache_output = langchain.llm_cache.lookup("bar", llm_string)
assert cache_output == expected_cache_output
langchain.llm_cache = None
expected_generations = [
[Generation(text="fizz")],

Loading…
Cancel
Save