Harrison/virtual time (#4658)

Co-authored-by: ifsheldon <39153080+ifsheldon@users.noreply.github.com>
Co-authored-by: maple.liang <maple.liang@gempoll.com>
This commit is contained in:
Harrison Chase 2023-05-14 10:29:17 -07:00 committed by GitHub
parent f2f2aced6d
commit 243886be93
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 148 additions and 38 deletions

View File

@ -70,7 +70,7 @@
{ {
"data": { "data": {
"text/plain": [ "text/plain": [
"['5c9f7c06-c9eb-45f2-aea5-efce5fb9f2bd']" "['d7f85756-2371-4bdf-9140-052780a0f9b3']"
] ]
}, },
"execution_count": 3, "execution_count": 3,
@ -93,7 +93,7 @@
{ {
"data": { "data": {
"text/plain": [ "text/plain": [
"[Document(page_content='hello world', metadata={'last_accessed_at': datetime.datetime(2023, 4, 16, 22, 9, 1, 966261), 'created_at': datetime.datetime(2023, 4, 16, 22, 9, 0, 374683), 'buffer_idx': 0})]" "[Document(page_content='hello world', metadata={'last_accessed_at': datetime.datetime(2023, 5, 13, 21, 0, 27, 678341), 'created_at': datetime.datetime(2023, 5, 13, 21, 0, 27, 279596), 'buffer_idx': 0})]"
] ]
}, },
"execution_count": 4, "execution_count": 4,
@ -177,10 +177,51 @@
"retriever.get_relevant_documents(\"hello world\")" "retriever.get_relevant_documents(\"hello world\")"
] ]
}, },
{
"cell_type": "markdown",
"id": "32e0131e",
"metadata": {},
"source": [
"## Virtual Time\n",
"\n",
"Using some utils in LangChain, you can mock out the time component"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "da080d40",
"metadata": {},
"outputs": [],
"source": [
"from langchain.utils import mock_now\n",
"import datetime"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "7c7deff1",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[Document(page_content='hello world', metadata={'last_accessed_at': MockDateTime(2011, 2, 3, 10, 11), 'created_at': datetime.datetime(2023, 5, 13, 21, 0, 27, 279596), 'buffer_idx': 0})]\n"
]
}
],
"source": [
"# Notice the last access time is that date time\n",
"with mock_now(datetime.datetime(2011, 2, 3, 10, 11)):\n",
" print(retriever.get_relevant_documents(\"hello world\"))"
]
},
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"id": "bf6d8c90", "id": "c78d367d",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [] "source": []

View File

@ -88,7 +88,9 @@ Relevant context:
q2 = f"{entity_name} is {entity_action}" q2 = f"{entity_name} is {entity_action}"
return self.chain(prompt=prompt).run(q1=q1, queries=[q1, q2]).strip() return self.chain(prompt=prompt).run(q1=q1, queries=[q1, q2]).strip()
def _generate_reaction(self, observation: str, suffix: str) -> str: def _generate_reaction(
self, observation: str, suffix: str, now: Optional[datetime] = None
) -> str:
"""React to a given observation or dialogue act.""" """React to a given observation or dialogue act."""
prompt = PromptTemplate.from_template( prompt = PromptTemplate.from_template(
"{agent_summary_description}" "{agent_summary_description}"
@ -101,9 +103,13 @@ Relevant context:
+ "\n\n" + "\n\n"
+ suffix + suffix
) )
agent_summary_description = self.get_summary() agent_summary_description = self.get_summary(now=now)
relevant_memories_str = self.summarize_related_memories(observation) relevant_memories_str = self.summarize_related_memories(observation)
current_time_str = datetime.now().strftime("%B %d, %Y, %I:%M %p") current_time_str = (
datetime.now().strftime("%B %d, %Y, %I:%M %p")
if now is None
else now.strftime("%B %d, %Y, %I:%M %p")
)
kwargs: Dict[str, Any] = dict( kwargs: Dict[str, Any] = dict(
agent_summary_description=agent_summary_description, agent_summary_description=agent_summary_description,
current_time=current_time_str, current_time=current_time_str,
@ -121,7 +127,9 @@ Relevant context:
def _clean_response(self, text: str) -> str: def _clean_response(self, text: str) -> str:
return re.sub(f"^{self.name} ", "", text.strip()).strip() return re.sub(f"^{self.name} ", "", text.strip()).strip()
def generate_reaction(self, observation: str) -> Tuple[bool, str]: def generate_reaction(
self, observation: str, now: Optional[datetime] = None
) -> Tuple[bool, str]:
"""React to a given observation.""" """React to a given observation."""
call_to_action_template = ( call_to_action_template = (
"Should {agent_name} react to the observation, and if so," "Should {agent_name} react to the observation, and if so,"
@ -130,14 +138,17 @@ Relevant context:
+ "\notherwise, write:\nREACT: {agent_name}'s reaction (if anything)." + "\notherwise, write:\nREACT: {agent_name}'s reaction (if anything)."
+ "\nEither do nothing, react, or say something but not both.\n\n" + "\nEither do nothing, react, or say something but not both.\n\n"
) )
full_result = self._generate_reaction(observation, call_to_action_template) full_result = self._generate_reaction(
observation, call_to_action_template, now=now
)
result = full_result.strip().split("\n")[0] result = full_result.strip().split("\n")[0]
# AAA # AAA
self.memory.save_context( self.memory.save_context(
{}, {},
{ {
self.memory.add_memory_key: f"{self.name} observed " self.memory.add_memory_key: f"{self.name} observed "
f"{observation} and reacted by {result}" f"{observation} and reacted by {result}",
self.memory.now_key: now,
}, },
) )
if "REACT:" in result: if "REACT:" in result:
@ -149,14 +160,18 @@ Relevant context:
else: else:
return False, result return False, result
def generate_dialogue_response(self, observation: str) -> Tuple[bool, str]: def generate_dialogue_response(
self, observation: str, now: Optional[datetime] = None
) -> Tuple[bool, str]:
"""React to a given observation.""" """React to a given observation."""
call_to_action_template = ( call_to_action_template = (
"What would {agent_name} say? To end the conversation, write:" "What would {agent_name} say? To end the conversation, write:"
' GOODBYE: "what to say". Otherwise to continue the conversation,' ' GOODBYE: "what to say". Otherwise to continue the conversation,'
' write: SAY: "what to say next"\n\n' ' write: SAY: "what to say next"\n\n'
) )
full_result = self._generate_reaction(observation, call_to_action_template) full_result = self._generate_reaction(
observation, call_to_action_template, now=now
)
result = full_result.strip().split("\n")[0] result = full_result.strip().split("\n")[0]
if "GOODBYE:" in result: if "GOODBYE:" in result:
farewell = self._clean_response(result.split("GOODBYE:")[-1]) farewell = self._clean_response(result.split("GOODBYE:")[-1])
@ -164,7 +179,8 @@ Relevant context:
{}, {},
{ {
self.memory.add_memory_key: f"{self.name} observed " self.memory.add_memory_key: f"{self.name} observed "
f"{observation} and said {farewell}" f"{observation} and said {farewell}",
self.memory.now_key: now,
}, },
) )
return False, f"{self.name} said {farewell}" return False, f"{self.name} said {farewell}"
@ -174,7 +190,8 @@ Relevant context:
{}, {},
{ {
self.memory.add_memory_key: f"{self.name} observed " self.memory.add_memory_key: f"{self.name} observed "
f"{observation} and said {response_text}" f"{observation} and said {response_text}",
self.memory.now_key: now,
}, },
) )
return True, f"{self.name} said {response_text}" return True, f"{self.name} said {response_text}"
@ -203,9 +220,11 @@ Relevant context:
.strip() .strip()
) )
def get_summary(self, force_refresh: bool = False) -> str: def get_summary(
self, force_refresh: bool = False, now: Optional[datetime] = None
) -> str:
"""Return a descriptive summary of the agent.""" """Return a descriptive summary of the agent."""
current_time = datetime.now() current_time = datetime.now() if now is None else now
since_refresh = (current_time - self.last_refreshed).seconds since_refresh = (current_time - self.last_refreshed).seconds
if ( if (
not self.summary not self.summary
@ -221,10 +240,13 @@ Relevant context:
+ f"\n{self.summary}" + f"\n{self.summary}"
) )
def get_full_header(self, force_refresh: bool = False) -> str: def get_full_header(
self, force_refresh: bool = False, now: Optional[datetime] = None
) -> str:
"""Return a full header of the agent's status, summary, and current time.""" """Return a full header of the agent's status, summary, and current time."""
summary = self.get_summary(force_refresh=force_refresh) now = datetime.now() if now is None else now
current_time_str = datetime.now().strftime("%B %d, %Y, %I:%M %p") summary = self.get_summary(force_refresh=force_refresh, now=now)
current_time_str = now.strftime("%B %d, %Y, %I:%M %p")
return ( return (
f"{summary}\nIt is {current_time_str}.\n{self.name}'s status: {self.status}" f"{summary}\nIt is {current_time_str}.\n{self.name}'s status: {self.status}"
) )

View File

@ -1,5 +1,6 @@
import logging import logging
import re import re
from datetime import datetime
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
from langchain import LLMChain from langchain import LLMChain
@ -7,6 +8,7 @@ from langchain.base_language import BaseLanguageModel
from langchain.prompts import PromptTemplate from langchain.prompts import PromptTemplate
from langchain.retrievers import TimeWeightedVectorStoreRetriever from langchain.retrievers import TimeWeightedVectorStoreRetriever
from langchain.schema import BaseMemory, Document from langchain.schema import BaseMemory, Document
from langchain.utils import mock_now
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -44,6 +46,7 @@ class GenerativeAgentMemory(BaseMemory):
relevant_memories_key: str = "relevant_memories" relevant_memories_key: str = "relevant_memories"
relevant_memories_simple_key: str = "relevant_memories_simple" relevant_memories_simple_key: str = "relevant_memories_simple"
most_recent_memories_key: str = "most_recent_memories" most_recent_memories_key: str = "most_recent_memories"
now_key: str = "now"
reflecting: bool = False reflecting: bool = False
def chain(self, prompt: PromptTemplate) -> LLMChain: def chain(self, prompt: PromptTemplate) -> LLMChain:
@ -68,7 +71,9 @@ class GenerativeAgentMemory(BaseMemory):
result = self.chain(prompt).run(observations=observation_str) result = self.chain(prompt).run(observations=observation_str)
return self._parse_list(result) return self._parse_list(result)
def _get_insights_on_topic(self, topic: str) -> List[str]: def _get_insights_on_topic(
self, topic: str, now: Optional[datetime] = None
) -> List[str]:
"""Generate 'insights' on a topic of reflection, based on pertinent memories.""" """Generate 'insights' on a topic of reflection, based on pertinent memories."""
prompt = PromptTemplate.from_template( prompt = PromptTemplate.from_template(
"Statements about {topic}\n" "Statements about {topic}\n"
@ -76,7 +81,7 @@ class GenerativeAgentMemory(BaseMemory):
+ "What 5 high-level insights can you infer from the above statements?" + "What 5 high-level insights can you infer from the above statements?"
+ " (example format: insight (because of 1, 5, 3))" + " (example format: insight (because of 1, 5, 3))"
) )
related_memories = self.fetch_memories(topic) related_memories = self.fetch_memories(topic, now=now)
related_statements = "\n".join( related_statements = "\n".join(
[ [
f"{i+1}. {memory.page_content}" f"{i+1}. {memory.page_content}"
@ -89,16 +94,16 @@ class GenerativeAgentMemory(BaseMemory):
# TODO: Parse the connections between memories and insights # TODO: Parse the connections between memories and insights
return self._parse_list(result) return self._parse_list(result)
def pause_to_reflect(self) -> List[str]: def pause_to_reflect(self, now: Optional[datetime] = None) -> List[str]:
"""Reflect on recent observations and generate 'insights'.""" """Reflect on recent observations and generate 'insights'."""
if self.verbose: if self.verbose:
logger.info("Character is reflecting") logger.info("Character is reflecting")
new_insights = [] new_insights = []
topics = self._get_topics_of_reflection() topics = self._get_topics_of_reflection()
for topic in topics: for topic in topics:
insights = self._get_insights_on_topic(topic) insights = self._get_insights_on_topic(topic, now=now)
for insight in insights: for insight in insights:
self.add_memory(insight) self.add_memory(insight, now=now)
new_insights.extend(insights) new_insights.extend(insights)
return new_insights return new_insights
@ -122,14 +127,16 @@ class GenerativeAgentMemory(BaseMemory):
else: else:
return 0.0 return 0.0
def add_memory(self, memory_content: str) -> List[str]: def add_memory(
self, memory_content: str, now: Optional[datetime] = None
) -> List[str]:
"""Add an observation or memory to the agent's memory.""" """Add an observation or memory to the agent's memory."""
importance_score = self._score_memory_importance(memory_content) importance_score = self._score_memory_importance(memory_content)
self.aggregate_importance += importance_score self.aggregate_importance += importance_score
document = Document( document = Document(
page_content=memory_content, metadata={"importance": importance_score} page_content=memory_content, metadata={"importance": importance_score}
) )
result = self.memory_retriever.add_documents([document]) result = self.memory_retriever.add_documents([document], current_time=now)
# After an agent has processed a certain amount of memories (as measured by # After an agent has processed a certain amount of memories (as measured by
# aggregate importance), it is time to reflect on recent events to add # aggregate importance), it is time to reflect on recent events to add
@ -140,14 +147,20 @@ class GenerativeAgentMemory(BaseMemory):
and not self.reflecting and not self.reflecting
): ):
self.reflecting = True self.reflecting = True
self.pause_to_reflect() self.pause_to_reflect(now=now)
# Hack to clear the importance from reflection # Hack to clear the importance from reflection
self.aggregate_importance = 0.0 self.aggregate_importance = 0.0
self.reflecting = False self.reflecting = False
return result return result
def fetch_memories(self, observation: str) -> List[Document]: def fetch_memories(
self, observation: str, now: Optional[datetime] = None
) -> List[Document]:
"""Fetch related memories.""" """Fetch related memories."""
if now is not None:
with mock_now(now):
return self.memory_retriever.get_relevant_documents(observation)
else:
return self.memory_retriever.get_relevant_documents(observation) return self.memory_retriever.get_relevant_documents(observation)
def format_memories_detail(self, relevant_memories: List[Document]) -> str: def format_memories_detail(self, relevant_memories: List[Document]) -> str:
@ -183,9 +196,10 @@ class GenerativeAgentMemory(BaseMemory):
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, str]: def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, str]:
"""Return key-value pairs given the text input to the chain.""" """Return key-value pairs given the text input to the chain."""
queries = inputs.get(self.queries_key) queries = inputs.get(self.queries_key)
now = inputs.get(self.now_key)
if queries is not None: if queries is not None:
relevant_memories = [ relevant_memories = [
mem for query in queries for mem in self.fetch_memories(query) mem for query in queries for mem in self.fetch_memories(query, now=now)
] ]
return { return {
self.relevant_memories_key: self.format_memories_detail( self.relevant_memories_key: self.format_memories_detail(
@ -205,12 +219,13 @@ class GenerativeAgentMemory(BaseMemory):
} }
return {} return {}
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None: def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, Any]) -> None:
"""Save the context of this model run to memory.""" """Save the context of this model run to memory."""
# TODO: fix the save memory key # TODO: fix the save memory key
mem = outputs.get(self.add_memory_key) mem = outputs.get(self.add_memory_key)
now = outputs.get(self.now_key)
if mem: if mem:
self.add_memory(mem) self.add_memory(mem, now=now)
def clear(self) -> None: def clear(self) -> None:
"""Clear memory contents.""" """Clear memory contents."""

View File

@ -1,6 +1,6 @@
"""Retriever that combines embedding similarity with recency in retrieving values.""" """Retriever that combines embedding similarity with recency in retrieving values."""
import datetime
from copy import deepcopy from copy import deepcopy
from datetime import datetime
from typing import Any, Dict, List, Optional, Tuple from typing import Any, Dict, List, Optional, Tuple
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
@ -9,7 +9,7 @@ from langchain.schema import BaseRetriever, Document
from langchain.vectorstores.base import VectorStore from langchain.vectorstores.base import VectorStore
def _get_hours_passed(time: datetime, ref_time: datetime) -> float: def _get_hours_passed(time: datetime.datetime, ref_time: datetime.datetime) -> float:
"""Get the hours passed between two datetime objects.""" """Get the hours passed between two datetime objects."""
return (time - ref_time).total_seconds() / 3600 return (time - ref_time).total_seconds() / 3600
@ -51,7 +51,7 @@ class TimeWeightedVectorStoreRetriever(BaseRetriever, BaseModel):
self, self,
document: Document, document: Document,
vector_relevance: Optional[float], vector_relevance: Optional[float],
current_time: datetime, current_time: datetime.datetime,
) -> float: ) -> float:
"""Return the combined score for a document.""" """Return the combined score for a document."""
hours_passed = _get_hours_passed( hours_passed = _get_hours_passed(
@ -82,7 +82,7 @@ class TimeWeightedVectorStoreRetriever(BaseRetriever, BaseModel):
def get_relevant_documents(self, query: str) -> List[Document]: def get_relevant_documents(self, query: str) -> List[Document]:
"""Return documents that are relevant to the query.""" """Return documents that are relevant to the query."""
current_time = datetime.now() current_time = datetime.datetime.now()
docs_and_scores = { docs_and_scores = {
doc.metadata["buffer_idx"]: (doc, self.default_salience) doc.metadata["buffer_idx"]: (doc, self.default_salience)
for doc in self.memory_stream[-self.k :] for doc in self.memory_stream[-self.k :]
@ -96,7 +96,6 @@ class TimeWeightedVectorStoreRetriever(BaseRetriever, BaseModel):
rescored_docs.sort(key=lambda x: x[1], reverse=True) rescored_docs.sort(key=lambda x: x[1], reverse=True)
result = [] result = []
# Ensure frequently accessed memories aren't forgotten # Ensure frequently accessed memories aren't forgotten
current_time = datetime.now()
for doc, _ in rescored_docs[: self.k]: for doc, _ in rescored_docs[: self.k]:
# TODO: Update vector store doc once `update` method is exposed. # TODO: Update vector store doc once `update` method is exposed.
buffered_doc = self.memory_stream[doc.metadata["buffer_idx"]] buffered_doc = self.memory_stream[doc.metadata["buffer_idx"]]
@ -110,7 +109,7 @@ class TimeWeightedVectorStoreRetriever(BaseRetriever, BaseModel):
def add_documents(self, documents: List[Document], **kwargs: Any) -> List[str]: def add_documents(self, documents: List[Document], **kwargs: Any) -> List[str]:
"""Add documents to vectorstore.""" """Add documents to vectorstore."""
current_time = kwargs.get("current_time", datetime.now()) current_time = kwargs.get("current_time", datetime.datetime.now())
# Avoid mutating input documents # Avoid mutating input documents
dup_docs = [deepcopy(d) for d in documents] dup_docs = [deepcopy(d) for d in documents]
for i, doc in enumerate(dup_docs): for i, doc in enumerate(dup_docs):
@ -126,7 +125,7 @@ class TimeWeightedVectorStoreRetriever(BaseRetriever, BaseModel):
self, documents: List[Document], **kwargs: Any self, documents: List[Document], **kwargs: Any
) -> List[str]: ) -> List[str]:
"""Add documents to vectorstore.""" """Add documents to vectorstore."""
current_time = kwargs.get("current_time", datetime.now()) current_time = kwargs.get("current_time", datetime.datetime.now())
# Avoid mutating input documents # Avoid mutating input documents
dup_docs = [deepcopy(d) for d in documents] dup_docs = [deepcopy(d) for d in documents]
for i, doc in enumerate(dup_docs): for i, doc in enumerate(dup_docs):

View File

@ -1,4 +1,6 @@
"""Generic utility functions.""" """Generic utility functions."""
import contextlib
import datetime
import os import os
from typing import Any, Callable, Dict, Optional, Tuple from typing import Any, Callable, Dict, Optional, Tuple
@ -78,3 +80,34 @@ def stringify_dict(data: dict) -> str:
for key, value in data.items(): for key, value in data.items():
text += key + ": " + stringify_value(value) + "\n" text += key + ": " + stringify_value(value) + "\n"
return text return text
@contextlib.contextmanager
def mock_now(dt_value): # type: ignore
"""Context manager for mocking out datetime.now() in unit tests.
Example:
with mock_now(datetime.datetime(2011, 2, 3, 10, 11)):
assert datetime.datetime.now() == datetime.datetime(2011, 2, 3, 10, 11)
"""
class MockDateTime(datetime.datetime):
@classmethod
def now(cls): # type: ignore
# Create a copy of dt_value.
return datetime.datetime(
dt_value.year,
dt_value.month,
dt_value.day,
dt_value.hour,
dt_value.minute,
dt_value.second,
dt_value.microsecond,
dt_value.tzinfo,
)
real_datetime = datetime.datetime
datetime.datetime = MockDateTime
try:
yield datetime.datetime
finally:
datetime.datetime = real_datetime