forked from Archives/langchain
Harrison/summary message rol (#1783)
Co-authored-by: Aratako <127325395+Aratako@users.noreply.github.com>
This commit is contained in:
parent
85e4dd7fc3
commit
951c158106
@ -1,4 +1,4 @@
|
||||
from typing import Any, Dict, List
|
||||
from typing import Any, Dict, List, Type, Union
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
@ -12,7 +12,12 @@ from langchain.memory.prompt import (
|
||||
)
|
||||
from langchain.memory.utils import get_prompt_input_key
|
||||
from langchain.prompts.base import BasePromptTemplate
|
||||
from langchain.schema import BaseLanguageModel, SystemMessage, get_buffer_string
|
||||
from langchain.schema import (
|
||||
BaseLanguageModel,
|
||||
BaseMessage,
|
||||
SystemMessage,
|
||||
get_buffer_string,
|
||||
)
|
||||
|
||||
|
||||
class ConversationKGMemory(BaseChatMemory, BaseModel):
|
||||
@ -29,6 +34,7 @@ class ConversationKGMemory(BaseChatMemory, BaseModel):
|
||||
knowledge_extraction_prompt: BasePromptTemplate = KNOWLEDGE_TRIPLE_EXTRACTION_PROMPT
|
||||
entity_extraction_prompt: BasePromptTemplate = ENTITY_EXTRACTION_PROMPT
|
||||
llm: BaseLanguageModel
|
||||
summary_message_cls: Type[BaseMessage] = SystemMessage
|
||||
"""Number of previous utterances to include in the context."""
|
||||
memory_key: str = "history" #: :meta private:
|
||||
|
||||
@ -40,12 +46,15 @@ class ConversationKGMemory(BaseChatMemory, BaseModel):
|
||||
knowledge = self.kg.get_entity_knowledge(entity)
|
||||
if knowledge:
|
||||
summaries[entity] = ". ".join(knowledge) + "."
|
||||
context: Union[str, List]
|
||||
if summaries:
|
||||
summary_strings = [
|
||||
f"On {entity}: {summary}" for entity, summary in summaries.items()
|
||||
]
|
||||
if self.return_messages:
|
||||
context: Any = [SystemMessage(content=text) for text in summary_strings]
|
||||
context = [
|
||||
self.summary_message_cls(content=text) for text in summary_strings
|
||||
]
|
||||
else:
|
||||
context = "\n".join(summary_strings)
|
||||
else:
|
||||
|
@ -1,4 +1,4 @@
|
||||
from typing import Any, Dict, List
|
||||
from typing import Any, Dict, List, Type
|
||||
|
||||
from pydantic import BaseModel, root_validator
|
||||
|
||||
@ -19,6 +19,7 @@ class SummarizerMixin(BaseModel):
|
||||
ai_prefix: str = "AI"
|
||||
llm: BaseLanguageModel
|
||||
prompt: BasePromptTemplate = SUMMARY_PROMPT
|
||||
summary_message_cls: Type[BaseMessage] = SystemMessage
|
||||
|
||||
def predict_new_summary(
|
||||
self, messages: List[BaseMessage], existing_summary: str
|
||||
@ -50,7 +51,7 @@ class ConversationSummaryMemory(BaseChatMemory, SummarizerMixin, BaseModel):
|
||||
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Return history buffer."""
|
||||
if self.return_messages:
|
||||
buffer: Any = [SystemMessage(content=self.buffer)]
|
||||
buffer: Any = [self.summary_message_cls(content=self.buffer)]
|
||||
else:
|
||||
buffer = self.buffer
|
||||
return {self.memory_key: buffer}
|
||||
|
@ -4,7 +4,7 @@ from pydantic import BaseModel, root_validator
|
||||
|
||||
from langchain.memory.chat_memory import BaseChatMemory
|
||||
from langchain.memory.summary import SummarizerMixin
|
||||
from langchain.schema import BaseMessage, SystemMessage, get_buffer_string
|
||||
from langchain.schema import BaseMessage, get_buffer_string
|
||||
|
||||
|
||||
class ConversationSummaryBufferMemory(BaseChatMemory, SummarizerMixin, BaseModel):
|
||||
@ -12,6 +12,7 @@ class ConversationSummaryBufferMemory(BaseChatMemory, SummarizerMixin, BaseModel
|
||||
|
||||
max_token_limit: int = 2000
|
||||
moving_summary_buffer: str = ""
|
||||
summary_message_role: str = "system"
|
||||
memory_key: str = "history"
|
||||
|
||||
@property
|
||||
@ -31,7 +32,7 @@ class ConversationSummaryBufferMemory(BaseChatMemory, SummarizerMixin, BaseModel
|
||||
buffer = self.buffer
|
||||
if self.moving_summary_buffer != "":
|
||||
first_messages: List[BaseMessage] = [
|
||||
SystemMessage(content=self.moving_summary_buffer)
|
||||
self.summary_message_cls(content=self.moving_summary_buffer)
|
||||
]
|
||||
buffer = first_messages + buffer
|
||||
if self.return_messages:
|
||||
|
Loading…
Reference in New Issue
Block a user