forked from Archives/langchain
You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
81 lines
3.1 KiB
Python
81 lines
3.1 KiB
Python
from typing import Any, Dict, List
|
|
|
|
from pydantic import BaseModel, root_validator
|
|
|
|
from langchain.memory.chat_memory import BaseChatMemory
|
|
from langchain.memory.summary import SummarizerMixin
|
|
from langchain.memory.utils import get_buffer_string
|
|
from langchain.schema import BaseMessage, SystemMessage
|
|
|
|
|
|
class ConversationSummaryBufferMemory(BaseChatMemory, SummarizerMixin, BaseModel):
|
|
"""Buffer with summarizer for storing conversation memory."""
|
|
|
|
max_token_limit: int = 2000
|
|
moving_summary_buffer: str = ""
|
|
memory_key: str = "history"
|
|
|
|
@property
|
|
def buffer(self) -> List[BaseMessage]:
|
|
return self.chat_memory.messages
|
|
|
|
@property
|
|
def memory_variables(self) -> List[str]:
|
|
"""Will always return list of memory variables.
|
|
|
|
:meta private:
|
|
"""
|
|
return [self.memory_key]
|
|
|
|
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
|
"""Return history buffer."""
|
|
buffer = self.buffer
|
|
if self.moving_summary_buffer != "":
|
|
first_messages: List[BaseMessage] = [
|
|
SystemMessage(content=self.moving_summary_buffer)
|
|
]
|
|
buffer = first_messages + buffer
|
|
if self.return_messages:
|
|
final_buffer: Any = buffer
|
|
else:
|
|
final_buffer = get_buffer_string(
|
|
buffer, human_prefix=self.human_prefix, ai_prefix=self.ai_prefix
|
|
)
|
|
return {self.memory_key: final_buffer}
|
|
|
|
@root_validator()
|
|
def validate_prompt_input_variables(cls, values: Dict) -> Dict:
|
|
"""Validate that prompt input variables are consistent."""
|
|
prompt_variables = values["prompt"].input_variables
|
|
expected_keys = {"summary", "new_lines"}
|
|
if expected_keys != set(prompt_variables):
|
|
raise ValueError(
|
|
"Got unexpected prompt input variables. The prompt expects "
|
|
f"{prompt_variables}, but it should have {expected_keys}."
|
|
)
|
|
return values
|
|
|
|
def get_num_tokens_list(self, arr: List[BaseMessage]) -> List[int]:
|
|
"""Get list of number of tokens in each string in the input array."""
|
|
return [self.llm.get_num_tokens(get_buffer_string([x])) for x in arr]
|
|
|
|
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
|
|
"""Save context from this conversation to buffer."""
|
|
super().save_context(inputs, outputs)
|
|
# Prune buffer if it exceeds max token limit
|
|
buffer = self.chat_memory.messages
|
|
curr_buffer_length = sum(self.get_num_tokens_list(buffer))
|
|
if curr_buffer_length > self.max_token_limit:
|
|
pruned_memory = []
|
|
while curr_buffer_length > self.max_token_limit:
|
|
pruned_memory.append(buffer.pop(0))
|
|
curr_buffer_length = sum(self.get_num_tokens_list(buffer))
|
|
self.moving_summary_buffer = self.predict_new_summary(
|
|
pruned_memory, self.moving_summary_buffer
|
|
)
|
|
|
|
def clear(self) -> None:
|
|
"""Clear memory contents."""
|
|
super().clear()
|
|
self.moving_summary_buffer = ""
|