Generative Characters (#2859)

Add a time-weighted memory retriever and a notebook that approximates a
Generative Agent from https://arxiv.org/pdf/2304.03442.pdf


The "daily plan" components are removed for now since they are less
useful without a virtual world, but the memory is an interesting
component to build off.

---------

Co-authored-by: Harrison Chase <hw.chase.17@gmail.com>
This commit is contained in:
vowelparrot 2023-04-16 21:41:00 -07:00 committed by GitHub
parent a9310a3e8b
commit 99c0382209
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 1771 additions and 3 deletions

View File

@ -0,0 +1,213 @@
{
"cells": [
{
"attachments": {},
"cell_type": "markdown",
"id": "a90b7557",
"metadata": {},
"source": [
"# Time Weighted VectorStore Retriever\n",
"\n",
"This retriever uses a combination of semantic similarity and recency.\n",
"\n",
"The algorithm for scoring them is:\n",
"\n",
"```\n",
"semantic_similarity + (1.0 - decay_rate) ** hours_passed\n",
"```\n",
"\n",
"Notably, hours_passed refers to the hours passed since the object in the retriever **was last accessed**, not since it was created. This means that frequently accessed objects remain \"fresh.\""
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "f22cc96b",
"metadata": {},
"outputs": [],
"source": [
"import faiss\n",
"\n",
"from datetime import datetime, timedelta\n",
"from langchain.docstore import InMemoryDocstore\n",
"from langchain.embeddings import OpenAIEmbeddings\n",
"from langchain.retrievers import TimeWeightedVectorStoreRetriever\n",
"from langchain.schema import Document\n",
"from langchain.vectorstores import FAISS\n"
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "6af7ea6b",
"metadata": {},
"source": [
"## Low Decay Rate\n",
"\n",
"A low decay rate (in this, to be extreme, we will set close to 0) means memories will be \"remembered\" for longer. A decay rate of 0 means memories never be forgotten, making this retriever equivalent to the vector lookup."
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "c10e7696",
"metadata": {},
"outputs": [],
"source": [
"# Define your embedding model\n",
"embeddings_model = OpenAIEmbeddings()\n",
"# Initialize the vectorstore as empty\n",
"embedding_size = 1536\n",
"index = faiss.IndexFlatL2(embedding_size)\n",
"vectorstore = FAISS(embeddings_model.embed_query, index, InMemoryDocstore({}), {})\n",
"retriever = TimeWeightedVectorStoreRetriever(vectorstore=vectorstore, decay_rate=.0000000000000000000000001, k=1) "
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "86dbadb9",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"['129ba56b-7e7f-480b-83b3-8138a7f5db4a']"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"yesterday = datetime.now() - timedelta(days=1)\n",
"retriever.add_documents([Document(page_content=\"hello world\", metadata={\"last_accessed_at\": yesterday})])\n",
"retriever.add_documents([Document(page_content=\"hello foo\")])"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "a580be32",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[Document(page_content='hello foo', metadata={'last_accessed_at': datetime.datetime(2023, 4, 16, 15, 46, 43, 860748), 'created_at': datetime.datetime(2023, 4, 16, 15, 46, 14, 469670), 'buffer_idx': 1})]"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# \"Hello World\" is returned first because it is most salient, and the decay rate is close to 0., meaning it's still recent enough\n",
"retriever.get_relevant_documents(\"hello world\")"
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "ca056896",
"metadata": {},
"source": [
"## High Decay Rate\n",
"\n",
"With a high decay factor (e.g., several 9's), the recency score quickly goes to 0! If you set this all the way to 1, recency is 0 for all objects, once again making this equivalent to a vector lookup.\n"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "dc37669b",
"metadata": {},
"outputs": [],
"source": [
"# Define your embedding model\n",
"embeddings_model = OpenAIEmbeddings()\n",
"# Initialize the vectorstore as empty\n",
"embedding_size = 1536\n",
"index = faiss.IndexFlatL2(embedding_size)\n",
"vectorstore = FAISS(embeddings_model.embed_query, index, InMemoryDocstore({}), {})\n",
"retriever = TimeWeightedVectorStoreRetriever(vectorstore=vectorstore, decay_rate=.999, k=1) "
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "fa284384",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"['8fff7ef8-3a30-40f3-b42e-b8d5c7850863']"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"yesterday = datetime.now() - timedelta(days=1)\n",
"retriever.add_documents([Document(page_content=\"hello world\", metadata={\"last_accessed_at\": yesterday})])\n",
"retriever.add_documents([Document(page_content=\"hello foo\")])"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "7558f94d",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[Document(page_content='hello foo', metadata={'last_accessed_at': datetime.datetime(2023, 4, 16, 15, 46, 17, 646927), 'created_at': datetime.datetime(2023, 4, 16, 15, 46, 14, 469670), 'buffer_idx': 1})]"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# \"Hello Foo\" is returned first because \"hello world\" is mostly forgotten\n",
"retriever.get_relevant_documents(\"hello world\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "bf6d8c90",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.2"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

File diff suppressed because it is too large Load Diff

View File

@ -6,6 +6,9 @@ from langchain.retrievers.pinecone_hybrid_search import PineconeHybridSearchRetr
from langchain.retrievers.remote_retriever import RemoteLangChainRetriever
from langchain.retrievers.svm import SVMRetriever
from langchain.retrievers.tfidf import TFIDFRetriever
from langchain.retrievers.time_weighted_retriever import (
TimeWeightedVectorStoreRetriever,
)
from langchain.retrievers.weaviate_hybrid_search import WeaviateHybridSearchRetriever
__all__ = [
@ -17,5 +20,6 @@ __all__ = [
"TFIDFRetriever",
"WeaviateHybridSearchRetriever",
"DataberryRetriever",
"TimeWeightedVectorStoreRetriever",
"SVMRetriever",
]

View File

@ -1,6 +1,6 @@
from __future__ import annotations
from typing import List
from typing import List, Optional
import aiohttp
import requests
@ -13,8 +13,8 @@ class ChatGPTPluginRetriever(BaseRetriever, BaseModel):
url: str
bearer_token: str
top_k: int = 3
filter: dict | None = None
aiosession: aiohttp.ClientSession | None = None
filter: Optional[None] = None
aiosession: Optional[aiohttp.ClientSession] = None
class Config:
"""Configuration for this pydantic object."""

View File

@ -0,0 +1,138 @@
"""Retriever that combines embedding similarity with recency in retrieving values."""
from copy import deepcopy
from datetime import datetime
from typing import Any, Dict, List, Optional, Tuple
from pydantic import BaseModel, Field
from langchain.schema import BaseRetriever, Document
from langchain.vectorstores.base import VectorStore
def _get_hours_passed(time: datetime, ref_time: datetime) -> float:
"""Get the hours passed between two datetime objects."""
return (time - ref_time).total_seconds() / 3600
class TimeWeightedVectorStoreRetriever(BaseRetriever, BaseModel):
"""Retriever combining embededing similarity with recency."""
vectorstore: VectorStore
"""The vectorstore to store documents and determine salience."""
search_kwargs: dict = Field(default_factory=lambda: dict(k=100))
"""Keyword arguments to pass to the vectorstore similarity search."""
# TODO: abstract as a queue
memory_stream: List[Document] = Field(default_factory=list)
"""The memory_stream of documents to search through."""
decay_rate: float = Field(default=0.01)
"""The exponential decay factor used as (1.0-decay_rate)**(hrs_passed)."""
k: int = 4
"""The maximum number of documents to retrieve in a given call."""
other_score_keys: List[str] = []
"""Other keys in the metadata to factor into the score, e.g. 'importance'."""
default_salience: Optional[float] = None
"""The salience to assign memories not retrieved from the vector store.
None assigns no salience to documents not fetched from the vector store.
"""
class Config:
"""Configuration for this pydantic object."""
arbitrary_types_allowed = True
def _get_combined_score(
self,
document: Document,
vector_relevance: Optional[float],
current_time: datetime,
) -> float:
"""Return the combined score for a document."""
hours_passed = _get_hours_passed(
current_time,
document.metadata["last_accessed_at"],
)
score = (1.0 - self.decay_rate) ** hours_passed
for key in self.other_score_keys:
if key in document.metadata:
score += document.metadata[key]
if vector_relevance is not None:
score += vector_relevance
return score
def get_salient_docs(self, query: str) -> Dict[int, Tuple[Document, float]]:
"""Return documents that are salient to the query."""
docs_and_scores: List[Tuple[Document, float]]
docs_and_scores = self.vectorstore.similarity_search_with_relevance_scores(
query, **self.search_kwargs
)
results = {}
for fetched_doc, relevance in docs_and_scores:
buffer_idx = fetched_doc.metadata["buffer_idx"]
doc = self.memory_stream[buffer_idx]
results[buffer_idx] = (doc, relevance)
return results
def get_relevant_documents(self, query: str) -> List[Document]:
"""Return documents that are relevant to the query."""
current_time = datetime.now()
docs_and_scores = {
doc.metadata["buffer_idx"]: (doc, self.default_salience)
for doc in self.memory_stream[-self.k :]
}
# If a doc is considered salient, update the salience score
docs_and_scores.update(self.get_salient_docs(query))
rescored_docs = [
(doc, self._get_combined_score(doc, relevance, current_time))
for doc, relevance in docs_and_scores.values()
]
rescored_docs.sort(key=lambda x: x[1], reverse=True)
result = []
# Ensure frequently accessed memories aren't forgotten
current_time = datetime.now()
for doc, _ in rescored_docs[: self.k]:
# TODO: Update vector store doc once `update` method is exposed.
buffered_doc = self.memory_stream[doc.metadata["buffer_idx"]]
buffered_doc.metadata["last_accessed_at"] = current_time
result.append(buffered_doc)
return result
async def aget_relevant_documents(self, query: str) -> List[Document]:
"""Return documents that are relevant to the query."""
raise NotImplementedError
def add_documents(self, documents: List[Document], **kwargs: Any) -> List[str]:
"""Add documents to vectorstore."""
current_time = kwargs.get("current_time", datetime.now())
# Avoid mutating input documents
dup_docs = [deepcopy(d) for d in documents]
for i, doc in enumerate(dup_docs):
if "last_accessed_at" not in doc.metadata:
doc.metadata["last_accessed_at"] = current_time
if "created_at" not in doc.metadata:
doc.metadata["created_at"] = current_time
doc.metadata["buffer_idx"] = len(self.memory_stream) + i
self.memory_stream.extend(dup_docs)
return self.vectorstore.add_documents(dup_docs, **kwargs)
async def aadd_documents(
self, documents: List[Document], **kwargs: Any
) -> List[str]:
"""Add documents to vectorstore."""
current_time = kwargs.get("current_time", datetime.now())
# Avoid mutating input documents
dup_docs = [deepcopy(d) for d in documents]
for i, doc in enumerate(dup_docs):
if "last_accessed_at" not in doc.metadata:
doc.metadata["last_accessed_at"] = current_time
if "created_at" not in doc.metadata:
doc.metadata["created_at"] = current_time
doc.metadata["buffer_idx"] = len(self.memory_stream) + i
self.memory_stream.extend(dup_docs)
return await self.vectorstore.aadd_documents(dup_docs, **kwargs)

View File

View File

@ -0,0 +1,163 @@
"""Tests for the time-weighted retriever class."""
from datetime import datetime
from typing import Any, Iterable, List, Optional, Tuple, Type
import pytest
from langchain.embeddings.base import Embeddings
from langchain.retrievers.time_weighted_retriever import (
TimeWeightedVectorStoreRetriever,
_get_hours_passed,
)
from langchain.schema import Document
from langchain.vectorstores.base import VectorStore
def _get_example_memories(k: int = 4) -> List[Document]:
return [
Document(
page_content="foo",
metadata={
"buffer_idx": i,
"last_accessed_at": datetime(2023, 4, 14, 12, 0),
},
)
for i in range(k)
]
class MockVectorStore(VectorStore):
"""Mock invalid vector store."""
def add_texts(
self,
texts: Iterable[str],
metadatas: Optional[List[dict]] = None,
**kwargs: Any,
) -> List[str]:
"""Run more texts through the embeddings and add to the vectorstore.
Args:
texts: Iterable of strings to add to the vectorstore.
metadatas: Optional list of metadatas associated with the texts.
kwargs: vectorstore specific parameters
Returns:
List of ids from adding the texts into the vectorstore.
"""
return list(texts)
async def aadd_texts(
self,
texts: Iterable[str],
metadatas: Optional[List[dict]] = None,
**kwargs: Any,
) -> List[str]:
"""Run more texts through the embeddings and add to the vectorstore."""
raise NotImplementedError
def similarity_search(
self, query: str, k: int = 4, **kwargs: Any
) -> List[Document]:
"""Return docs most similar to query."""
return []
@classmethod
def from_documents(
cls: Type["MockVectorStore"],
documents: List[Document],
embedding: Embeddings,
**kwargs: Any,
) -> "MockVectorStore":
"""Return VectorStore initialized from documents and embeddings."""
texts = [d.page_content for d in documents]
metadatas = [d.metadata for d in documents]
return cls.from_texts(texts, embedding, metadatas=metadatas, **kwargs)
@classmethod
def from_texts(
cls: Type["MockVectorStore"],
texts: List[str],
embedding: Embeddings,
metadatas: Optional[List[dict]] = None,
**kwargs: Any,
) -> "MockVectorStore":
"""Return VectorStore initialized from texts and embeddings."""
return cls()
def _similarity_search_with_relevance_scores(
self,
query: str,
k: int = 4,
**kwargs: Any,
) -> List[Tuple[Document, float]]:
"""Return docs and similarity scores, normalized on a scale from 0 to 1.
0 is dissimilar, 1 is most similar.
"""
return [(doc, 0.5) for doc in _get_example_memories()]
@pytest.fixture
def time_weighted_retriever() -> TimeWeightedVectorStoreRetriever:
vectorstore = MockVectorStore()
return TimeWeightedVectorStoreRetriever(
vectorstore=vectorstore, memory_stream=_get_example_memories()
)
def test__get_hours_passed() -> None:
time1 = datetime(2023, 4, 14, 14, 30)
time2 = datetime(2023, 4, 14, 12, 0)
expected_hours_passed = 2.5
hours_passed = _get_hours_passed(time1, time2)
assert hours_passed == expected_hours_passed
def test_get_combined_score(
time_weighted_retriever: TimeWeightedVectorStoreRetriever,
) -> None:
document = Document(
page_content="Test document",
metadata={"last_accessed_at": datetime(2023, 4, 14, 12, 0)},
)
vector_salience = 0.7
expected_hours_passed = 2.5
current_time = datetime(2023, 4, 14, 14, 30)
combined_score = time_weighted_retriever._get_combined_score(
document, vector_salience, current_time
)
expected_score = (
1.0 - time_weighted_retriever.decay_rate
) ** expected_hours_passed + vector_salience
assert combined_score == pytest.approx(expected_score)
def test_get_salient_docs(
time_weighted_retriever: TimeWeightedVectorStoreRetriever,
) -> None:
query = "Test query"
docs_and_scores = time_weighted_retriever.get_salient_docs(query)
assert isinstance(docs_and_scores, dict)
def test_get_relevant_documents(
time_weighted_retriever: TimeWeightedVectorStoreRetriever,
) -> None:
query = "Test query"
relevant_documents = time_weighted_retriever.get_relevant_documents(query)
assert isinstance(relevant_documents, list)
def test_add_documents(
time_weighted_retriever: TimeWeightedVectorStoreRetriever,
) -> None:
documents = [Document(page_content="test_add_documents document")]
added_documents = time_weighted_retriever.add_documents(documents)
assert isinstance(added_documents, list)
assert len(added_documents) == 1
assert (
time_weighted_retriever.memory_stream[-1].page_content
== documents[0].page_content
)