diff --git a/docs/modules/memory/types/summary.ipynb b/docs/modules/memory/types/summary.ipynb index 89b5865dc8..b2dcbb9bf8 100644 --- a/docs/modules/memory/types/summary.ipynb +++ b/docs/modules/memory/types/summary.ipynb @@ -18,7 +18,7 @@ "metadata": {}, "outputs": [], "source": [ - "from langchain.memory import ConversationSummaryMemory\n", + "from langchain.memory import ConversationSummaryMemory, ChatMessageHistory\n", "from langchain.llms import OpenAI" ] }, @@ -125,6 +125,59 @@ "memory.predict_new_summary(messages, previous_summary)" ] }, + { + "cell_type": "markdown", + "id": "fa3ad83f", + "metadata": {}, + "source": [ + "## Initializing with messages\n", + "\n", + "If you have messages outside this class, you can easily initialize the class with ChatMessageHistory. During loading, a summary will be calculated." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "80fd072b", + "metadata": {}, + "outputs": [], + "source": [ + "history = ChatMessageHistory()\n", + "history.add_user_message(\"hi\")\n", + "history.add_ai_message(\"hi there!\")" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "ee9c74ad", + "metadata": {}, + "outputs": [], + "source": [ + "memory = ConversationSummaryMemory.from_messages(llm=OpenAI(temperature=0), chat_memory=history, return_messages=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "0ce6924d", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'\\nThe human greets the AI, to which the AI responds with a friendly greeting.'" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "memory.buffer" + ] + }, { "cell_type": "markdown", "id": "4fad9448", diff --git a/langchain/memory/summary.py b/langchain/memory/summary.py index 7a2d04f47c..c35bd70b93 100644 --- a/langchain/memory/summary.py +++ b/langchain/memory/summary.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from typing import Any, Dict, List, Type from pydantic import BaseModel, root_validator @@ -8,6 +10,7 @@ from langchain.memory.chat_memory import BaseChatMemory from langchain.memory.prompt import SUMMARY_PROMPT from langchain.prompts.base import BasePromptTemplate from langchain.schema import ( + BaseChatMessageHistory, BaseMessage, SystemMessage, get_buffer_string, @@ -40,6 +43,22 @@ class ConversationSummaryMemory(BaseChatMemory, SummarizerMixin): buffer: str = "" memory_key: str = "history" #: :meta private: + @classmethod + def from_messages( + cls, + llm: BaseLanguageModel, + chat_memory: BaseChatMessageHistory, + *, + summarize_step: int = 2, + **kwargs: Any, + ) -> ConversationSummaryMemory: + obj = cls(llm=llm, chat_memory=chat_memory, **kwargs) + for i in range(0, len(obj.chat_memory.messages), summarize_step): + obj.buffer = obj.predict_new_summary( + obj.chat_memory.messages[i : i + summarize_step], obj.buffer + ) + return obj + @property def memory_variables(self) -> List[str]: """Will always return list of memory variables.