forked from Archives/langchain
Generative Characters (#2859)
Add a time-weighted memory retriever and a notebook that approximates a Generative Agent from https://arxiv.org/pdf/2304.03442.pdf The "daily plan" components are removed for now since they are less useful without a virtual world, but the memory is an interesting component to build off. --------- Co-authored-by: Harrison Chase <hw.chase.17@gmail.com>
This commit is contained in:
parent
a9310a3e8b
commit
99c0382209
@ -0,0 +1,213 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"id": "a90b7557",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Time Weighted VectorStore Retriever\n",
|
||||
"\n",
|
||||
"This retriever uses a combination of semantic similarity and recency.\n",
|
||||
"\n",
|
||||
"The algorithm for scoring them is:\n",
|
||||
"\n",
|
||||
"```\n",
|
||||
"semantic_similarity + (1.0 - decay_rate) ** hours_passed\n",
|
||||
"```\n",
|
||||
"\n",
|
||||
"Notably, hours_passed refers to the hours passed since the object in the retriever **was last accessed**, not since it was created. This means that frequently accessed objects remain \"fresh.\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"id": "f22cc96b",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import faiss\n",
|
||||
"\n",
|
||||
"from datetime import datetime, timedelta\n",
|
||||
"from langchain.docstore import InMemoryDocstore\n",
|
||||
"from langchain.embeddings import OpenAIEmbeddings\n",
|
||||
"from langchain.retrievers import TimeWeightedVectorStoreRetriever\n",
|
||||
"from langchain.schema import Document\n",
|
||||
"from langchain.vectorstores import FAISS\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"id": "6af7ea6b",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Low Decay Rate\n",
|
||||
"\n",
|
||||
"A low decay rate (in this, to be extreme, we will set close to 0) means memories will be \"remembered\" for longer. A decay rate of 0 means memories never be forgotten, making this retriever equivalent to the vector lookup."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"id": "c10e7696",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Define your embedding model\n",
|
||||
"embeddings_model = OpenAIEmbeddings()\n",
|
||||
"# Initialize the vectorstore as empty\n",
|
||||
"embedding_size = 1536\n",
|
||||
"index = faiss.IndexFlatL2(embedding_size)\n",
|
||||
"vectorstore = FAISS(embeddings_model.embed_query, index, InMemoryDocstore({}), {})\n",
|
||||
"retriever = TimeWeightedVectorStoreRetriever(vectorstore=vectorstore, decay_rate=.0000000000000000000000001, k=1) "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"id": "86dbadb9",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"['129ba56b-7e7f-480b-83b3-8138a7f5db4a']"
|
||||
]
|
||||
},
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"yesterday = datetime.now() - timedelta(days=1)\n",
|
||||
"retriever.add_documents([Document(page_content=\"hello world\", metadata={\"last_accessed_at\": yesterday})])\n",
|
||||
"retriever.add_documents([Document(page_content=\"hello foo\")])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"id": "a580be32",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"[Document(page_content='hello foo', metadata={'last_accessed_at': datetime.datetime(2023, 4, 16, 15, 46, 43, 860748), 'created_at': datetime.datetime(2023, 4, 16, 15, 46, 14, 469670), 'buffer_idx': 1})]"
|
||||
]
|
||||
},
|
||||
"execution_count": 9,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# \"Hello World\" is returned first because it is most salient, and the decay rate is close to 0., meaning it's still recent enough\n",
|
||||
"retriever.get_relevant_documents(\"hello world\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"id": "ca056896",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## High Decay Rate\n",
|
||||
"\n",
|
||||
"With a high decay factor (e.g., several 9's), the recency score quickly goes to 0! If you set this all the way to 1, recency is 0 for all objects, once again making this equivalent to a vector lookup.\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"id": "dc37669b",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Define your embedding model\n",
|
||||
"embeddings_model = OpenAIEmbeddings()\n",
|
||||
"# Initialize the vectorstore as empty\n",
|
||||
"embedding_size = 1536\n",
|
||||
"index = faiss.IndexFlatL2(embedding_size)\n",
|
||||
"vectorstore = FAISS(embeddings_model.embed_query, index, InMemoryDocstore({}), {})\n",
|
||||
"retriever = TimeWeightedVectorStoreRetriever(vectorstore=vectorstore, decay_rate=.999, k=1) "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"id": "fa284384",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"['8fff7ef8-3a30-40f3-b42e-b8d5c7850863']"
|
||||
]
|
||||
},
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"yesterday = datetime.now() - timedelta(days=1)\n",
|
||||
"retriever.add_documents([Document(page_content=\"hello world\", metadata={\"last_accessed_at\": yesterday})])\n",
|
||||
"retriever.add_documents([Document(page_content=\"hello foo\")])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"id": "7558f94d",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"[Document(page_content='hello foo', metadata={'last_accessed_at': datetime.datetime(2023, 4, 16, 15, 46, 17, 646927), 'created_at': datetime.datetime(2023, 4, 16, 15, 46, 14, 469670), 'buffer_idx': 1})]"
|
||||
]
|
||||
},
|
||||
"execution_count": 8,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# \"Hello Foo\" is returned first because \"hello world\" is mostly forgotten\n",
|
||||
"retriever.get_relevant_documents(\"hello world\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "bf6d8c90",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.2"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
1250
docs/use_cases/agents/characters.ipynb
Normal file
1250
docs/use_cases/agents/characters.ipynb
Normal file
File diff suppressed because it is too large
Load Diff
@ -6,6 +6,9 @@ from langchain.retrievers.pinecone_hybrid_search import PineconeHybridSearchRetr
|
||||
from langchain.retrievers.remote_retriever import RemoteLangChainRetriever
|
||||
from langchain.retrievers.svm import SVMRetriever
|
||||
from langchain.retrievers.tfidf import TFIDFRetriever
|
||||
from langchain.retrievers.time_weighted_retriever import (
|
||||
TimeWeightedVectorStoreRetriever,
|
||||
)
|
||||
from langchain.retrievers.weaviate_hybrid_search import WeaviateHybridSearchRetriever
|
||||
|
||||
__all__ = [
|
||||
@ -17,5 +20,6 @@ __all__ = [
|
||||
"TFIDFRetriever",
|
||||
"WeaviateHybridSearchRetriever",
|
||||
"DataberryRetriever",
|
||||
"TimeWeightedVectorStoreRetriever",
|
||||
"SVMRetriever",
|
||||
]
|
||||
|
@ -1,6 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import List
|
||||
from typing import List, Optional
|
||||
|
||||
import aiohttp
|
||||
import requests
|
||||
@ -13,8 +13,8 @@ class ChatGPTPluginRetriever(BaseRetriever, BaseModel):
|
||||
url: str
|
||||
bearer_token: str
|
||||
top_k: int = 3
|
||||
filter: dict | None = None
|
||||
aiosession: aiohttp.ClientSession | None = None
|
||||
filter: Optional[None] = None
|
||||
aiosession: Optional[aiohttp.ClientSession] = None
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
138
langchain/retrievers/time_weighted_retriever.py
Normal file
138
langchain/retrievers/time_weighted_retriever.py
Normal file
@ -0,0 +1,138 @@
|
||||
"""Retriever that combines embedding similarity with recency in retrieving values."""
|
||||
from copy import deepcopy
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from langchain.schema import BaseRetriever, Document
|
||||
from langchain.vectorstores.base import VectorStore
|
||||
|
||||
|
||||
def _get_hours_passed(time: datetime, ref_time: datetime) -> float:
|
||||
"""Get the hours passed between two datetime objects."""
|
||||
return (time - ref_time).total_seconds() / 3600
|
||||
|
||||
|
||||
class TimeWeightedVectorStoreRetriever(BaseRetriever, BaseModel):
|
||||
"""Retriever combining embededing similarity with recency."""
|
||||
|
||||
vectorstore: VectorStore
|
||||
"""The vectorstore to store documents and determine salience."""
|
||||
|
||||
search_kwargs: dict = Field(default_factory=lambda: dict(k=100))
|
||||
"""Keyword arguments to pass to the vectorstore similarity search."""
|
||||
|
||||
# TODO: abstract as a queue
|
||||
memory_stream: List[Document] = Field(default_factory=list)
|
||||
"""The memory_stream of documents to search through."""
|
||||
|
||||
decay_rate: float = Field(default=0.01)
|
||||
"""The exponential decay factor used as (1.0-decay_rate)**(hrs_passed)."""
|
||||
|
||||
k: int = 4
|
||||
"""The maximum number of documents to retrieve in a given call."""
|
||||
|
||||
other_score_keys: List[str] = []
|
||||
"""Other keys in the metadata to factor into the score, e.g. 'importance'."""
|
||||
|
||||
default_salience: Optional[float] = None
|
||||
"""The salience to assign memories not retrieved from the vector store.
|
||||
|
||||
None assigns no salience to documents not fetched from the vector store.
|
||||
"""
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def _get_combined_score(
|
||||
self,
|
||||
document: Document,
|
||||
vector_relevance: Optional[float],
|
||||
current_time: datetime,
|
||||
) -> float:
|
||||
"""Return the combined score for a document."""
|
||||
hours_passed = _get_hours_passed(
|
||||
current_time,
|
||||
document.metadata["last_accessed_at"],
|
||||
)
|
||||
score = (1.0 - self.decay_rate) ** hours_passed
|
||||
for key in self.other_score_keys:
|
||||
if key in document.metadata:
|
||||
score += document.metadata[key]
|
||||
if vector_relevance is not None:
|
||||
score += vector_relevance
|
||||
return score
|
||||
|
||||
def get_salient_docs(self, query: str) -> Dict[int, Tuple[Document, float]]:
|
||||
"""Return documents that are salient to the query."""
|
||||
docs_and_scores: List[Tuple[Document, float]]
|
||||
docs_and_scores = self.vectorstore.similarity_search_with_relevance_scores(
|
||||
query, **self.search_kwargs
|
||||
)
|
||||
results = {}
|
||||
for fetched_doc, relevance in docs_and_scores:
|
||||
buffer_idx = fetched_doc.metadata["buffer_idx"]
|
||||
doc = self.memory_stream[buffer_idx]
|
||||
results[buffer_idx] = (doc, relevance)
|
||||
return results
|
||||
|
||||
def get_relevant_documents(self, query: str) -> List[Document]:
|
||||
"""Return documents that are relevant to the query."""
|
||||
current_time = datetime.now()
|
||||
docs_and_scores = {
|
||||
doc.metadata["buffer_idx"]: (doc, self.default_salience)
|
||||
for doc in self.memory_stream[-self.k :]
|
||||
}
|
||||
# If a doc is considered salient, update the salience score
|
||||
docs_and_scores.update(self.get_salient_docs(query))
|
||||
rescored_docs = [
|
||||
(doc, self._get_combined_score(doc, relevance, current_time))
|
||||
for doc, relevance in docs_and_scores.values()
|
||||
]
|
||||
rescored_docs.sort(key=lambda x: x[1], reverse=True)
|
||||
result = []
|
||||
# Ensure frequently accessed memories aren't forgotten
|
||||
current_time = datetime.now()
|
||||
for doc, _ in rescored_docs[: self.k]:
|
||||
# TODO: Update vector store doc once `update` method is exposed.
|
||||
buffered_doc = self.memory_stream[doc.metadata["buffer_idx"]]
|
||||
buffered_doc.metadata["last_accessed_at"] = current_time
|
||||
result.append(buffered_doc)
|
||||
return result
|
||||
|
||||
async def aget_relevant_documents(self, query: str) -> List[Document]:
|
||||
"""Return documents that are relevant to the query."""
|
||||
raise NotImplementedError
|
||||
|
||||
def add_documents(self, documents: List[Document], **kwargs: Any) -> List[str]:
|
||||
"""Add documents to vectorstore."""
|
||||
current_time = kwargs.get("current_time", datetime.now())
|
||||
# Avoid mutating input documents
|
||||
dup_docs = [deepcopy(d) for d in documents]
|
||||
for i, doc in enumerate(dup_docs):
|
||||
if "last_accessed_at" not in doc.metadata:
|
||||
doc.metadata["last_accessed_at"] = current_time
|
||||
if "created_at" not in doc.metadata:
|
||||
doc.metadata["created_at"] = current_time
|
||||
doc.metadata["buffer_idx"] = len(self.memory_stream) + i
|
||||
self.memory_stream.extend(dup_docs)
|
||||
return self.vectorstore.add_documents(dup_docs, **kwargs)
|
||||
|
||||
async def aadd_documents(
|
||||
self, documents: List[Document], **kwargs: Any
|
||||
) -> List[str]:
|
||||
"""Add documents to vectorstore."""
|
||||
current_time = kwargs.get("current_time", datetime.now())
|
||||
# Avoid mutating input documents
|
||||
dup_docs = [deepcopy(d) for d in documents]
|
||||
for i, doc in enumerate(dup_docs):
|
||||
if "last_accessed_at" not in doc.metadata:
|
||||
doc.metadata["last_accessed_at"] = current_time
|
||||
if "created_at" not in doc.metadata:
|
||||
doc.metadata["created_at"] = current_time
|
||||
doc.metadata["buffer_idx"] = len(self.memory_stream) + i
|
||||
self.memory_stream.extend(dup_docs)
|
||||
return await self.vectorstore.aadd_documents(dup_docs, **kwargs)
|
0
tests/unit_tests/retrievers/__init__.py
Normal file
0
tests/unit_tests/retrievers/__init__.py
Normal file
163
tests/unit_tests/retrievers/test_time_weighted_retriever.py
Normal file
163
tests/unit_tests/retrievers/test_time_weighted_retriever.py
Normal file
@ -0,0 +1,163 @@
|
||||
"""Tests for the time-weighted retriever class."""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Any, Iterable, List, Optional, Tuple, Type
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.retrievers.time_weighted_retriever import (
|
||||
TimeWeightedVectorStoreRetriever,
|
||||
_get_hours_passed,
|
||||
)
|
||||
from langchain.schema import Document
|
||||
from langchain.vectorstores.base import VectorStore
|
||||
|
||||
|
||||
def _get_example_memories(k: int = 4) -> List[Document]:
|
||||
return [
|
||||
Document(
|
||||
page_content="foo",
|
||||
metadata={
|
||||
"buffer_idx": i,
|
||||
"last_accessed_at": datetime(2023, 4, 14, 12, 0),
|
||||
},
|
||||
)
|
||||
for i in range(k)
|
||||
]
|
||||
|
||||
|
||||
class MockVectorStore(VectorStore):
|
||||
"""Mock invalid vector store."""
|
||||
|
||||
def add_texts(
|
||||
self,
|
||||
texts: Iterable[str],
|
||||
metadatas: Optional[List[dict]] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[str]:
|
||||
"""Run more texts through the embeddings and add to the vectorstore.
|
||||
|
||||
Args:
|
||||
texts: Iterable of strings to add to the vectorstore.
|
||||
metadatas: Optional list of metadatas associated with the texts.
|
||||
kwargs: vectorstore specific parameters
|
||||
|
||||
Returns:
|
||||
List of ids from adding the texts into the vectorstore.
|
||||
"""
|
||||
return list(texts)
|
||||
|
||||
async def aadd_texts(
|
||||
self,
|
||||
texts: Iterable[str],
|
||||
metadatas: Optional[List[dict]] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[str]:
|
||||
"""Run more texts through the embeddings and add to the vectorstore."""
|
||||
raise NotImplementedError
|
||||
|
||||
def similarity_search(
|
||||
self, query: str, k: int = 4, **kwargs: Any
|
||||
) -> List[Document]:
|
||||
"""Return docs most similar to query."""
|
||||
return []
|
||||
|
||||
@classmethod
|
||||
def from_documents(
|
||||
cls: Type["MockVectorStore"],
|
||||
documents: List[Document],
|
||||
embedding: Embeddings,
|
||||
**kwargs: Any,
|
||||
) -> "MockVectorStore":
|
||||
"""Return VectorStore initialized from documents and embeddings."""
|
||||
texts = [d.page_content for d in documents]
|
||||
metadatas = [d.metadata for d in documents]
|
||||
return cls.from_texts(texts, embedding, metadatas=metadatas, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def from_texts(
|
||||
cls: Type["MockVectorStore"],
|
||||
texts: List[str],
|
||||
embedding: Embeddings,
|
||||
metadatas: Optional[List[dict]] = None,
|
||||
**kwargs: Any,
|
||||
) -> "MockVectorStore":
|
||||
"""Return VectorStore initialized from texts and embeddings."""
|
||||
return cls()
|
||||
|
||||
def _similarity_search_with_relevance_scores(
|
||||
self,
|
||||
query: str,
|
||||
k: int = 4,
|
||||
**kwargs: Any,
|
||||
) -> List[Tuple[Document, float]]:
|
||||
"""Return docs and similarity scores, normalized on a scale from 0 to 1.
|
||||
|
||||
0 is dissimilar, 1 is most similar.
|
||||
"""
|
||||
return [(doc, 0.5) for doc in _get_example_memories()]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def time_weighted_retriever() -> TimeWeightedVectorStoreRetriever:
|
||||
vectorstore = MockVectorStore()
|
||||
return TimeWeightedVectorStoreRetriever(
|
||||
vectorstore=vectorstore, memory_stream=_get_example_memories()
|
||||
)
|
||||
|
||||
|
||||
def test__get_hours_passed() -> None:
|
||||
time1 = datetime(2023, 4, 14, 14, 30)
|
||||
time2 = datetime(2023, 4, 14, 12, 0)
|
||||
expected_hours_passed = 2.5
|
||||
hours_passed = _get_hours_passed(time1, time2)
|
||||
assert hours_passed == expected_hours_passed
|
||||
|
||||
|
||||
def test_get_combined_score(
|
||||
time_weighted_retriever: TimeWeightedVectorStoreRetriever,
|
||||
) -> None:
|
||||
document = Document(
|
||||
page_content="Test document",
|
||||
metadata={"last_accessed_at": datetime(2023, 4, 14, 12, 0)},
|
||||
)
|
||||
vector_salience = 0.7
|
||||
expected_hours_passed = 2.5
|
||||
current_time = datetime(2023, 4, 14, 14, 30)
|
||||
combined_score = time_weighted_retriever._get_combined_score(
|
||||
document, vector_salience, current_time
|
||||
)
|
||||
expected_score = (
|
||||
1.0 - time_weighted_retriever.decay_rate
|
||||
) ** expected_hours_passed + vector_salience
|
||||
assert combined_score == pytest.approx(expected_score)
|
||||
|
||||
|
||||
def test_get_salient_docs(
|
||||
time_weighted_retriever: TimeWeightedVectorStoreRetriever,
|
||||
) -> None:
|
||||
query = "Test query"
|
||||
docs_and_scores = time_weighted_retriever.get_salient_docs(query)
|
||||
assert isinstance(docs_and_scores, dict)
|
||||
|
||||
|
||||
def test_get_relevant_documents(
|
||||
time_weighted_retriever: TimeWeightedVectorStoreRetriever,
|
||||
) -> None:
|
||||
query = "Test query"
|
||||
relevant_documents = time_weighted_retriever.get_relevant_documents(query)
|
||||
assert isinstance(relevant_documents, list)
|
||||
|
||||
|
||||
def test_add_documents(
|
||||
time_weighted_retriever: TimeWeightedVectorStoreRetriever,
|
||||
) -> None:
|
||||
documents = [Document(page_content="test_add_documents document")]
|
||||
added_documents = time_weighted_retriever.add_documents(documents)
|
||||
assert isinstance(added_documents, list)
|
||||
assert len(added_documents) == 1
|
||||
assert (
|
||||
time_weighted_retriever.memory_stream[-1].page_content
|
||||
== documents[0].page_content
|
||||
)
|
Loading…
Reference in New Issue
Block a user