From 1f18698b2a730e3fe922d5322364bfd6c18d3c0b Mon Sep 17 00:00:00 2001 From: Harrison Chase Date: Sun, 19 Mar 2023 10:42:24 -0700 Subject: [PATCH] Harrison/token buffer memory (#1786) Co-authored-by: Aratako <127325395+Aratako@users.noreply.github.com> --- docs/modules/memory/types/token_buffer.ipynb | 288 +++++++++++++++++++ langchain/memory/__init__.py | 2 + langchain/memory/token_buffer.py | 54 ++++ 3 files changed, 344 insertions(+) create mode 100644 docs/modules/memory/types/token_buffer.ipynb create mode 100644 langchain/memory/token_buffer.py diff --git a/docs/modules/memory/types/token_buffer.ipynb b/docs/modules/memory/types/token_buffer.ipynb new file mode 100644 index 0000000000..3ddbf44f16 --- /dev/null +++ b/docs/modules/memory/types/token_buffer.ipynb @@ -0,0 +1,288 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "ff4be5f3", + "metadata": {}, + "source": [ + "## ConversationTokenBufferMemory\n", + "\n", + "`ConversationTokenBufferMemory` keeps a buffer of recent interactions in memory, and uses token length rather than number of interactions to determine when to flush interactions.\n", + "\n", + "Let's first walk through how to use the utilities" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "da3384db", + "metadata": {}, + "outputs": [], + "source": [ + "from langchain.memory import ConversationTokenBufferMemory\n", + "from langchain.llms import OpenAI\n", + "llm = OpenAI()" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "e00d4938", + "metadata": {}, + "outputs": [], + "source": [ + "memory = ConversationTokenBufferMemory(llm=llm, max_token_limit=10)\n", + "memory.save_context({\"input\": \"hi\"}, {\"ouput\": \"whats up\"})\n", + "memory.save_context({\"input\": \"not much you\"}, {\"ouput\": \"not much\"})" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "2fe28a28", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'history': 'Human: not much you\\nAI: not much'}" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "memory.load_memory_variables({})" + ] + }, + { + "cell_type": "markdown", + "id": "cf57b97a", + "metadata": {}, + "source": [ + "We can also get the history as a list of messages (this is useful if you are using this with a chat model)." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "3422a3a8", + "metadata": {}, + "outputs": [], + "source": [ + "memory = ConversationTokenBufferMemory(llm=llm, max_token_limit=10, return_messages=True)\n", + "memory.save_context({\"input\": \"hi\"}, {\"ouput\": \"whats up\"})\n", + "memory.save_context({\"input\": \"not much you\"}, {\"ouput\": \"not much\"})" + ] + }, + { + "cell_type": "markdown", + "id": "a6d2569f", + "metadata": {}, + "source": [ + "## Using in a chain\n", + "Let's walk through an example, again setting `verbose=True` so we can see the prompt." + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "ebd68c10", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\n", + "\u001b[1m> Entering new ConversationChain chain...\u001b[0m\n", + "Prompt after formatting:\n", + "\u001b[32;1m\u001b[1;3mThe following is a friendly conversation between a human and an AI. The AI is talkative and provides lots of specific details from its context. If the AI does not know the answer to a question, it truthfully says it does not know.\n", + "\n", + "Current conversation:\n", + "\n", + "Human: Hi, what's up?\n", + "AI:\u001b[0m\n", + "\n", + "\u001b[1m> Finished chain.\u001b[0m\n" + ] + }, + { + "data": { + "text/plain": [ + "\" Hi there! I'm doing great, just enjoying the day. How about you?\"" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from langchain.chains import ConversationChain\n", + "conversation_with_summary = ConversationChain(\n", + " llm=llm, \n", + " # We set a very low max_token_limit for the purposes of testing.\n", + " memory=ConversationTokenBufferMemory(llm=OpenAI(), max_token_limit=60),\n", + " verbose=True\n", + ")\n", + "conversation_with_summary.predict(input=\"Hi, what's up?\")" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "86207a61", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\n", + "\u001b[1m> Entering new ConversationChain chain...\u001b[0m\n", + "Prompt after formatting:\n", + "\u001b[32;1m\u001b[1;3mThe following is a friendly conversation between a human and an AI. The AI is talkative and provides lots of specific details from its context. If the AI does not know the answer to a question, it truthfully says it does not know.\n", + "\n", + "Current conversation:\n", + "Human: Hi, what's up?\n", + "AI: Hi there! I'm doing great, just enjoying the day. How about you?\n", + "Human: Just working on writing some documentation!\n", + "AI:\u001b[0m\n", + "\n", + "\u001b[1m> Finished chain.\u001b[0m\n" + ] + }, + { + "data": { + "text/plain": [ + "' Sounds like a productive day! What kind of documentation are you writing?'" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "conversation_with_summary.predict(input=\"Just working on writing some documentation!\")" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "76a0ab39", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\n", + "\u001b[1m> Entering new ConversationChain chain...\u001b[0m\n", + "Prompt after formatting:\n", + "\u001b[32;1m\u001b[1;3mThe following is a friendly conversation between a human and an AI. The AI is talkative and provides lots of specific details from its context. If the AI does not know the answer to a question, it truthfully says it does not know.\n", + "\n", + "Current conversation:\n", + "Human: Hi, what's up?\n", + "AI: Hi there! I'm doing great, just enjoying the day. How about you?\n", + "Human: Just working on writing some documentation!\n", + "AI: Sounds like a productive day! What kind of documentation are you writing?\n", + "Human: For LangChain! Have you heard of it?\n", + "AI:\u001b[0m\n", + "\n", + "\u001b[1m> Finished chain.\u001b[0m\n" + ] + }, + { + "data": { + "text/plain": [ + "\" Yes, I have heard of LangChain! It is a decentralized language-learning platform that connects native speakers and learners in real time. Is that the documentation you're writing about?\"" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "conversation_with_summary.predict(input=\"For LangChain! Have you heard of it?\")" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "8c669db1", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\n", + "\u001b[1m> Entering new ConversationChain chain...\u001b[0m\n", + "Prompt after formatting:\n", + "\u001b[32;1m\u001b[1;3mThe following is a friendly conversation between a human and an AI. The AI is talkative and provides lots of specific details from its context. If the AI does not know the answer to a question, it truthfully says it does not know.\n", + "\n", + "Current conversation:\n", + "Human: For LangChain! Have you heard of it?\n", + "AI: Yes, I have heard of LangChain! It is a decentralized language-learning platform that connects native speakers and learners in real time. Is that the documentation you're writing about?\n", + "Human: Haha nope, although a lot of people confuse it for that\n", + "AI:\u001b[0m\n", + "\n", + "\u001b[1m> Finished chain.\u001b[0m\n" + ] + }, + { + "data": { + "text/plain": [ + "\" Oh, I see. Is there another language learning platform you're referring to?\"" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# We can see here that the buffer is updated\n", + "conversation_with_summary.predict(input=\"Haha nope, although a lot of people confuse it for that\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8c09a239", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.1" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/langchain/memory/__init__.py b/langchain/memory/__init__.py index dce50a92e2..65fe5d628d 100644 --- a/langchain/memory/__init__.py +++ b/langchain/memory/__init__.py @@ -11,6 +11,7 @@ from langchain.memory.readonly import ReadOnlySharedMemory from langchain.memory.simple import SimpleMemory from langchain.memory.summary import ConversationSummaryMemory from langchain.memory.summary_buffer import ConversationSummaryBufferMemory +from langchain.memory.token_buffer import ConversationTokenBufferMemory __all__ = [ "CombinedMemory", @@ -24,4 +25,5 @@ __all__ = [ "ChatMessageHistory", "ConversationStringBufferMemory", "ReadOnlySharedMemory", + "ConversationTokenBufferMemory", ] diff --git a/langchain/memory/token_buffer.py b/langchain/memory/token_buffer.py new file mode 100644 index 0000000000..3bd9b68410 --- /dev/null +++ b/langchain/memory/token_buffer.py @@ -0,0 +1,54 @@ +from typing import Any, Dict, List + +from pydantic import BaseModel + +from langchain.memory.chat_memory import BaseChatMemory +from langchain.schema import BaseLanguageModel, BaseMessage, get_buffer_string + + +class ConversationTokenBufferMemory(BaseChatMemory, BaseModel): + """Buffer for storing conversation memory.""" + + human_prefix: str = "Human" + ai_prefix: str = "AI" + llm: BaseLanguageModel + memory_key: str = "history" + max_token_limit: int = 2000 + + @property + def buffer(self) -> List[BaseMessage]: + """String buffer of memory.""" + return self.chat_memory.messages + + @property + def memory_variables(self) -> List[str]: + """Will always return list of memory variables. + + :meta private: + """ + return [self.memory_key] + + def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + """Return history buffer.""" + buffer: Any = self.buffer + if self.return_messages: + final_buffer: Any = buffer + else: + final_buffer = get_buffer_string( + buffer, + human_prefix=self.human_prefix, + ai_prefix=self.ai_prefix, + ) + return {self.memory_key: final_buffer} + + def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None: + """Save context from this conversation to buffer. Pruned.""" + super().save_context(inputs, outputs) + # Prune buffer if it exceeds max token limit + buffer = self.chat_memory.messages + curr_buffer_length = self.llm.get_num_tokens_from_messages(buffer) + if curr_buffer_length > self.max_token_limit: + pruned_memory = [] + while curr_buffer_length > self.max_token_limit: + pruned_memory.append(buffer.pop(0)) + curr_buffer_length = self.llm.get_num_tokens_from_messages(buffer)