diff --git a/langchain/llms/base.py b/langchain/llms/base.py index 22e63b4908..90452ed5ce 100644 --- a/langchain/llms/base.py +++ b/langchain/llms/base.py @@ -96,7 +96,7 @@ class BaseLLM(BaseModel, ABC): new_results = self._generate(missing_prompts, stop=stop) self.callback_manager.on_llm_end(new_results) for i, result in enumerate(new_results.generations): - existing_prompts[i] = result + existing_prompts[missing_prompt_idxs[i]] = result prompt = prompts[i] langchain.llm_cache.update(prompt, llm_string, result) generations = [existing_prompts[i] for i in range(len(prompts))] diff --git a/tests/unit_tests/llms/test_base.py b/tests/unit_tests/llms/test_base.py new file mode 100644 index 0000000000..da67a9da85 --- /dev/null +++ b/tests/unit_tests/llms/test_base.py @@ -0,0 +1,27 @@ +"""Test base LLM functionality.""" +import langchain +from langchain.cache import InMemoryCache +from langchain.schema import Generation, LLMResult +from tests.unit_tests.llms.fake_llm import FakeLLM + + +def test_caching() -> None: + """Test caching behavior.""" + langchain.llm_cache = InMemoryCache() + llm = FakeLLM() + params = llm._llm_dict() + params["stop"] = None + llm_string = str(sorted([(k, v) for k, v in params.items()])) + langchain.llm_cache.update("foo", llm_string, [Generation(text="fizz")]) + output = llm.generate(["foo", "bar", "foo"]) + langchain.llm_cache = None + expected_generations = [ + [Generation(text="fizz")], + [Generation(text="foo")], + [Generation(text="fizz")], + ] + expected_output = LLMResult( + expected_generations, + llm_output=None, + ) + assert output == expected_output