Harrison/fix save context (#2082)

Co-authored-by: Saurabh Misra <misra.saurabh1@gmail.com>
This commit is contained in:
Harrison Chase 2023-03-27 23:10:46 -07:00 committed by GitHub
parent b7f392fdd6
commit 365669a7fd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -74,25 +74,27 @@ class ConversationEntityMemory(BaseChatMemory, BaseModel):
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None: def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
"""Save context from this conversation to buffer.""" """Save context from this conversation to buffer."""
super().save_context(inputs, outputs) super().save_context(inputs, outputs)
if self.input_key is None: if self.input_key is None:
prompt_input_key = get_prompt_input_key(inputs, self.memory_variables) prompt_input_key = get_prompt_input_key(inputs, self.memory_variables)
else: else:
prompt_input_key = self.input_key prompt_input_key = self.input_key
for entity in self.entity_cache:
chain = LLMChain(llm=self.llm, prompt=self.entity_summarization_prompt)
# key value store for entity
existing_summary = self.store.get(entity, "")
buffer_string = get_buffer_string(
self.buffer[-self.k * 2 :],
human_prefix=self.human_prefix,
ai_prefix=self.ai_prefix,
)
buffer_string = get_buffer_string(
self.buffer[-self.k * 2 :],
human_prefix=self.human_prefix,
ai_prefix=self.ai_prefix,
)
input_data = inputs[prompt_input_key]
chain = LLMChain(llm=self.llm, prompt=self.entity_summarization_prompt)
for entity in self.entity_cache:
existing_summary = self.store.get(entity, "")
output = chain.predict( output = chain.predict(
summary=existing_summary, summary=existing_summary,
history=buffer_string,
input=inputs[prompt_input_key],
entity=entity, entity=entity,
history=buffer_string,
input=input_data,
) )
self.store[entity] = output.strip() self.store[entity] = output.strip()