From 365669a7fd21dcfc9a6cae2c8b24fc63f5499054 Mon Sep 17 00:00:00 2001 From: Harrison Chase Date: Mon, 27 Mar 2023 23:10:46 -0700 Subject: [PATCH] Harrison/fix save context (#2082) Co-authored-by: Saurabh Misra --- langchain/memory/entity.py | 24 +++++++++++++----------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/langchain/memory/entity.py b/langchain/memory/entity.py index 73f0bc15..779a5e1e 100644 --- a/langchain/memory/entity.py +++ b/langchain/memory/entity.py @@ -74,25 +74,27 @@ class ConversationEntityMemory(BaseChatMemory, BaseModel): def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None: """Save context from this conversation to buffer.""" super().save_context(inputs, outputs) + if self.input_key is None: prompt_input_key = get_prompt_input_key(inputs, self.memory_variables) else: prompt_input_key = self.input_key - for entity in self.entity_cache: - chain = LLMChain(llm=self.llm, prompt=self.entity_summarization_prompt) - # key value store for entity - existing_summary = self.store.get(entity, "") - buffer_string = get_buffer_string( - self.buffer[-self.k * 2 :], - human_prefix=self.human_prefix, - ai_prefix=self.ai_prefix, - ) + buffer_string = get_buffer_string( + self.buffer[-self.k * 2 :], + human_prefix=self.human_prefix, + ai_prefix=self.ai_prefix, + ) + input_data = inputs[prompt_input_key] + chain = LLMChain(llm=self.llm, prompt=self.entity_summarization_prompt) + + for entity in self.entity_cache: + existing_summary = self.store.get(entity, "") output = chain.predict( summary=existing_summary, - history=buffer_string, - input=inputs[prompt_input_key], entity=entity, + history=buffer_string, + input=input_data, ) self.store[entity] = output.strip()