From 16099505972e9f34e751b6ffbed5b7837a87e515 Mon Sep 17 00:00:00 2001 From: Harrison Chase Date: Thu, 13 Apr 2023 10:03:43 -0700 Subject: [PATCH] Harrison/retriever memory (#2804) Co-authored-by: vowelparrot <130414180+vowelparrot@users.noreply.github.com> --- .../types/vectorstore_retriever_memory.ipynb | 327 ++++++++++++++++++ langchain/memory/__init__.py | 2 + langchain/memory/vectorstore.py | 72 ++++ langchain/vectorstores/base.py | 10 + 4 files changed, 411 insertions(+) create mode 100644 docs/modules/memory/types/vectorstore_retriever_memory.ipynb create mode 100644 langchain/memory/vectorstore.py diff --git a/docs/modules/memory/types/vectorstore_retriever_memory.ipynb b/docs/modules/memory/types/vectorstore_retriever_memory.ipynb new file mode 100644 index 00000000..e44ac4d3 --- /dev/null +++ b/docs/modules/memory/types/vectorstore_retriever_memory.ipynb @@ -0,0 +1,327 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "ff4be5f3", + "metadata": {}, + "source": [ + "# VectorStore-Backed Memory\n", + "\n", + "`VectorStoreRetrieverMemory` stores interactions in a VectorDB and queries the top-K most \"salient\" interactions every type it is called.\n", + "\n", + "This differs from most of the other Memory classes in that " + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "da3384db", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "from datetime import datetime\n", + "from langchain.embeddings.openai import OpenAIEmbeddings\n", + "from langchain.llms import OpenAI\n", + "from langchain.memory import VectorStoreRetrieverMemory\n", + "from langchain.chains import ConversationChain" + ] + }, + { + "cell_type": "markdown", + "id": "c2e7abdf", + "metadata": {}, + "source": [ + "### Initialize your VectorStore\n", + "\n", + "Depending on the store you choose, this step may look different. Consult the relevant VectorStore documentation for more details." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "eef56f65", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "import faiss\n", + "\n", + "from langchain.docstore import InMemoryDocstore\n", + "from langchain.vectorstores import FAISS\n", + "\n", + "\n", + "embedding_size = 1536 # Dimensions of the OpenAIEmbeddings\n", + "index = faiss.IndexFlatL2(embedding_size)\n", + "embedding_fn = OpenAIEmbeddings().embed_query\n", + "vectorstore = FAISS(embedding_fn, index, InMemoryDocstore({}), {})" + ] + }, + { + "cell_type": "markdown", + "id": "8f4bdf92", + "metadata": {}, + "source": [ + "### Create your the VectorStoreRetrieverMemory\n", + "\n", + "The memory object is instantiated from " + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "e00d4938", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "# In actual usage, you would set `k` to be a higher value, but we use k=1 to show that\n", + "# the vector lookup still returns the semantically relevant information\n", + "retriever = vectorstore.as_retriever(search_kwargs=dict(k=1))\n", + "memory = VectorStoreRetrieverMemory(retriever=retriever)\n", + "\n", + "# When added to an agent, the memory object can save pertinent information from conversations or used tools\n", + "memory.save_context({\"input\": \"check the latest scores of the Warriors game\"}, {\"output\": \"the Warriors are up against the Astros 88 to 84\"})\n", + "memory.save_context({\"input\": \"I need help doing my taxes - what's the standard deduction this year?\"}, {\"output\": \"...\"})\n", + "memory.save_context({\"input\": \"What's the the time?\"}, {\"output\": f\"It's {datetime.now()}\"}) # " + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "2fe28a28", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "input: I need help doing my taxes - what's the standard deduction this year?\n", + "output: ...\n" + ] + } + ], + "source": [ + "# Notice the first result returned is the memory pertaining to tax help, which the language model deems more semantically relevant\n", + "# to a 1099 than the other documents, despite them both containing numbers.\n", + "print(memory.load_memory_variables({\"prompt\": \"What's a 1099?\"})[\"history\"])" + ] + }, + { + "cell_type": "markdown", + "id": "a6d2569f", + "metadata": {}, + "source": [ + "## Using in a chain\n", + "Let's walk through an example, again setting `verbose=True` so we can see the prompt." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "ebd68c10", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\n", + "\u001b[1m> Entering new ConversationChain chain...\u001b[0m\n", + "Prompt after formatting:\n", + "\u001b[32;1m\u001b[1;3mThe following is a friendly conversation between a human and an AI. The AI is talkative and provides lots of specific details from its context. If the AI does not know the answer to a question, it truthfully says it does not know.\n", + "\n", + "Current conversation:\n", + "input: I need help doing my taxes - what's the standard deduction this year?\n", + "output: ...\n", + "Human: Hi, my name is Perry, what's up?\n", + "AI:\u001b[0m\n", + "\n", + "\u001b[1m> Finished chain.\u001b[0m\n" + ] + }, + { + "data": { + "text/plain": [ + "\" Hi Perry, my name is AI. I'm doing great, how about you? I understand you need help with your taxes. What specifically do you need help with?\"" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "llm = OpenAI(temperature=0) # Can be any valid LLM\n", + "conversation_with_summary = ConversationChain(\n", + " llm=llm, \n", + " # We set a very low max_token_limit for the purposes of testing.\n", + " memory=memory,\n", + " verbose=True\n", + ")\n", + "conversation_with_summary.predict(input=\"Hi, my name is Perry, what's up?\")" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "86207a61", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\n", + "\u001b[1m> Entering new ConversationChain chain...\u001b[0m\n", + "Prompt after formatting:\n", + "\u001b[32;1m\u001b[1;3mThe following is a friendly conversation between a human and an AI. The AI is talkative and provides lots of specific details from its context. If the AI does not know the answer to a question, it truthfully says it does not know.\n", + "\n", + "Current conversation:\n", + "input: check the latest scores of the Warriors game\n", + "output: the Warriors are up against the Astros 88 to 84\n", + "Human: If the Cavaliers were to face off against the Warriers or the Astros, who would they most stand a chance to beat?\n", + "AI:\u001b[0m\n", + "\n", + "\u001b[1m> Finished chain.\u001b[0m\n" + ] + }, + { + "data": { + "text/plain": [ + "\" It's hard to say without knowing the current form of the teams. However, based on the current scores, it looks like the Cavaliers would have a better chance of beating the Astros than the Warriors.\"" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Here, the basketball related content is surfaced\n", + "conversation_with_summary.predict(input=\"If the Cavaliers were to face off against the Warriers or the Astros, who would they most stand a chance to beat?\")" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "8c669db1", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\n", + "\u001b[1m> Entering new ConversationChain chain...\u001b[0m\n", + "Prompt after formatting:\n", + "\u001b[32;1m\u001b[1;3mThe following is a friendly conversation between a human and an AI. The AI is talkative and provides lots of specific details from its context. If the AI does not know the answer to a question, it truthfully says it does not know.\n", + "\n", + "Current conversation:\n", + "input: What's the the time?\n", + "output: It's 2023-04-13 09:18:55.623736\n", + "Human: What day is it tomorrow?\n", + "AI:\u001b[0m\n", + "\n", + "\u001b[1m> Finished chain.\u001b[0m\n" + ] + }, + { + "data": { + "text/plain": [ + "' Tomorrow is 2023-04-14.'" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Even though the language model is stateless, since relavent memory is fetched, it can \"reason\" about the time.\n", + "# Timestamping memories and data is useful in general to let the agent determine temporal relevance\n", + "conversation_with_summary.predict(input=\"What day is it tomorrow?\")" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "8c09a239", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\n", + "\u001b[1m> Entering new ConversationChain chain...\u001b[0m\n", + "Prompt after formatting:\n", + "\u001b[32;1m\u001b[1;3mThe following is a friendly conversation between a human and an AI. The AI is talkative and provides lots of specific details from its context. If the AI does not know the answer to a question, it truthfully says it does not know.\n", + "\n", + "Current conversation:\n", + "input: Hi, my name is Perry, what's up?\n", + "response: Hi Perry, my name is AI. I'm doing great, how about you? I understand you need help with your taxes. What specifically do you need help with?\n", + "Human: What's your name?\n", + "AI:\u001b[0m\n", + "\n", + "\u001b[1m> Finished chain.\u001b[0m\n" + ] + }, + { + "data": { + "text/plain": [ + "\" My name is AI. It's nice to meet you, Perry.\"" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# The memories from the conversation are automatically stored,\n", + "# since this query best matches the introduction chat above,\n", + "# the agent is able to 'remember' the user's name.\n", + "conversation_with_summary.predict(input=\"What's your name?\")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.2" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/langchain/memory/__init__.py b/langchain/memory/__init__.py index b085a336..643eecf2 100644 --- a/langchain/memory/__init__.py +++ b/langchain/memory/__init__.py @@ -19,6 +19,7 @@ from langchain.memory.simple import SimpleMemory from langchain.memory.summary import ConversationSummaryMemory from langchain.memory.summary_buffer import ConversationSummaryBufferMemory from langchain.memory.token_buffer import ConversationTokenBufferMemory +from langchain.memory.vectorstore import VectorStoreRetrieverMemory __all__ = [ "CombinedMemory", @@ -38,4 +39,5 @@ __all__ = [ "RedisChatMessageHistory", "DynamoDBChatMessageHistory", "PostgresChatMessageHistory", + "VectorStoreRetrieverMemory", ] diff --git a/langchain/memory/vectorstore.py b/langchain/memory/vectorstore.py new file mode 100644 index 00000000..d5c40f26 --- /dev/null +++ b/langchain/memory/vectorstore.py @@ -0,0 +1,72 @@ +"""Class for a VectorStore-backed memory object.""" + +from typing import Any, Dict, List, Optional, Union + +from pydantic import Field + +from langchain.memory.chat_memory import BaseMemory +from langchain.memory.utils import get_prompt_input_key +from langchain.schema import Document +from langchain.vectorstores.base import VectorStoreRetriever + + +class VectorStoreRetrieverMemory(BaseMemory): + """Class for a VectorStore-backed memory object.""" + + retriever: VectorStoreRetriever = Field(exclude=True) + """VectorStoreRetriever object to connect to.""" + + memory_key: str = "history" #: :meta private: + """Key name to locate the memories in the result of load_memory_variables.""" + + input_key: Optional[str] = None + """Key name to index the inputs to load_memory_variables.""" + + return_docs: bool = False + """Whether or not to return the result of querying the database directly.""" + + @property + def memory_variables(self) -> List[str]: + """The list of keys emitted from the load_memory_variables method.""" + return [self.memory_key] + + def _get_prompt_input_key(self, inputs: Dict[str, Any]) -> str: + """Get the input key for the prompt.""" + if self.input_key is None: + return get_prompt_input_key(inputs, self.memory_variables) + return self.input_key + + def load_memory_variables( + self, inputs: Dict[str, Any] + ) -> Dict[str, Union[List[Document], str]]: + """Return history buffer.""" + input_key = self._get_prompt_input_key(inputs) + query = inputs[input_key] + docs = self.retriever.get_relevant_documents(query) + result: Union[List[Document], str] + if not self.return_docs: + result = "\n".join([doc.page_content for doc in docs]) + else: + result = docs + return {self.memory_key: result} + + def _form_documents( + self, inputs: Dict[str, Any], outputs: Dict[str, str] + ) -> List[Document]: + """Format context from this conversation to buffer.""" + # Each document should only include the current turn, not the chat history + filtered_inputs = {k: v for k, v in inputs.items() if k != self.memory_key} + texts = [ + f"{k}: {v}" + for k, v in list(filtered_inputs.items()) + list(outputs.items()) + ] + page_content = "\n".join(texts) + return [Document(page_content=page_content)] + + def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None: + """Save context from this conversation to buffer.""" + documents = self._form_documents(inputs, outputs) + self.retriever.add_documents(documents) + + def clear(self) -> None: + """Nothing to clear.""" diff --git a/langchain/vectorstores/base.py b/langchain/vectorstores/base.py index b1c9bf64..f995d3dd 100644 --- a/langchain/vectorstores/base.py +++ b/langchain/vectorstores/base.py @@ -262,3 +262,13 @@ class VectorStoreRetriever(BaseRetriever, BaseModel): else: raise ValueError(f"search_type of {self.search_type} not allowed.") return docs + + def add_documents(self, documents: List[Document], **kwargs: Any) -> List[str]: + """Add documents to vectorstore.""" + return self.vectorstore.add_documents(documents, **kwargs) + + async def aadd_documents( + self, documents: List[Document], **kwargs: Any + ) -> List[str]: + """Add documents to vectorstore.""" + return await self.vectorstore.aadd_documents(documents, **kwargs)