You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
langchain/langchain/memory/vectorstore.py

73 lines
2.7 KiB
Python

"""Class for a VectorStore-backed memory object."""
from typing import Any, Dict, List, Optional, Union
from pydantic import Field
from langchain.memory.chat_memory import BaseMemory
from langchain.memory.utils import get_prompt_input_key
from langchain.schema import Document
from langchain.vectorstores.base import VectorStoreRetriever
class VectorStoreRetrieverMemory(BaseMemory):
"""Class for a VectorStore-backed memory object."""
retriever: VectorStoreRetriever = Field(exclude=True)
"""VectorStoreRetriever object to connect to."""
memory_key: str = "history" #: :meta private:
"""Key name to locate the memories in the result of load_memory_variables."""
input_key: Optional[str] = None
"""Key name to index the inputs to load_memory_variables."""
return_docs: bool = False
"""Whether or not to return the result of querying the database directly."""
@property
def memory_variables(self) -> List[str]:
"""The list of keys emitted from the load_memory_variables method."""
return [self.memory_key]
def _get_prompt_input_key(self, inputs: Dict[str, Any]) -> str:
"""Get the input key for the prompt."""
if self.input_key is None:
return get_prompt_input_key(inputs, self.memory_variables)
return self.input_key
def load_memory_variables(
self, inputs: Dict[str, Any]
) -> Dict[str, Union[List[Document], str]]:
"""Return history buffer."""
input_key = self._get_prompt_input_key(inputs)
query = inputs[input_key]
docs = self.retriever.get_relevant_documents(query)
result: Union[List[Document], str]
if not self.return_docs:
result = "\n".join([doc.page_content for doc in docs])
else:
result = docs
return {self.memory_key: result}
def _form_documents(
self, inputs: Dict[str, Any], outputs: Dict[str, str]
) -> List[Document]:
"""Format context from this conversation to buffer."""
# Each document should only include the current turn, not the chat history
filtered_inputs = {k: v for k, v in inputs.items() if k != self.memory_key}
texts = [
f"{k}: {v}"
for k, v in list(filtered_inputs.items()) + list(outputs.items())
]
page_content = "\n".join(texts)
return [Document(page_content=page_content)]
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
"""Save context from this conversation to buffer."""
documents = self._form_documents(inputs, outputs)
self.retriever.add_documents(documents)
def clear(self) -> None:
"""Nothing to clear."""