add output key for memory (#443)

this allows chains that return multiple values to use memory
harrison/callback-updates
Harrison Chase 1 year ago committed by GitHub
parent 5208bb8c36
commit 55007e71be
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -167,7 +167,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.6"
"version": "3.10.8"
}
},
"nbformat": 4,

@ -0,0 +1,174 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "e42733c5",
"metadata": {},
"source": [
"# Adding Memory to a Multi-Input Chain\n",
"\n",
"Most memory objects assume a single output. In this notebook, we go over how to add memory to a chain that has multiple outputs. As an example of such a chain, we will add memory to a question/answering chain. This chain takes as inputs both related documents and a user question."
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "978ba52b",
"metadata": {},
"outputs": [],
"source": [
"from langchain.embeddings.openai import OpenAIEmbeddings\n",
"from langchain.embeddings.cohere import CohereEmbeddings\n",
"from langchain.text_splitter import CharacterTextSplitter\n",
"from langchain.vectorstores.elastic_vector_search import ElasticVectorSearch\n",
"from langchain.vectorstores.faiss import FAISS\n",
"from langchain.docstore.document import Document"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "2ee8628b",
"metadata": {},
"outputs": [],
"source": [
"with open('../state_of_the_union.txt') as f:\n",
" state_of_the_union = f.read()\n",
"text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)\n",
"texts = text_splitter.split_text(state_of_the_union)\n",
"\n",
"embeddings = OpenAIEmbeddings()"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "aa70c847",
"metadata": {},
"outputs": [],
"source": [
"docsearch = FAISS.from_texts(texts, embeddings, metadatas=[{\"source\": i} for i in range(len(texts))])"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "ea4f7d82",
"metadata": {},
"outputs": [],
"source": [
"query = \"What did the president say about Justice Breyer\"\n",
"docs = docsearch.similarity_search(query)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "d3dc4ed5",
"metadata": {},
"outputs": [],
"source": [
"from langchain.chains.question_answering import load_qa_chain\n",
"from langchain.llms import OpenAI\n",
"from langchain.prompts import PromptTemplate\n",
"from langchain.chains.conversation.memory import ConversationBufferMemory"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "9a530742",
"metadata": {},
"outputs": [],
"source": [
"template = \"\"\"You are a chatbot having a conversation with a human.\n",
"\n",
"Given the following extracted parts of a long document and a question, create a final answer.\n",
"\n",
"{context}\n",
"\n",
"{chat_history}\n",
"Human: {human_input}\n",
"Chatbot:\"\"\"\n",
"\n",
"prompt = PromptTemplate(\n",
" input_variables=[\"chat_history\", \"human_input\", \"context\"], \n",
" template=template\n",
")\n",
"memory = ConversationBufferMemory(memory_key=\"chat_history\", input_key=\"human_input\")\n",
"chain = load_qa_chain(OpenAI(temperature=0), chain_type=\"stuff\", memory=memory, prompt=prompt)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "9bb8a8b4",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'output_text': \" President Biden honored Justice Stephen Breyer, an Army veteran, Constitutional scholar, and retiring Justice of the United States Supreme Court. He thanked Justice Breyer for his service and said that one of the most serious constitutional responsibilities a President has is nominating someone to serve on the United States Supreme Court. He then announced his nomination of Circuit Court of Appeals Judge Ketanji Brown Jackson to continue Justice Breyer's legacy of excellence.\"}"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"query = \"What did the president say about Justice Breyer\"\n",
"chain({\"input_documents\": docs, \"human_input\": query}, return_only_outputs=True)"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "82593148",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"Human: What did the president say about Justice Breyer\n",
"AI: President Biden honored Justice Stephen Breyer, an Army veteran, Constitutional scholar, and retiring Justice of the United States Supreme Court. He thanked Justice Breyer for his service and said that one of the most serious constitutional responsibilities a President has is nominating someone to serve on the United States Supreme Court. He then announced his nomination of Circuit Court of Appeals Judge Ketanji Brown Jackson to continue Justice Breyer's legacy of excellence.\n"
]
}
],
"source": [
"print(chain.memory.buffer)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f262b2fb",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.8"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

@ -1,5 +1,5 @@
"""Memory modules for conversation prompts."""
from typing import Any, Dict, List
from typing import Any, Dict, List, Optional
from pydantic import BaseModel, Field, root_validator
@ -25,6 +25,8 @@ class ConversationBufferMemory(Memory, BaseModel):
ai_prefix: str = "AI"
"""Prefix to use for AI generated responses."""
buffer: str = ""
output_key: Optional[str] = None
input_key: Optional[str] = None
memory_key: str = "history" #: :meta private:
@property
@ -41,11 +43,18 @@ class ConversationBufferMemory(Memory, BaseModel):
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
"""Save context from this conversation to buffer."""
prompt_input_key = _get_prompt_input_key(inputs, self.memory_variables)
if len(outputs) != 1:
raise ValueError(f"One output key expected, got {outputs.keys()}")
if self.input_key is None:
prompt_input_key = _get_prompt_input_key(inputs, self.memory_variables)
else:
prompt_input_key = self.input_key
if self.output_key is None:
if len(outputs) != 1:
raise ValueError(f"One output key expected, got {outputs.keys()}")
output_key = list(outputs.keys())[0]
else:
output_key = self.output_key
human = "Human: " + inputs[prompt_input_key]
ai = f"{self.ai_prefix}: " + outputs[list(outputs.keys())[0]]
ai = f"{self.ai_prefix}: " + outputs[output_key]
self.buffer += "\n" + "\n".join([human, ai])
def clear(self) -> None:
@ -60,6 +69,8 @@ class ConversationBufferWindowMemory(Memory, BaseModel):
"""Prefix to use for AI generated responses."""
buffer: List[str] = Field(default_factory=list)
memory_key: str = "history" #: :meta private:
output_key: Optional[str] = None
input_key: Optional[str] = None
k: int = 5
@property
@ -76,11 +87,18 @@ class ConversationBufferWindowMemory(Memory, BaseModel):
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
"""Save context from this conversation to buffer."""
prompt_input_key = _get_prompt_input_key(inputs, self.memory_variables)
if len(outputs) != 1:
raise ValueError(f"One output key expected, got {outputs.keys()}")
if self.input_key is None:
prompt_input_key = _get_prompt_input_key(inputs, self.memory_variables)
else:
prompt_input_key = self.input_key
if self.output_key is None:
if len(outputs) != 1:
raise ValueError(f"One output key expected, got {outputs.keys()}")
output_key = list(outputs.keys())[0]
else:
output_key = self.output_key
human = "Human: " + inputs[prompt_input_key]
ai = f"{self.ai_prefix}: " + outputs[list(outputs.keys())[0]]
ai = f"{self.ai_prefix}: " + outputs[output_key]
self.buffer.append("\n".join([human, ai]))
def clear(self) -> None:
@ -101,6 +119,8 @@ class ConversationSummaryMemory(Memory, BaseModel):
llm: BaseLLM
prompt: BasePromptTemplate = SUMMARY_PROMPT
memory_key: str = "history" #: :meta private:
output_key: Optional[str] = None
input_key: Optional[str] = None
@property
def memory_variables(self) -> List[str]:
@ -128,11 +148,18 @@ class ConversationSummaryMemory(Memory, BaseModel):
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
"""Save context from this conversation to buffer."""
prompt_input_key = _get_prompt_input_key(inputs, self.memory_variables)
if len(outputs) != 1:
raise ValueError(f"One output key expected, got {outputs.keys()}")
if self.input_key is None:
prompt_input_key = _get_prompt_input_key(inputs, self.memory_variables)
else:
prompt_input_key = self.input_key
if self.output_key is None:
if len(outputs) != 1:
raise ValueError(f"One output key expected, got {outputs.keys()}")
output_key = list(outputs.keys())[0]
else:
output_key = self.output_key
human = f"Human: {inputs[prompt_input_key]}"
ai = f"{self.ai_prefix}: {list(outputs.values())[0]}"
ai = f"{self.ai_prefix}: {output_key}"
new_lines = "\n".join([human, ai])
chain = LLMChain(llm=self.llm, prompt=self.prompt)
self.buffer = chain.predict(summary=self.buffer, new_lines=new_lines)
@ -153,6 +180,8 @@ class ConversationSummaryBufferMemory(Memory, BaseModel):
memory_key: str = "history"
ai_prefix: str = "AI"
"""Prefix to use for AI generated responses."""
output_key: Optional[str] = None
input_key: Optional[str] = None
@property
def memory_variables(self) -> List[str]:
@ -187,11 +216,18 @@ class ConversationSummaryBufferMemory(Memory, BaseModel):
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
"""Save context from this conversation to buffer."""
prompt_input_key = _get_prompt_input_key(inputs, self.memory_variables)
if len(outputs) != 1:
raise ValueError(f"One output key expected, got {outputs.keys()}")
if self.input_key is None:
prompt_input_key = _get_prompt_input_key(inputs, self.memory_variables)
else:
prompt_input_key = self.input_key
if self.output_key is None:
if len(outputs) != 1:
raise ValueError(f"One output key expected, got {outputs.keys()}")
output_key = list(outputs.keys())[0]
else:
output_key = self.output_key
human = f"Human: {inputs[prompt_input_key]}"
ai = f"{self.ai_prefix}: {list(outputs.values())[0]}"
ai = f"{self.ai_prefix}: {output_key}"
new_lines = "\n".join([human, ai])
self.buffer.append(new_lines)
# Prune buffer if it exceeds max token limit

Loading…
Cancel
Save