forked from Archives/langchain
Harrison/retriever memory (#2804)
Co-authored-by: vowelparrot <130414180+vowelparrot@users.noreply.github.com>
This commit is contained in:
parent
7688bf9182
commit
1609950597
327
docs/modules/memory/types/vectorstore_retriever_memory.ipynb
Normal file
327
docs/modules/memory/types/vectorstore_retriever_memory.ipynb
Normal file
@ -0,0 +1,327 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "ff4be5f3",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# VectorStore-Backed Memory\n",
|
||||
"\n",
|
||||
"`VectorStoreRetrieverMemory` stores interactions in a VectorDB and queries the top-K most \"salient\" interactions every type it is called.\n",
|
||||
"\n",
|
||||
"This differs from most of the other Memory classes in that "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"id": "da3384db",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from datetime import datetime\n",
|
||||
"from langchain.embeddings.openai import OpenAIEmbeddings\n",
|
||||
"from langchain.llms import OpenAI\n",
|
||||
"from langchain.memory import VectorStoreRetrieverMemory\n",
|
||||
"from langchain.chains import ConversationChain"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "c2e7abdf",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Initialize your VectorStore\n",
|
||||
"\n",
|
||||
"Depending on the store you choose, this step may look different. Consult the relevant VectorStore documentation for more details."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"id": "eef56f65",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import faiss\n",
|
||||
"\n",
|
||||
"from langchain.docstore import InMemoryDocstore\n",
|
||||
"from langchain.vectorstores import FAISS\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"embedding_size = 1536 # Dimensions of the OpenAIEmbeddings\n",
|
||||
"index = faiss.IndexFlatL2(embedding_size)\n",
|
||||
"embedding_fn = OpenAIEmbeddings().embed_query\n",
|
||||
"vectorstore = FAISS(embedding_fn, index, InMemoryDocstore({}), {})"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "8f4bdf92",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Create your the VectorStoreRetrieverMemory\n",
|
||||
"\n",
|
||||
"The memory object is instantiated from "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"id": "e00d4938",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# In actual usage, you would set `k` to be a higher value, but we use k=1 to show that\n",
|
||||
"# the vector lookup still returns the semantically relevant information\n",
|
||||
"retriever = vectorstore.as_retriever(search_kwargs=dict(k=1))\n",
|
||||
"memory = VectorStoreRetrieverMemory(retriever=retriever)\n",
|
||||
"\n",
|
||||
"# When added to an agent, the memory object can save pertinent information from conversations or used tools\n",
|
||||
"memory.save_context({\"input\": \"check the latest scores of the Warriors game\"}, {\"output\": \"the Warriors are up against the Astros 88 to 84\"})\n",
|
||||
"memory.save_context({\"input\": \"I need help doing my taxes - what's the standard deduction this year?\"}, {\"output\": \"...\"})\n",
|
||||
"memory.save_context({\"input\": \"What's the the time?\"}, {\"output\": f\"It's {datetime.now()}\"}) # "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"id": "2fe28a28",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"input: I need help doing my taxes - what's the standard deduction this year?\n",
|
||||
"output: ...\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# Notice the first result returned is the memory pertaining to tax help, which the language model deems more semantically relevant\n",
|
||||
"# to a 1099 than the other documents, despite them both containing numbers.\n",
|
||||
"print(memory.load_memory_variables({\"prompt\": \"What's a 1099?\"})[\"history\"])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "a6d2569f",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Using in a chain\n",
|
||||
"Let's walk through an example, again setting `verbose=True` so we can see the prompt."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"id": "ebd68c10",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n",
|
||||
"\n",
|
||||
"\u001b[1m> Entering new ConversationChain chain...\u001b[0m\n",
|
||||
"Prompt after formatting:\n",
|
||||
"\u001b[32;1m\u001b[1;3mThe following is a friendly conversation between a human and an AI. The AI is talkative and provides lots of specific details from its context. If the AI does not know the answer to a question, it truthfully says it does not know.\n",
|
||||
"\n",
|
||||
"Current conversation:\n",
|
||||
"input: I need help doing my taxes - what's the standard deduction this year?\n",
|
||||
"output: ...\n",
|
||||
"Human: Hi, my name is Perry, what's up?\n",
|
||||
"AI:\u001b[0m\n",
|
||||
"\n",
|
||||
"\u001b[1m> Finished chain.\u001b[0m\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"\" Hi Perry, my name is AI. I'm doing great, how about you? I understand you need help with your taxes. What specifically do you need help with?\""
|
||||
]
|
||||
},
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"llm = OpenAI(temperature=0) # Can be any valid LLM\n",
|
||||
"conversation_with_summary = ConversationChain(\n",
|
||||
" llm=llm, \n",
|
||||
" # We set a very low max_token_limit for the purposes of testing.\n",
|
||||
" memory=memory,\n",
|
||||
" verbose=True\n",
|
||||
")\n",
|
||||
"conversation_with_summary.predict(input=\"Hi, my name is Perry, what's up?\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"id": "86207a61",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n",
|
||||
"\n",
|
||||
"\u001b[1m> Entering new ConversationChain chain...\u001b[0m\n",
|
||||
"Prompt after formatting:\n",
|
||||
"\u001b[32;1m\u001b[1;3mThe following is a friendly conversation between a human and an AI. The AI is talkative and provides lots of specific details from its context. If the AI does not know the answer to a question, it truthfully says it does not know.\n",
|
||||
"\n",
|
||||
"Current conversation:\n",
|
||||
"input: check the latest scores of the Warriors game\n",
|
||||
"output: the Warriors are up against the Astros 88 to 84\n",
|
||||
"Human: If the Cavaliers were to face off against the Warriers or the Astros, who would they most stand a chance to beat?\n",
|
||||
"AI:\u001b[0m\n",
|
||||
"\n",
|
||||
"\u001b[1m> Finished chain.\u001b[0m\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"\" It's hard to say without knowing the current form of the teams. However, based on the current scores, it looks like the Cavaliers would have a better chance of beating the Astros than the Warriors.\""
|
||||
]
|
||||
},
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# Here, the basketball related content is surfaced\n",
|
||||
"conversation_with_summary.predict(input=\"If the Cavaliers were to face off against the Warriers or the Astros, who would they most stand a chance to beat?\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"id": "8c669db1",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n",
|
||||
"\n",
|
||||
"\u001b[1m> Entering new ConversationChain chain...\u001b[0m\n",
|
||||
"Prompt after formatting:\n",
|
||||
"\u001b[32;1m\u001b[1;3mThe following is a friendly conversation between a human and an AI. The AI is talkative and provides lots of specific details from its context. If the AI does not know the answer to a question, it truthfully says it does not know.\n",
|
||||
"\n",
|
||||
"Current conversation:\n",
|
||||
"input: What's the the time?\n",
|
||||
"output: It's 2023-04-13 09:18:55.623736\n",
|
||||
"Human: What day is it tomorrow?\n",
|
||||
"AI:\u001b[0m\n",
|
||||
"\n",
|
||||
"\u001b[1m> Finished chain.\u001b[0m\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"' Tomorrow is 2023-04-14.'"
|
||||
]
|
||||
},
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# Even though the language model is stateless, since relavent memory is fetched, it can \"reason\" about the time.\n",
|
||||
"# Timestamping memories and data is useful in general to let the agent determine temporal relevance\n",
|
||||
"conversation_with_summary.predict(input=\"What day is it tomorrow?\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"id": "8c09a239",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n",
|
||||
"\n",
|
||||
"\u001b[1m> Entering new ConversationChain chain...\u001b[0m\n",
|
||||
"Prompt after formatting:\n",
|
||||
"\u001b[32;1m\u001b[1;3mThe following is a friendly conversation between a human and an AI. The AI is talkative and provides lots of specific details from its context. If the AI does not know the answer to a question, it truthfully says it does not know.\n",
|
||||
"\n",
|
||||
"Current conversation:\n",
|
||||
"input: Hi, my name is Perry, what's up?\n",
|
||||
"response: Hi Perry, my name is AI. I'm doing great, how about you? I understand you need help with your taxes. What specifically do you need help with?\n",
|
||||
"Human: What's your name?\n",
|
||||
"AI:\u001b[0m\n",
|
||||
"\n",
|
||||
"\u001b[1m> Finished chain.\u001b[0m\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"\" My name is AI. It's nice to meet you, Perry.\""
|
||||
]
|
||||
},
|
||||
"execution_count": 8,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# The memories from the conversation are automatically stored,\n",
|
||||
"# since this query best matches the introduction chat above,\n",
|
||||
"# the agent is able to 'remember' the user's name.\n",
|
||||
"conversation_with_summary.predict(input=\"What's your name?\")"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.2"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
@ -19,6 +19,7 @@ from langchain.memory.simple import SimpleMemory
|
||||
from langchain.memory.summary import ConversationSummaryMemory
|
||||
from langchain.memory.summary_buffer import ConversationSummaryBufferMemory
|
||||
from langchain.memory.token_buffer import ConversationTokenBufferMemory
|
||||
from langchain.memory.vectorstore import VectorStoreRetrieverMemory
|
||||
|
||||
__all__ = [
|
||||
"CombinedMemory",
|
||||
@ -38,4 +39,5 @@ __all__ = [
|
||||
"RedisChatMessageHistory",
|
||||
"DynamoDBChatMessageHistory",
|
||||
"PostgresChatMessageHistory",
|
||||
"VectorStoreRetrieverMemory",
|
||||
]
|
||||
|
72
langchain/memory/vectorstore.py
Normal file
72
langchain/memory/vectorstore.py
Normal file
@ -0,0 +1,72 @@
|
||||
"""Class for a VectorStore-backed memory object."""
|
||||
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from langchain.memory.chat_memory import BaseMemory
|
||||
from langchain.memory.utils import get_prompt_input_key
|
||||
from langchain.schema import Document
|
||||
from langchain.vectorstores.base import VectorStoreRetriever
|
||||
|
||||
|
||||
class VectorStoreRetrieverMemory(BaseMemory):
|
||||
"""Class for a VectorStore-backed memory object."""
|
||||
|
||||
retriever: VectorStoreRetriever = Field(exclude=True)
|
||||
"""VectorStoreRetriever object to connect to."""
|
||||
|
||||
memory_key: str = "history" #: :meta private:
|
||||
"""Key name to locate the memories in the result of load_memory_variables."""
|
||||
|
||||
input_key: Optional[str] = None
|
||||
"""Key name to index the inputs to load_memory_variables."""
|
||||
|
||||
return_docs: bool = False
|
||||
"""Whether or not to return the result of querying the database directly."""
|
||||
|
||||
@property
|
||||
def memory_variables(self) -> List[str]:
|
||||
"""The list of keys emitted from the load_memory_variables method."""
|
||||
return [self.memory_key]
|
||||
|
||||
def _get_prompt_input_key(self, inputs: Dict[str, Any]) -> str:
|
||||
"""Get the input key for the prompt."""
|
||||
if self.input_key is None:
|
||||
return get_prompt_input_key(inputs, self.memory_variables)
|
||||
return self.input_key
|
||||
|
||||
def load_memory_variables(
|
||||
self, inputs: Dict[str, Any]
|
||||
) -> Dict[str, Union[List[Document], str]]:
|
||||
"""Return history buffer."""
|
||||
input_key = self._get_prompt_input_key(inputs)
|
||||
query = inputs[input_key]
|
||||
docs = self.retriever.get_relevant_documents(query)
|
||||
result: Union[List[Document], str]
|
||||
if not self.return_docs:
|
||||
result = "\n".join([doc.page_content for doc in docs])
|
||||
else:
|
||||
result = docs
|
||||
return {self.memory_key: result}
|
||||
|
||||
def _form_documents(
|
||||
self, inputs: Dict[str, Any], outputs: Dict[str, str]
|
||||
) -> List[Document]:
|
||||
"""Format context from this conversation to buffer."""
|
||||
# Each document should only include the current turn, not the chat history
|
||||
filtered_inputs = {k: v for k, v in inputs.items() if k != self.memory_key}
|
||||
texts = [
|
||||
f"{k}: {v}"
|
||||
for k, v in list(filtered_inputs.items()) + list(outputs.items())
|
||||
]
|
||||
page_content = "\n".join(texts)
|
||||
return [Document(page_content=page_content)]
|
||||
|
||||
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
|
||||
"""Save context from this conversation to buffer."""
|
||||
documents = self._form_documents(inputs, outputs)
|
||||
self.retriever.add_documents(documents)
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Nothing to clear."""
|
@ -262,3 +262,13 @@ class VectorStoreRetriever(BaseRetriever, BaseModel):
|
||||
else:
|
||||
raise ValueError(f"search_type of {self.search_type} not allowed.")
|
||||
return docs
|
||||
|
||||
def add_documents(self, documents: List[Document], **kwargs: Any) -> List[str]:
|
||||
"""Add documents to vectorstore."""
|
||||
return self.vectorstore.add_documents(documents, **kwargs)
|
||||
|
||||
async def aadd_documents(
|
||||
self, documents: List[Document], **kwargs: Any
|
||||
) -> List[str]:
|
||||
"""Add documents to vectorstore."""
|
||||
return await self.vectorstore.aadd_documents(documents, **kwargs)
|
||||
|
Loading…
Reference in New Issue
Block a user