mirror of
https://github.com/hwchase17/langchain
synced 2024-11-10 01:10:59 +00:00
Add a conversation memory that combines a (optionally persistent) vectorstore history with a token buffer (#22155)
**langchain: ConversationVectorStoreTokenBufferMemory** -**Description:** This PR adds ConversationVectorStoreTokenBufferMemory. It is similar in concept to ConversationSummaryBufferMemory. It maintains an in-memory buffer of messages up to a preset token limit. After the limit is hit timestamped messages are written into a vectorstore retriever rather than into a summary. The user's prompt is then used to retrieve relevant fragments of the previous conversation. By persisting the vectorstore, one can maintain memory from session to session. -**Issue:** n/a -**Dependencies:** none -**Twitter handle:** Please no!!! - [X] **Add tests and docs**: I looked to see how the unit tests were written for the other ConversationMemory modules, but couldn't find anything other than a test for successful import. I need to know whether you are using pytest.mock or another fixture to simulate the LLM and vectorstore. In addition, I would like guidance on where to place the documentation. Should it be a notebook file in docs/docs? - [X] **Lint and test**: I am seeing some linting errors from a couple of modules unrelated to this PR. If no one reviews your PR within a few days, please @-mention one of baskaryan, efriis, eyurtsev, ccurme, vbarda, hwchase17. --------- Co-authored-by: Lincoln Stein <lstein@gmail.com> Co-authored-by: isaac hershenson <ihershenson@hmc.edu>
This commit is contained in:
parent
32f8f39974
commit
c314222796
@ -1,6 +1,7 @@
|
||||
"""__ModuleName__ document loader."""
|
||||
|
||||
from typing import Iterator
|
||||
|
||||
from langchain_core.document_loaders.base import BaseLoader
|
||||
from langchain_core.documents import Document
|
||||
|
||||
|
@ -48,6 +48,9 @@ from langchain.memory.summary import ConversationSummaryMemory
|
||||
from langchain.memory.summary_buffer import ConversationSummaryBufferMemory
|
||||
from langchain.memory.token_buffer import ConversationTokenBufferMemory
|
||||
from langchain.memory.vectorstore import VectorStoreRetrieverMemory
|
||||
from langchain.memory.vectorstore_token_buffer_memory import (
|
||||
ConversationVectorStoreTokenBufferMemory, # avoid circular import
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain_community.chat_message_histories import (
|
||||
@ -122,6 +125,7 @@ __all__ = [
|
||||
"ConversationSummaryBufferMemory",
|
||||
"ConversationSummaryMemory",
|
||||
"ConversationTokenBufferMemory",
|
||||
"ConversationVectorStoreTokenBufferMemory",
|
||||
"CosmosDBChatMessageHistory",
|
||||
"DynamoDBChatMessageHistory",
|
||||
"ElasticsearchChatMessageHistory",
|
||||
|
@ -0,0 +1,184 @@
|
||||
"""
|
||||
Class for a conversation memory buffer with older messages stored in a vectorstore .
|
||||
|
||||
This implementats a conversation memory in which the messages are stored in a memory
|
||||
buffer up to a specified token limit. When the limit is exceeded, older messages are
|
||||
saved to a vectorstore backing database. The vectorstore can be made persistent across
|
||||
sessions.
|
||||
"""
|
||||
|
||||
import warnings
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from langchain_core.messages import BaseMessage
|
||||
from langchain_core.prompts.chat import SystemMessagePromptTemplate
|
||||
from langchain_core.pydantic_v1 import Field, PrivateAttr
|
||||
from langchain_core.vectorstores import VectorStoreRetriever
|
||||
|
||||
from langchain.memory import ConversationTokenBufferMemory, VectorStoreRetrieverMemory
|
||||
from langchain.memory.chat_memory import BaseChatMemory
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
|
||||
DEFAULT_HISTORY_TEMPLATE = """
|
||||
Current date and time: {current_time}.
|
||||
|
||||
Potentially relevant timestamped excerpts of previous conversations (you
|
||||
do not need to use these if irrelevant):
|
||||
{previous_history}
|
||||
|
||||
"""
|
||||
|
||||
TIMESTAMP_FORMAT = "%Y-%m-%d %H:%M:%S %Z"
|
||||
|
||||
|
||||
class ConversationVectorStoreTokenBufferMemory(ConversationTokenBufferMemory):
|
||||
"""Conversation chat memory with token limit and vectordb backing.
|
||||
|
||||
load_memory_variables() will return a dict with the key "history".
|
||||
It contains background information retrieved from the vector store
|
||||
plus recent lines of the current conversation.
|
||||
|
||||
To help the LLM understand the part of the conversation stored in the
|
||||
vectorstore, each interaction is timestamped and the current date and
|
||||
time is also provided in the history. A side effect of this is that the
|
||||
LLM will have access to the current date and time.
|
||||
|
||||
Initialization arguments:
|
||||
|
||||
This class accepts all the initialization arguments of
|
||||
ConversationTokenBufferMemory, such as `llm`. In addition, it
|
||||
accepts the following additional arguments
|
||||
|
||||
retriever: (required) A VectorStoreRetriever object to use
|
||||
as the vector backing store
|
||||
|
||||
split_chunk_size: (optional, 1000) Token chunk split size
|
||||
for long messages generated by the AI
|
||||
|
||||
previous_history_template: (optional) Template used to format
|
||||
the contents of the prompt history
|
||||
|
||||
|
||||
Example using ChromaDB:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from langchain.memory.token_buffer_vectorstore_memory import (
|
||||
ConversationVectorStoreTokenBufferMemory
|
||||
)
|
||||
from langchain_community.vectorstores import Chroma
|
||||
from langchain_community.embeddings import HuggingFaceInstructEmbeddings
|
||||
from langchain_openai import OpenAI
|
||||
|
||||
embedder = HuggingFaceInstructEmbeddings(
|
||||
query_instruction="Represent the query for retrieval: "
|
||||
)
|
||||
chroma = Chroma(collection_name="demo",
|
||||
embedding_function=embedder,
|
||||
collection_metadata={"hnsw:space": "cosine"},
|
||||
)
|
||||
|
||||
retriever = chroma.as_retriever(
|
||||
search_type="similarity_score_threshold",
|
||||
search_kwargs={
|
||||
'k': 5,
|
||||
'score_threshold': 0.75,
|
||||
},
|
||||
)
|
||||
|
||||
conversation_memory = ConversationVectorStoreTokenBufferMemory(
|
||||
return_messages=True,
|
||||
llm=OpenAI(),
|
||||
retriever=retriever,
|
||||
max_token_limit = 1000,
|
||||
)
|
||||
|
||||
conversation_memory.save_context({"Human": "Hi there"},
|
||||
{"AI": "Nice to meet you!"}
|
||||
)
|
||||
conversation_memory.save_context({"Human": "Nice day isn't it?"},
|
||||
{"AI": "I love Wednesdays."}
|
||||
)
|
||||
conversation_memory.load_memory_variables({"input": "What time is it?"})
|
||||
|
||||
"""
|
||||
|
||||
retriever: VectorStoreRetriever = Field(exclude=True)
|
||||
memory_key: str = "history"
|
||||
previous_history_template: str = DEFAULT_HISTORY_TEMPLATE
|
||||
split_chunk_size: int = 1000
|
||||
|
||||
_memory_retriever: VectorStoreRetrieverMemory = PrivateAttr(default=None)
|
||||
_timestamps: List[datetime] = PrivateAttr(default_factory=list)
|
||||
|
||||
@property
|
||||
def memory_retriever(self) -> VectorStoreRetrieverMemory:
|
||||
"""Return a memory retriever from the passed retriever object."""
|
||||
if self._memory_retriever is not None:
|
||||
return self._memory_retriever
|
||||
self._memory_retriever = VectorStoreRetrieverMemory(retriever=self.retriever)
|
||||
return self._memory_retriever
|
||||
|
||||
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Return history and memory buffer."""
|
||||
try:
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
memory_variables = self.memory_retriever.load_memory_variables(inputs)
|
||||
previous_history = memory_variables[self.memory_retriever.memory_key]
|
||||
except AssertionError: # happens when db is empty
|
||||
previous_history = ""
|
||||
current_history = super().load_memory_variables(inputs)
|
||||
template = SystemMessagePromptTemplate.from_template(
|
||||
self.previous_history_template
|
||||
)
|
||||
messages = [
|
||||
template.format(
|
||||
previous_history=previous_history,
|
||||
current_time=datetime.now().astimezone().strftime(TIMESTAMP_FORMAT),
|
||||
)
|
||||
]
|
||||
messages.extend(current_history[self.memory_key])
|
||||
return {self.memory_key: messages}
|
||||
|
||||
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
|
||||
"""Save context from this conversation to buffer. Pruned."""
|
||||
BaseChatMemory.save_context(self, inputs, outputs)
|
||||
self._timestamps.append(datetime.now().astimezone())
|
||||
# Prune buffer if it exceeds max token limit
|
||||
buffer = self.chat_memory.messages
|
||||
curr_buffer_length = self.llm.get_num_tokens_from_messages(buffer)
|
||||
if curr_buffer_length > self.max_token_limit:
|
||||
while curr_buffer_length > self.max_token_limit:
|
||||
self._pop_and_store_interaction(buffer)
|
||||
curr_buffer_length = self.llm.get_num_tokens_from_messages(buffer)
|
||||
|
||||
def save_remainder(self) -> None:
|
||||
"""
|
||||
Save the remainder of the conversation buffer to the vector store.
|
||||
|
||||
This is useful if you have made the vectorstore persistent, in which
|
||||
case this can be called before the end of the session to store the
|
||||
remainder of the conversation.
|
||||
"""
|
||||
buffer = self.chat_memory.messages
|
||||
while len(buffer) > 0:
|
||||
self._pop_and_store_interaction(buffer)
|
||||
|
||||
def _pop_and_store_interaction(self, buffer: List[BaseMessage]) -> None:
|
||||
input = buffer.pop(0)
|
||||
output = buffer.pop(0)
|
||||
timestamp = self._timestamps.pop(0).strftime(TIMESTAMP_FORMAT)
|
||||
# Split AI output into smaller chunks to avoid creating documents
|
||||
# that will overflow the context window
|
||||
ai_chunks = self._split_long_ai_text(str(output.content))
|
||||
for index, chunk in enumerate(ai_chunks):
|
||||
self.memory_retriever.save_context(
|
||||
{"Human": f"<{timestamp}/00> {str(input.content)}"},
|
||||
{"AI": f"<{timestamp}/{index:02}> {chunk}"},
|
||||
)
|
||||
|
||||
def _split_long_ai_text(self, text: str) -> List[str]:
|
||||
splitter = RecursiveCharacterTextSplitter(chunk_size=self.split_chunk_size)
|
||||
return [chunk.page_content for chunk in splitter.create_documents([text])]
|
@ -13,6 +13,7 @@ EXPECTED_ALL = [
|
||||
"ConversationSummaryBufferMemory",
|
||||
"ConversationSummaryMemory",
|
||||
"ConversationTokenBufferMemory",
|
||||
"ConversationVectorStoreTokenBufferMemory",
|
||||
"CosmosDBChatMessageHistory",
|
||||
"DynamoDBChatMessageHistory",
|
||||
"ElasticsearchChatMessageHistory",
|
||||
|
Loading…
Reference in New Issue
Block a user