forked from Archives/langchain
Harrison/summary message rol (#1783)
Co-authored-by: Aratako <127325395+Aratako@users.noreply.github.com>
This commit is contained in:
parent
85e4dd7fc3
commit
951c158106
@ -1,4 +1,4 @@
|
|||||||
from typing import Any, Dict, List
|
from typing import Any, Dict, List, Type, Union
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
@ -12,7 +12,12 @@ from langchain.memory.prompt import (
|
|||||||
)
|
)
|
||||||
from langchain.memory.utils import get_prompt_input_key
|
from langchain.memory.utils import get_prompt_input_key
|
||||||
from langchain.prompts.base import BasePromptTemplate
|
from langchain.prompts.base import BasePromptTemplate
|
||||||
from langchain.schema import BaseLanguageModel, SystemMessage, get_buffer_string
|
from langchain.schema import (
|
||||||
|
BaseLanguageModel,
|
||||||
|
BaseMessage,
|
||||||
|
SystemMessage,
|
||||||
|
get_buffer_string,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class ConversationKGMemory(BaseChatMemory, BaseModel):
|
class ConversationKGMemory(BaseChatMemory, BaseModel):
|
||||||
@ -29,6 +34,7 @@ class ConversationKGMemory(BaseChatMemory, BaseModel):
|
|||||||
knowledge_extraction_prompt: BasePromptTemplate = KNOWLEDGE_TRIPLE_EXTRACTION_PROMPT
|
knowledge_extraction_prompt: BasePromptTemplate = KNOWLEDGE_TRIPLE_EXTRACTION_PROMPT
|
||||||
entity_extraction_prompt: BasePromptTemplate = ENTITY_EXTRACTION_PROMPT
|
entity_extraction_prompt: BasePromptTemplate = ENTITY_EXTRACTION_PROMPT
|
||||||
llm: BaseLanguageModel
|
llm: BaseLanguageModel
|
||||||
|
summary_message_cls: Type[BaseMessage] = SystemMessage
|
||||||
"""Number of previous utterances to include in the context."""
|
"""Number of previous utterances to include in the context."""
|
||||||
memory_key: str = "history" #: :meta private:
|
memory_key: str = "history" #: :meta private:
|
||||||
|
|
||||||
@ -40,12 +46,15 @@ class ConversationKGMemory(BaseChatMemory, BaseModel):
|
|||||||
knowledge = self.kg.get_entity_knowledge(entity)
|
knowledge = self.kg.get_entity_knowledge(entity)
|
||||||
if knowledge:
|
if knowledge:
|
||||||
summaries[entity] = ". ".join(knowledge) + "."
|
summaries[entity] = ". ".join(knowledge) + "."
|
||||||
|
context: Union[str, List]
|
||||||
if summaries:
|
if summaries:
|
||||||
summary_strings = [
|
summary_strings = [
|
||||||
f"On {entity}: {summary}" for entity, summary in summaries.items()
|
f"On {entity}: {summary}" for entity, summary in summaries.items()
|
||||||
]
|
]
|
||||||
if self.return_messages:
|
if self.return_messages:
|
||||||
context: Any = [SystemMessage(content=text) for text in summary_strings]
|
context = [
|
||||||
|
self.summary_message_cls(content=text) for text in summary_strings
|
||||||
|
]
|
||||||
else:
|
else:
|
||||||
context = "\n".join(summary_strings)
|
context = "\n".join(summary_strings)
|
||||||
else:
|
else:
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
from typing import Any, Dict, List
|
from typing import Any, Dict, List, Type
|
||||||
|
|
||||||
from pydantic import BaseModel, root_validator
|
from pydantic import BaseModel, root_validator
|
||||||
|
|
||||||
@ -19,6 +19,7 @@ class SummarizerMixin(BaseModel):
|
|||||||
ai_prefix: str = "AI"
|
ai_prefix: str = "AI"
|
||||||
llm: BaseLanguageModel
|
llm: BaseLanguageModel
|
||||||
prompt: BasePromptTemplate = SUMMARY_PROMPT
|
prompt: BasePromptTemplate = SUMMARY_PROMPT
|
||||||
|
summary_message_cls: Type[BaseMessage] = SystemMessage
|
||||||
|
|
||||||
def predict_new_summary(
|
def predict_new_summary(
|
||||||
self, messages: List[BaseMessage], existing_summary: str
|
self, messages: List[BaseMessage], existing_summary: str
|
||||||
@ -50,7 +51,7 @@ class ConversationSummaryMemory(BaseChatMemory, SummarizerMixin, BaseModel):
|
|||||||
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
"""Return history buffer."""
|
"""Return history buffer."""
|
||||||
if self.return_messages:
|
if self.return_messages:
|
||||||
buffer: Any = [SystemMessage(content=self.buffer)]
|
buffer: Any = [self.summary_message_cls(content=self.buffer)]
|
||||||
else:
|
else:
|
||||||
buffer = self.buffer
|
buffer = self.buffer
|
||||||
return {self.memory_key: buffer}
|
return {self.memory_key: buffer}
|
||||||
|
@ -4,7 +4,7 @@ from pydantic import BaseModel, root_validator
|
|||||||
|
|
||||||
from langchain.memory.chat_memory import BaseChatMemory
|
from langchain.memory.chat_memory import BaseChatMemory
|
||||||
from langchain.memory.summary import SummarizerMixin
|
from langchain.memory.summary import SummarizerMixin
|
||||||
from langchain.schema import BaseMessage, SystemMessage, get_buffer_string
|
from langchain.schema import BaseMessage, get_buffer_string
|
||||||
|
|
||||||
|
|
||||||
class ConversationSummaryBufferMemory(BaseChatMemory, SummarizerMixin, BaseModel):
|
class ConversationSummaryBufferMemory(BaseChatMemory, SummarizerMixin, BaseModel):
|
||||||
@ -12,6 +12,7 @@ class ConversationSummaryBufferMemory(BaseChatMemory, SummarizerMixin, BaseModel
|
|||||||
|
|
||||||
max_token_limit: int = 2000
|
max_token_limit: int = 2000
|
||||||
moving_summary_buffer: str = ""
|
moving_summary_buffer: str = ""
|
||||||
|
summary_message_role: str = "system"
|
||||||
memory_key: str = "history"
|
memory_key: str = "history"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -31,7 +32,7 @@ class ConversationSummaryBufferMemory(BaseChatMemory, SummarizerMixin, BaseModel
|
|||||||
buffer = self.buffer
|
buffer = self.buffer
|
||||||
if self.moving_summary_buffer != "":
|
if self.moving_summary_buffer != "":
|
||||||
first_messages: List[BaseMessage] = [
|
first_messages: List[BaseMessage] = [
|
||||||
SystemMessage(content=self.moving_summary_buffer)
|
self.summary_message_cls(content=self.moving_summary_buffer)
|
||||||
]
|
]
|
||||||
buffer = first_messages + buffer
|
buffer = first_messages + buffer
|
||||||
if self.return_messages:
|
if self.return_messages:
|
||||||
|
Loading…
Reference in New Issue
Block a user