|
|
|
@ -1,3 +1,5 @@
|
|
|
|
|
from __future__ import annotations
|
|
|
|
|
|
|
|
|
|
from typing import Any, Dict, List, Type
|
|
|
|
|
|
|
|
|
|
from pydantic import BaseModel, root_validator
|
|
|
|
@ -8,6 +10,7 @@ from langchain.memory.chat_memory import BaseChatMemory
|
|
|
|
|
from langchain.memory.prompt import SUMMARY_PROMPT
|
|
|
|
|
from langchain.prompts.base import BasePromptTemplate
|
|
|
|
|
from langchain.schema import (
|
|
|
|
|
BaseChatMessageHistory,
|
|
|
|
|
BaseMessage,
|
|
|
|
|
SystemMessage,
|
|
|
|
|
get_buffer_string,
|
|
|
|
@ -40,6 +43,22 @@ class ConversationSummaryMemory(BaseChatMemory, SummarizerMixin):
|
|
|
|
|
buffer: str = ""
|
|
|
|
|
memory_key: str = "history" #: :meta private:
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def from_messages(
|
|
|
|
|
cls,
|
|
|
|
|
llm: BaseLanguageModel,
|
|
|
|
|
chat_memory: BaseChatMessageHistory,
|
|
|
|
|
*,
|
|
|
|
|
summarize_step: int = 2,
|
|
|
|
|
**kwargs: Any,
|
|
|
|
|
) -> ConversationSummaryMemory:
|
|
|
|
|
obj = cls(llm=llm, chat_memory=chat_memory, **kwargs)
|
|
|
|
|
for i in range(0, len(obj.chat_memory.messages), summarize_step):
|
|
|
|
|
obj.buffer = obj.predict_new_summary(
|
|
|
|
|
obj.chat_memory.messages[i : i + summarize_step], obj.buffer
|
|
|
|
|
)
|
|
|
|
|
return obj
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def memory_variables(self) -> List[str]:
|
|
|
|
|
"""Will always return list of memory variables.
|
|
|
|
|