diff --git a/langchain/memory/kg.py b/langchain/memory/kg.py index 7d23f60b..b0aaa4c0 100644 --- a/langchain/memory/kg.py +++ b/langchain/memory/kg.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List +from typing import Any, Dict, List, Type, Union from pydantic import BaseModel, Field @@ -12,7 +12,12 @@ from langchain.memory.prompt import ( ) from langchain.memory.utils import get_prompt_input_key from langchain.prompts.base import BasePromptTemplate -from langchain.schema import BaseLanguageModel, SystemMessage, get_buffer_string +from langchain.schema import ( + BaseLanguageModel, + BaseMessage, + SystemMessage, + get_buffer_string, +) class ConversationKGMemory(BaseChatMemory, BaseModel): @@ -29,6 +34,7 @@ class ConversationKGMemory(BaseChatMemory, BaseModel): knowledge_extraction_prompt: BasePromptTemplate = KNOWLEDGE_TRIPLE_EXTRACTION_PROMPT entity_extraction_prompt: BasePromptTemplate = ENTITY_EXTRACTION_PROMPT llm: BaseLanguageModel + summary_message_cls: Type[BaseMessage] = SystemMessage """Number of previous utterances to include in the context.""" memory_key: str = "history" #: :meta private: @@ -40,12 +46,15 @@ class ConversationKGMemory(BaseChatMemory, BaseModel): knowledge = self.kg.get_entity_knowledge(entity) if knowledge: summaries[entity] = ". ".join(knowledge) + "." + context: Union[str, List] if summaries: summary_strings = [ f"On {entity}: {summary}" for entity, summary in summaries.items() ] if self.return_messages: - context: Any = [SystemMessage(content=text) for text in summary_strings] + context = [ + self.summary_message_cls(content=text) for text in summary_strings + ] else: context = "\n".join(summary_strings) else: diff --git a/langchain/memory/summary.py b/langchain/memory/summary.py index 082e3a43..4f9b27e1 100644 --- a/langchain/memory/summary.py +++ b/langchain/memory/summary.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List +from typing import Any, Dict, List, Type from pydantic import BaseModel, root_validator @@ -19,6 +19,7 @@ class SummarizerMixin(BaseModel): ai_prefix: str = "AI" llm: BaseLanguageModel prompt: BasePromptTemplate = SUMMARY_PROMPT + summary_message_cls: Type[BaseMessage] = SystemMessage def predict_new_summary( self, messages: List[BaseMessage], existing_summary: str @@ -50,7 +51,7 @@ class ConversationSummaryMemory(BaseChatMemory, SummarizerMixin, BaseModel): def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]: """Return history buffer.""" if self.return_messages: - buffer: Any = [SystemMessage(content=self.buffer)] + buffer: Any = [self.summary_message_cls(content=self.buffer)] else: buffer = self.buffer return {self.memory_key: buffer} diff --git a/langchain/memory/summary_buffer.py b/langchain/memory/summary_buffer.py index 9f0f5c6b..ae29f38a 100644 --- a/langchain/memory/summary_buffer.py +++ b/langchain/memory/summary_buffer.py @@ -4,7 +4,7 @@ from pydantic import BaseModel, root_validator from langchain.memory.chat_memory import BaseChatMemory from langchain.memory.summary import SummarizerMixin -from langchain.schema import BaseMessage, SystemMessage, get_buffer_string +from langchain.schema import BaseMessage, get_buffer_string class ConversationSummaryBufferMemory(BaseChatMemory, SummarizerMixin, BaseModel): @@ -12,6 +12,7 @@ class ConversationSummaryBufferMemory(BaseChatMemory, SummarizerMixin, BaseModel max_token_limit: int = 2000 moving_summary_buffer: str = "" + summary_message_role: str = "system" memory_key: str = "history" @property @@ -31,7 +32,7 @@ class ConversationSummaryBufferMemory(BaseChatMemory, SummarizerMixin, BaseModel buffer = self.buffer if self.moving_summary_buffer != "": first_messages: List[BaseMessage] = [ - SystemMessage(content=self.moving_summary_buffer) + self.summary_message_cls(content=self.moving_summary_buffer) ] buffer = first_messages + buffer if self.return_messages: