|
|
|
@ -4,8 +4,7 @@ from pydantic import BaseModel, root_validator
|
|
|
|
|
|
|
|
|
|
from langchain.memory.chat_memory import BaseChatMemory
|
|
|
|
|
from langchain.memory.summary import SummarizerMixin
|
|
|
|
|
from langchain.memory.utils import get_buffer_string
|
|
|
|
|
from langchain.schema import BaseMessage, SystemMessage
|
|
|
|
|
from langchain.schema import BaseMessage, SystemMessage, get_buffer_string
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ConversationSummaryBufferMemory(BaseChatMemory, SummarizerMixin, BaseModel):
|
|
|
|
@ -55,21 +54,17 @@ class ConversationSummaryBufferMemory(BaseChatMemory, SummarizerMixin, BaseModel
|
|
|
|
|
)
|
|
|
|
|
return values
|
|
|
|
|
|
|
|
|
|
def get_num_tokens_list(self, arr: List[BaseMessage]) -> List[int]:
|
|
|
|
|
"""Get list of number of tokens in each string in the input array."""
|
|
|
|
|
return [self.llm.get_num_tokens(get_buffer_string([x])) for x in arr]
|
|
|
|
|
|
|
|
|
|
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
|
|
|
|
|
"""Save context from this conversation to buffer."""
|
|
|
|
|
super().save_context(inputs, outputs)
|
|
|
|
|
# Prune buffer if it exceeds max token limit
|
|
|
|
|
buffer = self.chat_memory.messages
|
|
|
|
|
curr_buffer_length = sum(self.get_num_tokens_list(buffer))
|
|
|
|
|
curr_buffer_length = self.llm.get_num_tokens_from_messages(buffer)
|
|
|
|
|
if curr_buffer_length > self.max_token_limit:
|
|
|
|
|
pruned_memory = []
|
|
|
|
|
while curr_buffer_length > self.max_token_limit:
|
|
|
|
|
pruned_memory.append(buffer.pop(0))
|
|
|
|
|
curr_buffer_length = sum(self.get_num_tokens_list(buffer))
|
|
|
|
|
curr_buffer_length = self.llm.get_num_tokens_from_messages(buffer)
|
|
|
|
|
self.moving_summary_buffer = self.predict_new_summary(
|
|
|
|
|
pruned_memory, self.moving_summary_buffer
|
|
|
|
|
)
|
|
|
|
|