Harrison/summary message rol (#1783)

Co-authored-by: Aratako <127325395+Aratako@users.noreply.github.com>
This commit is contained in:
Harrison Chase 2023-03-19 10:09:18 -07:00 committed by GitHub
parent 85e4dd7fc3
commit 951c158106
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 18 additions and 7 deletions

View File

@ -1,4 +1,4 @@
from typing import Any, Dict, List from typing import Any, Dict, List, Type, Union
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
@ -12,7 +12,12 @@ from langchain.memory.prompt import (
) )
from langchain.memory.utils import get_prompt_input_key from langchain.memory.utils import get_prompt_input_key
from langchain.prompts.base import BasePromptTemplate from langchain.prompts.base import BasePromptTemplate
from langchain.schema import BaseLanguageModel, SystemMessage, get_buffer_string from langchain.schema import (
BaseLanguageModel,
BaseMessage,
SystemMessage,
get_buffer_string,
)
class ConversationKGMemory(BaseChatMemory, BaseModel): class ConversationKGMemory(BaseChatMemory, BaseModel):
@ -29,6 +34,7 @@ class ConversationKGMemory(BaseChatMemory, BaseModel):
knowledge_extraction_prompt: BasePromptTemplate = KNOWLEDGE_TRIPLE_EXTRACTION_PROMPT knowledge_extraction_prompt: BasePromptTemplate = KNOWLEDGE_TRIPLE_EXTRACTION_PROMPT
entity_extraction_prompt: BasePromptTemplate = ENTITY_EXTRACTION_PROMPT entity_extraction_prompt: BasePromptTemplate = ENTITY_EXTRACTION_PROMPT
llm: BaseLanguageModel llm: BaseLanguageModel
summary_message_cls: Type[BaseMessage] = SystemMessage
"""Number of previous utterances to include in the context.""" """Number of previous utterances to include in the context."""
memory_key: str = "history" #: :meta private: memory_key: str = "history" #: :meta private:
@ -40,12 +46,15 @@ class ConversationKGMemory(BaseChatMemory, BaseModel):
knowledge = self.kg.get_entity_knowledge(entity) knowledge = self.kg.get_entity_knowledge(entity)
if knowledge: if knowledge:
summaries[entity] = ". ".join(knowledge) + "." summaries[entity] = ". ".join(knowledge) + "."
context: Union[str, List]
if summaries: if summaries:
summary_strings = [ summary_strings = [
f"On {entity}: {summary}" for entity, summary in summaries.items() f"On {entity}: {summary}" for entity, summary in summaries.items()
] ]
if self.return_messages: if self.return_messages:
context: Any = [SystemMessage(content=text) for text in summary_strings] context = [
self.summary_message_cls(content=text) for text in summary_strings
]
else: else:
context = "\n".join(summary_strings) context = "\n".join(summary_strings)
else: else:

View File

@ -1,4 +1,4 @@
from typing import Any, Dict, List from typing import Any, Dict, List, Type
from pydantic import BaseModel, root_validator from pydantic import BaseModel, root_validator
@ -19,6 +19,7 @@ class SummarizerMixin(BaseModel):
ai_prefix: str = "AI" ai_prefix: str = "AI"
llm: BaseLanguageModel llm: BaseLanguageModel
prompt: BasePromptTemplate = SUMMARY_PROMPT prompt: BasePromptTemplate = SUMMARY_PROMPT
summary_message_cls: Type[BaseMessage] = SystemMessage
def predict_new_summary( def predict_new_summary(
self, messages: List[BaseMessage], existing_summary: str self, messages: List[BaseMessage], existing_summary: str
@ -50,7 +51,7 @@ class ConversationSummaryMemory(BaseChatMemory, SummarizerMixin, BaseModel):
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]: def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
"""Return history buffer.""" """Return history buffer."""
if self.return_messages: if self.return_messages:
buffer: Any = [SystemMessage(content=self.buffer)] buffer: Any = [self.summary_message_cls(content=self.buffer)]
else: else:
buffer = self.buffer buffer = self.buffer
return {self.memory_key: buffer} return {self.memory_key: buffer}

View File

@ -4,7 +4,7 @@ from pydantic import BaseModel, root_validator
from langchain.memory.chat_memory import BaseChatMemory from langchain.memory.chat_memory import BaseChatMemory
from langchain.memory.summary import SummarizerMixin from langchain.memory.summary import SummarizerMixin
from langchain.schema import BaseMessage, SystemMessage, get_buffer_string from langchain.schema import BaseMessage, get_buffer_string
class ConversationSummaryBufferMemory(BaseChatMemory, SummarizerMixin, BaseModel): class ConversationSummaryBufferMemory(BaseChatMemory, SummarizerMixin, BaseModel):
@ -12,6 +12,7 @@ class ConversationSummaryBufferMemory(BaseChatMemory, SummarizerMixin, BaseModel
max_token_limit: int = 2000 max_token_limit: int = 2000
moving_summary_buffer: str = "" moving_summary_buffer: str = ""
summary_message_role: str = "system"
memory_key: str = "history" memory_key: str = "history"
@property @property
@ -31,7 +32,7 @@ class ConversationSummaryBufferMemory(BaseChatMemory, SummarizerMixin, BaseModel
buffer = self.buffer buffer = self.buffer
if self.moving_summary_buffer != "": if self.moving_summary_buffer != "":
first_messages: List[BaseMessage] = [ first_messages: List[BaseMessage] = [
SystemMessage(content=self.moving_summary_buffer) self.summary_message_cls(content=self.moving_summary_buffer)
] ]
buffer = first_messages + buffer buffer = first_messages + buffer
if self.return_messages: if self.return_messages: