Harrison/chat memory (#1495)

This commit is contained in:
Harrison Chase 2023-03-07 09:02:40 -08:00 committed by GitHub
parent 7bec461782
commit f276bfad8e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 226 additions and 97 deletions

View File

@ -26,7 +26,7 @@
{
"cell_type": "code",
"execution_count": 1,
"id": "356e7c73",
"id": "87235cf1",
"metadata": {},
"outputs": [],
"source": [
@ -36,7 +36,7 @@
{
"cell_type": "code",
"execution_count": 2,
"id": "54f90bd7",
"id": "4404d509",
"metadata": {},
"outputs": [],
"source": [
@ -46,7 +46,7 @@
{
"cell_type": "code",
"execution_count": 3,
"id": "3884ff81",
"id": "78c1a67b",
"metadata": {},
"outputs": [],
"source": [
@ -56,7 +56,7 @@
{
"cell_type": "code",
"execution_count": 4,
"id": "b48d5844",
"id": "525ce606",
"metadata": {},
"outputs": [],
"source": [
@ -66,7 +66,7 @@
{
"cell_type": "code",
"execution_count": 5,
"id": "495dcff0",
"id": "be030822",
"metadata": {},
"outputs": [
{
@ -87,7 +87,7 @@
},
{
"cell_type": "markdown",
"id": "46196aa3",
"id": "2c0328fb",
"metadata": {},
"source": [
"## ConversationBufferMemory\n",
@ -100,7 +100,7 @@
{
"cell_type": "code",
"execution_count": 7,
"id": "3bac84f3",
"id": "a382b160",
"metadata": {},
"outputs": [],
"source": [
@ -110,7 +110,7 @@
{
"cell_type": "code",
"execution_count": 10,
"id": "cef35e7f",
"id": "a280d337",
"metadata": {},
"outputs": [],
"source": [
@ -122,7 +122,7 @@
{
"cell_type": "code",
"execution_count": 12,
"id": "2c9b39af",
"id": "1b739c0a",
"metadata": {},
"outputs": [
{
@ -142,7 +142,7 @@
},
{
"cell_type": "markdown",
"id": "567f7c16",
"id": "989e9425",
"metadata": {},
"source": [
"We can also get the history as a list of messages"
@ -151,7 +151,7 @@
{
"cell_type": "code",
"execution_count": 13,
"id": "a481a415",
"id": "798ceb1c",
"metadata": {},
"outputs": [],
"source": [
@ -163,7 +163,7 @@
{
"cell_type": "code",
"execution_count": 14,
"id": "86a56348",
"id": "698688fd",
"metadata": {},
"outputs": [
{
@ -187,6 +187,7 @@
"id": "d051c1da",
"metadata": {},
"source": [
"## Using in a chain\n",
"Finally, let's take a look at using this in a chain (setting `verbose=True` so we can see the prompt)."
]
},
@ -332,7 +333,7 @@
},
{
"cell_type": "markdown",
"id": "bd0146c2",
"id": "7826c210",
"metadata": {},
"source": [
"And that's it for the getting started! There are plenty of different types of memory, check out our examples to see them all"
@ -341,7 +342,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "447c138d",
"id": "3dd37d93",
"metadata": {},
"outputs": [],
"source": []

View File

@ -100,6 +100,7 @@
"id": "d051c1da",
"metadata": {},
"source": [
"## Using in a chain\n",
"Finally, let's take a look at using this in a chain (setting `verbose=True` so we can see the prompt)."
]
},

View File

@ -2,7 +2,7 @@
"cells": [
{
"cell_type": "markdown",
"id": "716c8cab",
"id": "a20c4e38",
"metadata": {},
"source": [
"## ConversationBufferWindowMemory\n",
@ -15,7 +15,7 @@
{
"cell_type": "code",
"execution_count": 3,
"id": "dc9d10b0",
"id": "1196da3f",
"metadata": {},
"outputs": [],
"source": [
@ -37,7 +37,7 @@
{
"cell_type": "code",
"execution_count": 6,
"id": "40664f10",
"id": "0c034a90",
"metadata": {},
"outputs": [
{
@ -57,7 +57,7 @@
},
{
"cell_type": "markdown",
"id": "b1932b49",
"id": "8c5cce1d",
"metadata": {},
"source": [
"We can also get the history as a list of messages (this is useful if you are using this with a chat model)."
@ -66,7 +66,7 @@
{
"cell_type": "code",
"execution_count": 8,
"id": "5fd077d5",
"id": "9b15b427",
"metadata": {},
"outputs": [],
"source": [
@ -78,7 +78,7 @@
{
"cell_type": "code",
"execution_count": 9,
"id": "b94b750f",
"id": "3bb47191",
"metadata": {},
"outputs": [
{
@ -99,9 +99,10 @@
},
{
"cell_type": "markdown",
"id": "ac59a682",
"id": "a95af04c",
"metadata": {},
"source": [
"## Using in a chain\n",
"Let's walk through an example, again setting `verbose=True` so we can see the prompt."
]
},

View File

@ -14,7 +14,7 @@
{
"cell_type": "code",
"execution_count": 1,
"id": "7b165e71",
"id": "1bea1181",
"metadata": {},
"outputs": [],
"source": [
@ -26,7 +26,7 @@
{
"cell_type": "code",
"execution_count": 2,
"id": "25354d39",
"id": "34425079",
"metadata": {},
"outputs": [],
"source": [
@ -42,7 +42,7 @@
{
"cell_type": "code",
"execution_count": 3,
"id": "71c75295",
"id": "b425642c",
"metadata": {},
"outputs": [
{
@ -64,7 +64,7 @@
{
"cell_type": "code",
"execution_count": 4,
"id": "bb8a7943",
"id": "3bf89b46",
"metadata": {},
"outputs": [],
"source": [
@ -80,7 +80,7 @@
{
"cell_type": "code",
"execution_count": 5,
"id": "a37ac236",
"id": "3e37d126",
"metadata": {},
"outputs": [
{
@ -102,9 +102,10 @@
},
{
"cell_type": "markdown",
"id": "655ab2c4",
"id": "ee5ad043",
"metadata": {},
"source": [
"## Using in a chain\n",
"Let's now use it in a chain!"
]
},
@ -189,7 +190,7 @@
{
"cell_type": "code",
"execution_count": 9,
"id": "dc1c0d5e",
"id": "0269f513",
"metadata": {},
"outputs": [
{

View File

@ -26,7 +26,7 @@
{
"cell_type": "code",
"execution_count": 2,
"id": "1d9a5b18",
"id": "2f4a3c85",
"metadata": {},
"outputs": [],
"source": [
@ -39,7 +39,7 @@
{
"cell_type": "code",
"execution_count": 3,
"id": "3f5da288",
"id": "72283b4f",
"metadata": {},
"outputs": [
{
@ -59,7 +59,7 @@
},
{
"cell_type": "markdown",
"id": "2ab23035",
"id": "0c8ff11e",
"metadata": {},
"source": [
"We can also get the history as a list of messages (this is useful if you are using this with a chat model)."
@ -68,7 +68,7 @@
{
"cell_type": "code",
"execution_count": 4,
"id": "43dbbc43",
"id": "44df43af",
"metadata": {},
"outputs": [],
"source": [
@ -80,7 +80,7 @@
{
"cell_type": "code",
"execution_count": 5,
"id": "ef6a4f60",
"id": "4726b1c8",
"metadata": {},
"outputs": [
{
@ -100,9 +100,68 @@
},
{
"cell_type": "markdown",
"id": "5ba0dde9",
"id": "dc956b0e",
"metadata": {},
"source": [
"We can also more modularly get current entities from a new message (will use previous messages as context.)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "36331ca5",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"['Sam']"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"memory.get_current_entities(\"what's Sams favorite color?\")"
]
},
{
"cell_type": "markdown",
"id": "e8749134",
"metadata": {},
"source": [
"We can also more modularly get knowledge triplets from a new message (will use previous messages as context.)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "b02d44db",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[KnowledgeTriple(subject='Sam', predicate='favorite color', object_='red')]"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"memory.get_knowledge_triplets(\"her favorite color is red\")"
]
},
{
"cell_type": "markdown",
"id": "f7a02ef3",
"metadata": {},
"source": [
"## Using in a chain\n",
"Let's now use this in a chain!"
]
},

View File

@ -2,7 +2,7 @@
"cells": [
{
"cell_type": "markdown",
"id": "5d3f2f0f",
"id": "1674bfd6",
"metadata": {},
"source": [
"## ConversationSummaryMemory\n",
@ -14,7 +14,7 @@
{
"cell_type": "code",
"execution_count": 1,
"id": "a36c34d6",
"id": "c5565e5c",
"metadata": {},
"outputs": [],
"source": [
@ -25,7 +25,7 @@
{
"cell_type": "code",
"execution_count": 2,
"id": "89646c34",
"id": "61621239",
"metadata": {},
"outputs": [],
"source": [
@ -36,7 +36,7 @@
{
"cell_type": "code",
"execution_count": 3,
"id": "dea210aa",
"id": "3bcb8b02",
"metadata": {},
"outputs": [
{
@ -56,7 +56,7 @@
},
{
"cell_type": "markdown",
"id": "3838fe93",
"id": "dedf0698",
"metadata": {},
"source": [
"We can also get the history as a list of messages (this is useful if you are using this with a chat model)."
@ -65,7 +65,7 @@
{
"cell_type": "code",
"execution_count": 4,
"id": "114dbef6",
"id": "6cb06b22",
"metadata": {},
"outputs": [],
"source": [
@ -76,13 +76,13 @@
{
"cell_type": "code",
"execution_count": 5,
"id": "39c8c106",
"id": "47b03ed7",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'history': [SystemMessage(content='\\nThe human greets the AI and the AI responds with a casual greeting.', additional_kwargs={})]}"
"{'history': [SystemMessage(content='\\nThe human greets the AI, to which the AI responds.', additional_kwargs={})]}"
]
},
"execution_count": 5,
@ -94,11 +94,43 @@
"memory.load_memory_variables({})"
]
},
{
"cell_type": "markdown",
"id": "9ec0a0ee",
"metadata": {},
"source": [
"We can also utilize the `predict_new_summary` method directly."
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "9c4dafb9",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'\\nThe human greets the AI, to which the AI responds.'"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"messages = memory.chat_memory.messages\n",
"previous_summary = \"\"\n",
"memory.predict_new_summary(messages, previous_summary)"
]
},
{
"cell_type": "markdown",
"id": "4fad9448",
"metadata": {},
"source": [
"## Using in a chain\n",
"Let's walk through an example of using this in a chain, again setting `verbose=True` so we can see the prompt."
]
},

View File

@ -2,7 +2,7 @@
"cells": [
{
"cell_type": "markdown",
"id": "2e90817b",
"id": "ff4be5f3",
"metadata": {},
"source": [
"## ConversationSummaryBufferMemory\n",
@ -14,8 +14,8 @@
},
{
"cell_type": "code",
"execution_count": 2,
"id": "b8aec547",
"execution_count": 1,
"id": "da3384db",
"metadata": {},
"outputs": [],
"source": [
@ -26,8 +26,8 @@
},
{
"cell_type": "code",
"execution_count": 8,
"id": "2594c8f1",
"execution_count": 2,
"id": "e00d4938",
"metadata": {},
"outputs": [],
"source": [
@ -38,17 +38,17 @@
},
{
"cell_type": "code",
"execution_count": 9,
"id": "a25087e0",
"execution_count": 3,
"id": "2fe28a28",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'history': 'System: \\nThe human greets the AI, to which the AI responds inquiring what is up.\\nHuman: not much you\\nAI: not much'}"
"{'history': 'System: \\nThe human says \"hi\", and the AI responds with \"whats up\".\\nHuman: not much you\\nAI: not much'}"
]
},
"execution_count": 9,
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
@ -59,7 +59,7 @@
},
{
"cell_type": "markdown",
"id": "3a6e7905",
"id": "cf57b97a",
"metadata": {},
"source": [
"We can also get the history as a list of messages (this is useful if you are using this with a chat model)."
@ -67,8 +67,8 @@
},
{
"cell_type": "code",
"execution_count": 10,
"id": "e439451f",
"execution_count": 4,
"id": "3422a3a8",
"metadata": {},
"outputs": [],
"source": [
@ -77,17 +77,49 @@
"memory.save_context({\"input\": \"not much you\"}, {\"ouput\": \"not much\"})"
]
},
{
"cell_type": "markdown",
"id": "a1dcaaee",
"metadata": {},
"source": [
"We can also utilize the `predict_new_summary` method directly."
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "fd7d7d6b",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'\\nThe human and AI state that they are not doing much.'"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"messages = memory.chat_memory.messages\n",
"previous_summary = \"\"\n",
"memory.predict_new_summary(messages, previous_summary)"
]
},
{
"cell_type": "markdown",
"id": "a6d2569f",
"metadata": {},
"source": [
"## Using in a chain\n",
"Let's walk through an example, again setting `verbose=True` so we can see the prompt."
]
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": 6,
"id": "ebd68c10",
"metadata": {},
"outputs": [
@ -112,10 +144,10 @@
{
"data": {
"text/plain": [
"\" Hi there! I'm doing great. I'm spending some time learning about the latest developments in AI technology. How about you?\""
"\" Hi there! I'm doing great. I'm learning about the latest advances in artificial intelligence. What about you?\""
]
},
"execution_count": 11,
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}

View File

@ -4,7 +4,7 @@ from pydantic import BaseModel, Field
from langchain.chains.llm import LLMChain
from langchain.graphs import NetworkxEntityGraph
from langchain.graphs.networkx_graph import get_entities, parse_triples
from langchain.graphs.networkx_graph import KnowledgeTriple, get_entities, parse_triples
from langchain.llms.base import BaseLLM
from langchain.memory.chat_memory import BaseChatMemory
from langchain.memory.prompt import (
@ -78,9 +78,7 @@ class ConversationKGMemory(BaseChatMemory, BaseModel):
return list(outputs.keys())[0]
return self.output_key
def _get_current_entities(self, inputs: Dict[str, Any]) -> List[str]:
"""Get the current entities in the conversation."""
prompt_input_key = self._get_prompt_input_key(inputs)
def get_current_entities(self, input_string: str) -> List[str]:
chain = LLMChain(llm=self.llm, prompt=self.entity_extraction_prompt)
buffer_string = get_buffer_string(
self.chat_memory.messages[-self.k * 2 :],
@ -89,14 +87,17 @@ class ConversationKGMemory(BaseChatMemory, BaseModel):
)
output = chain.predict(
history=buffer_string,
input=inputs[prompt_input_key],
input=input_string,
)
return get_entities(output)
def _get_and_update_kg(self, inputs: Dict[str, Any]) -> None:
"""Get and update knowledge graph from the conversation history."""
chain = LLMChain(llm=self.llm, prompt=self.knowledge_extraction_prompt)
def _get_current_entities(self, inputs: Dict[str, Any]) -> List[str]:
"""Get the current entities in the conversation."""
prompt_input_key = self._get_prompt_input_key(inputs)
return self.get_current_entities(inputs[prompt_input_key])
def get_knowledge_triplets(self, input_string: str) -> List[KnowledgeTriple]:
chain = LLMChain(llm=self.llm, prompt=self.knowledge_extraction_prompt)
buffer_string = get_buffer_string(
self.chat_memory.messages[-self.k * 2 :],
human_prefix=self.human_prefix,
@ -104,10 +105,16 @@ class ConversationKGMemory(BaseChatMemory, BaseModel):
)
output = chain.predict(
history=buffer_string,
input=inputs[prompt_input_key],
input=input_string,
verbose=True,
)
knowledge = parse_triples(output)
return knowledge
def _get_and_update_kg(self, inputs: Dict[str, Any]) -> None:
"""Get and update knowledge graph from the conversation history."""
prompt_input_key = self._get_prompt_input_key(inputs)
knowledge = self.get_knowledge_triplets(inputs[prompt_input_key])
for triple in knowledge:
self.kg.add_triple(triple)

View File

@ -8,17 +8,32 @@ from langchain.memory.chat_memory import BaseChatMemory
from langchain.memory.prompt import SUMMARY_PROMPT
from langchain.memory.utils import get_buffer_string
from langchain.prompts.base import BasePromptTemplate
from langchain.schema import SystemMessage
from langchain.schema import BaseMessage, SystemMessage
class ConversationSummaryMemory(BaseChatMemory, BaseModel):
"""Conversation summarizer to memory."""
buffer: str = ""
class SummarizerMixin(BaseModel):
human_prefix: str = "Human"
ai_prefix: str = "AI"
llm: BaseLLM
prompt: BasePromptTemplate = SUMMARY_PROMPT
def predict_new_summary(
self, messages: List[BaseMessage], existing_summary: str
) -> str:
new_lines = get_buffer_string(
messages,
human_prefix=self.human_prefix,
ai_prefix=self.ai_prefix,
)
chain = LLMChain(llm=self.llm, prompt=self.prompt)
return chain.predict(summary=existing_summary, new_lines=new_lines)
class ConversationSummaryMemory(BaseChatMemory, SummarizerMixin, BaseModel):
"""Conversation summarizer to memory."""
buffer: str = ""
memory_key: str = "history" #: :meta private:
@property
@ -52,15 +67,10 @@ class ConversationSummaryMemory(BaseChatMemory, BaseModel):
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
"""Save context from this conversation to buffer."""
super().save_context(inputs, outputs)
new_lines = get_buffer_string(
self.chat_memory.messages[-2:],
human_prefix=self.human_prefix,
ai_prefix=self.ai_prefix,
self.buffer = self.predict_new_summary(
self.chat_memory.messages[-2:], self.buffer
)
chain = LLMChain(llm=self.llm, prompt=self.prompt)
self.buffer = chain.predict(summary=self.buffer, new_lines=new_lines)
def clear(self) -> None:
"""Clear memory contents."""
super().clear()

View File

@ -2,24 +2,17 @@ from typing import Any, Dict, List
from pydantic import BaseModel, root_validator
from langchain.chains.llm import LLMChain
from langchain.llms.base import BaseLLM
from langchain.memory.chat_memory import BaseChatMemory
from langchain.memory.prompt import SUMMARY_PROMPT
from langchain.memory.summary import SummarizerMixin
from langchain.memory.utils import get_buffer_string
from langchain.prompts.base import BasePromptTemplate
from langchain.schema import BaseMessage, SystemMessage
class ConversationSummaryBufferMemory(BaseChatMemory, BaseModel):
class ConversationSummaryBufferMemory(BaseChatMemory, SummarizerMixin, BaseModel):
"""Buffer with summarizer for storing conversation memory."""
max_token_limit: int = 2000
human_prefix: str = "Human"
ai_prefix: str = "AI"
moving_summary_buffer: str = ""
llm: BaseLLM
prompt: BasePromptTemplate = SUMMARY_PROMPT
memory_key: str = "history"
@property
@ -77,16 +70,8 @@ class ConversationSummaryBufferMemory(BaseChatMemory, BaseModel):
while curr_buffer_length > self.max_token_limit:
pruned_memory.append(buffer.pop(0))
curr_buffer_length = sum(self.get_num_tokens_list(buffer))
chain = LLMChain(llm=self.llm, prompt=self.prompt)
self.moving_summary_buffer = chain.predict(
summary=self.moving_summary_buffer,
new_lines=(
get_buffer_string(
pruned_memory,
human_prefix=self.human_prefix,
ai_prefix=self.ai_prefix,
)
),
self.moving_summary_buffer = self.predict_new_summary(
pruned_memory, self.moving_summary_buffer
)
def clear(self) -> None: