From 364f8e7b5d3fc6db1fd58626d95eef3abed9ae5b Mon Sep 17 00:00:00 2001 From: Luke Stanley <306671+lukestanley@users.noreply.github.com> Date: Sat, 17 Jun 2023 01:08:44 +0000 Subject: [PATCH] Better Entity Memory code documentation (#6318) Just adds some comments and docstring improvements. There was some behaviour that was quite unclear to me at first like: - "when do things get updated?" - "why are there only entity names and no summaries?" - "why do the entity names disappear?" Now it can be much more obvious to many. I am lukestanley on Twitter. --- langchain/memory/entity.py | 68 ++++++++++++++++++++++++++++++++++++-- 1 file changed, 65 insertions(+), 3 deletions(-) diff --git a/langchain/memory/entity.py b/langchain/memory/entity.py index 27f903b18f..759da031e8 100644 --- a/langchain/memory/entity.py +++ b/langchain/memory/entity.py @@ -241,20 +241,35 @@ class SQLiteEntityStore(BaseEntityStore): class ConversationEntityMemory(BaseChatMemory): - """Entity extractor & summarizer to memory.""" + """Entity extractor & summarizer memory. + + Extracts named entities from the recent chat history and generates summaries. + With a swapable entity store, persisting entities across conversations. + Defaults to an in-memory entity store, and can be swapped out for a Redis, + SQLite, or other entity store. + """ human_prefix: str = "Human" ai_prefix: str = "AI" llm: BaseLanguageModel entity_extraction_prompt: BasePromptTemplate = ENTITY_EXTRACTION_PROMPT entity_summarization_prompt: BasePromptTemplate = ENTITY_SUMMARIZATION_PROMPT + + # Cache of recently detected entity names, if any + # It is updated when load_memory_variables is called: entity_cache: List[str] = [] + + # Number of recent message pairs to consider when updating entities: k: int = 3 + chat_history_key: str = "history" + + # Store to manage entity-related data: entity_store: BaseEntityStore = Field(default_factory=InMemoryEntityStore) @property def buffer(self) -> List[BaseMessage]: + """Access chat memory messages.""" return self.chat_memory.messages @property @@ -266,40 +281,78 @@ class ConversationEntityMemory(BaseChatMemory): return ["entities", self.chat_history_key] def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]: - """Return history buffer.""" + """ + Returns chat history and all generated entities with summaries if available, + and updates or clears the recent entity cache. + + New entity name can be found when calling this method, before the entity + summaries are generated, so the entity cache values may be empty if no entity + descriptions are generated yet. + """ + + # Create an LLMChain for predicting entity names from the recent chat history: chain = LLMChain(llm=self.llm, prompt=self.entity_extraction_prompt) + if self.input_key is None: prompt_input_key = get_prompt_input_key(inputs, self.memory_variables) else: prompt_input_key = self.input_key + + # Extract an arbitrary window of the last message pairs from + # the chat history, where the hyperparameter k is the + # number of message pairs: buffer_string = get_buffer_string( self.buffer[-self.k * 2 :], human_prefix=self.human_prefix, ai_prefix=self.ai_prefix, ) + + # Generates a comma-separated list of named entities, + # e.g. "Jane, White House, UFO" + # or "NONE" if no named entities are extracted: output = chain.predict( history=buffer_string, input=inputs[prompt_input_key], ) + + # If no named entities are extracted, assigns an empty list. if output.strip() == "NONE": entities = [] else: + # Make a list of the extracted entities: entities = [w.strip() for w in output.split(",")] + + # Make a dictionary of entities with summary if exists: entity_summaries = {} + for entity in entities: entity_summaries[entity] = self.entity_store.get(entity, "") + + # Replaces the entity name cache with the most recently discussed entities, + # or if no entities were extracted, clears the cache: self.entity_cache = entities + + # Should we return as message objects or as a string? if self.return_messages: + # Get last `k` pair of chat messages: buffer: Any = self.buffer[-self.k * 2 :] else: + # Reuse the string we made earlier: buffer = buffer_string + return { self.chat_history_key: buffer, "entities": entity_summaries, } def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None: - """Save context from this conversation to buffer.""" + """ + Save context from this conversation history to the entity store. + + Generates a summary for each entity in the entity cache by prompting + the model, and saves these summaries to the entity store. + """ + super().save_context(inputs, outputs) if self.input_key is None: @@ -307,15 +360,23 @@ class ConversationEntityMemory(BaseChatMemory): else: prompt_input_key = self.input_key + # Extract an arbitrary window of the last message pairs from + # the chat history, where the hyperparameter k is the + # number of message pairs: buffer_string = get_buffer_string( self.buffer[-self.k * 2 :], human_prefix=self.human_prefix, ai_prefix=self.ai_prefix, ) + input_data = inputs[prompt_input_key] + + # Create an LLMChain for predicting entity summarization from the context chain = LLMChain(llm=self.llm, prompt=self.entity_summarization_prompt) + # Generate new summaries for entities and save them in the entity store for entity in self.entity_cache: + # Get existing summary if it exists existing_summary = self.entity_store.get(entity, "") output = chain.predict( summary=existing_summary, @@ -323,6 +384,7 @@ class ConversationEntityMemory(BaseChatMemory): history=buffer_string, input=input_data, ) + # Save the updated summary to the entity store self.entity_store.set(entity, output.strip()) def clear(self) -> None: