From 243886be93d7d091bee8c0ebb1002182e57dc43c Mon Sep 17 00:00:00 2001 From: Harrison Chase Date: Sun, 14 May 2023 10:29:17 -0700 Subject: [PATCH] Harrison/virtual time (#4658) Co-authored-by: ifsheldon <39153080+ifsheldon@users.noreply.github.com> Co-authored-by: maple.liang --- .../examples/time_weighted_vectorstore.ipynb | 47 +++++++++++++++-- .../generative_agents/generative_agent.py | 52 +++++++++++++------ .../experimental/generative_agents/memory.py | 41 ++++++++++----- .../retrievers/time_weighted_retriever.py | 13 +++-- langchain/utils.py | 33 ++++++++++++ 5 files changed, 148 insertions(+), 38 deletions(-) diff --git a/docs/modules/indexes/retrievers/examples/time_weighted_vectorstore.ipynb b/docs/modules/indexes/retrievers/examples/time_weighted_vectorstore.ipynb index 1cf1ae02..88ec1261 100644 --- a/docs/modules/indexes/retrievers/examples/time_weighted_vectorstore.ipynb +++ b/docs/modules/indexes/retrievers/examples/time_weighted_vectorstore.ipynb @@ -70,7 +70,7 @@ { "data": { "text/plain": [ - "['5c9f7c06-c9eb-45f2-aea5-efce5fb9f2bd']" + "['d7f85756-2371-4bdf-9140-052780a0f9b3']" ] }, "execution_count": 3, @@ -93,7 +93,7 @@ { "data": { "text/plain": [ - "[Document(page_content='hello world', metadata={'last_accessed_at': datetime.datetime(2023, 4, 16, 22, 9, 1, 966261), 'created_at': datetime.datetime(2023, 4, 16, 22, 9, 0, 374683), 'buffer_idx': 0})]" + "[Document(page_content='hello world', metadata={'last_accessed_at': datetime.datetime(2023, 5, 13, 21, 0, 27, 678341), 'created_at': datetime.datetime(2023, 5, 13, 21, 0, 27, 279596), 'buffer_idx': 0})]" ] }, "execution_count": 4, @@ -177,10 +177,51 @@ "retriever.get_relevant_documents(\"hello world\")" ] }, + { + "cell_type": "markdown", + "id": "32e0131e", + "metadata": {}, + "source": [ + "## Virtual Time\n", + "\n", + "Using some utils in LangChain, you can mock out the time component" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "da080d40", + "metadata": {}, + "outputs": [], + "source": [ + "from langchain.utils import mock_now\n", + "import datetime" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "7c7deff1", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[Document(page_content='hello world', metadata={'last_accessed_at': MockDateTime(2011, 2, 3, 10, 11), 'created_at': datetime.datetime(2023, 5, 13, 21, 0, 27, 279596), 'buffer_idx': 0})]\n" + ] + } + ], + "source": [ + "# Notice the last access time is that date time\n", + "with mock_now(datetime.datetime(2011, 2, 3, 10, 11)):\n", + " print(retriever.get_relevant_documents(\"hello world\"))" + ] + }, { "cell_type": "code", "execution_count": null, - "id": "bf6d8c90", + "id": "c78d367d", "metadata": {}, "outputs": [], "source": [] diff --git a/langchain/experimental/generative_agents/generative_agent.py b/langchain/experimental/generative_agents/generative_agent.py index 64780da8..187a0d0c 100644 --- a/langchain/experimental/generative_agents/generative_agent.py +++ b/langchain/experimental/generative_agents/generative_agent.py @@ -88,7 +88,9 @@ Relevant context: q2 = f"{entity_name} is {entity_action}" return self.chain(prompt=prompt).run(q1=q1, queries=[q1, q2]).strip() - def _generate_reaction(self, observation: str, suffix: str) -> str: + def _generate_reaction( + self, observation: str, suffix: str, now: Optional[datetime] = None + ) -> str: """React to a given observation or dialogue act.""" prompt = PromptTemplate.from_template( "{agent_summary_description}" @@ -101,9 +103,13 @@ Relevant context: + "\n\n" + suffix ) - agent_summary_description = self.get_summary() + agent_summary_description = self.get_summary(now=now) relevant_memories_str = self.summarize_related_memories(observation) - current_time_str = datetime.now().strftime("%B %d, %Y, %I:%M %p") + current_time_str = ( + datetime.now().strftime("%B %d, %Y, %I:%M %p") + if now is None + else now.strftime("%B %d, %Y, %I:%M %p") + ) kwargs: Dict[str, Any] = dict( agent_summary_description=agent_summary_description, current_time=current_time_str, @@ -121,7 +127,9 @@ Relevant context: def _clean_response(self, text: str) -> str: return re.sub(f"^{self.name} ", "", text.strip()).strip() - def generate_reaction(self, observation: str) -> Tuple[bool, str]: + def generate_reaction( + self, observation: str, now: Optional[datetime] = None + ) -> Tuple[bool, str]: """React to a given observation.""" call_to_action_template = ( "Should {agent_name} react to the observation, and if so," @@ -130,14 +138,17 @@ Relevant context: + "\notherwise, write:\nREACT: {agent_name}'s reaction (if anything)." + "\nEither do nothing, react, or say something but not both.\n\n" ) - full_result = self._generate_reaction(observation, call_to_action_template) + full_result = self._generate_reaction( + observation, call_to_action_template, now=now + ) result = full_result.strip().split("\n")[0] # AAA self.memory.save_context( {}, { self.memory.add_memory_key: f"{self.name} observed " - f"{observation} and reacted by {result}" + f"{observation} and reacted by {result}", + self.memory.now_key: now, }, ) if "REACT:" in result: @@ -149,14 +160,18 @@ Relevant context: else: return False, result - def generate_dialogue_response(self, observation: str) -> Tuple[bool, str]: + def generate_dialogue_response( + self, observation: str, now: Optional[datetime] = None + ) -> Tuple[bool, str]: """React to a given observation.""" call_to_action_template = ( "What would {agent_name} say? To end the conversation, write:" ' GOODBYE: "what to say". Otherwise to continue the conversation,' ' write: SAY: "what to say next"\n\n' ) - full_result = self._generate_reaction(observation, call_to_action_template) + full_result = self._generate_reaction( + observation, call_to_action_template, now=now + ) result = full_result.strip().split("\n")[0] if "GOODBYE:" in result: farewell = self._clean_response(result.split("GOODBYE:")[-1]) @@ -164,7 +179,8 @@ Relevant context: {}, { self.memory.add_memory_key: f"{self.name} observed " - f"{observation} and said {farewell}" + f"{observation} and said {farewell}", + self.memory.now_key: now, }, ) return False, f"{self.name} said {farewell}" @@ -174,7 +190,8 @@ Relevant context: {}, { self.memory.add_memory_key: f"{self.name} observed " - f"{observation} and said {response_text}" + f"{observation} and said {response_text}", + self.memory.now_key: now, }, ) return True, f"{self.name} said {response_text}" @@ -203,9 +220,11 @@ Relevant context: .strip() ) - def get_summary(self, force_refresh: bool = False) -> str: + def get_summary( + self, force_refresh: bool = False, now: Optional[datetime] = None + ) -> str: """Return a descriptive summary of the agent.""" - current_time = datetime.now() + current_time = datetime.now() if now is None else now since_refresh = (current_time - self.last_refreshed).seconds if ( not self.summary @@ -221,10 +240,13 @@ Relevant context: + f"\n{self.summary}" ) - def get_full_header(self, force_refresh: bool = False) -> str: + def get_full_header( + self, force_refresh: bool = False, now: Optional[datetime] = None + ) -> str: """Return a full header of the agent's status, summary, and current time.""" - summary = self.get_summary(force_refresh=force_refresh) - current_time_str = datetime.now().strftime("%B %d, %Y, %I:%M %p") + now = datetime.now() if now is None else now + summary = self.get_summary(force_refresh=force_refresh, now=now) + current_time_str = now.strftime("%B %d, %Y, %I:%M %p") return ( f"{summary}\nIt is {current_time_str}.\n{self.name}'s status: {self.status}" ) diff --git a/langchain/experimental/generative_agents/memory.py b/langchain/experimental/generative_agents/memory.py index 8cfdacb7..9b9dd4bb 100644 --- a/langchain/experimental/generative_agents/memory.py +++ b/langchain/experimental/generative_agents/memory.py @@ -1,5 +1,6 @@ import logging import re +from datetime import datetime from typing import Any, Dict, List, Optional from langchain import LLMChain @@ -7,6 +8,7 @@ from langchain.base_language import BaseLanguageModel from langchain.prompts import PromptTemplate from langchain.retrievers import TimeWeightedVectorStoreRetriever from langchain.schema import BaseMemory, Document +from langchain.utils import mock_now logger = logging.getLogger(__name__) @@ -44,6 +46,7 @@ class GenerativeAgentMemory(BaseMemory): relevant_memories_key: str = "relevant_memories" relevant_memories_simple_key: str = "relevant_memories_simple" most_recent_memories_key: str = "most_recent_memories" + now_key: str = "now" reflecting: bool = False def chain(self, prompt: PromptTemplate) -> LLMChain: @@ -68,7 +71,9 @@ class GenerativeAgentMemory(BaseMemory): result = self.chain(prompt).run(observations=observation_str) return self._parse_list(result) - def _get_insights_on_topic(self, topic: str) -> List[str]: + def _get_insights_on_topic( + self, topic: str, now: Optional[datetime] = None + ) -> List[str]: """Generate 'insights' on a topic of reflection, based on pertinent memories.""" prompt = PromptTemplate.from_template( "Statements about {topic}\n" @@ -76,7 +81,7 @@ class GenerativeAgentMemory(BaseMemory): + "What 5 high-level insights can you infer from the above statements?" + " (example format: insight (because of 1, 5, 3))" ) - related_memories = self.fetch_memories(topic) + related_memories = self.fetch_memories(topic, now=now) related_statements = "\n".join( [ f"{i+1}. {memory.page_content}" @@ -89,16 +94,16 @@ class GenerativeAgentMemory(BaseMemory): # TODO: Parse the connections between memories and insights return self._parse_list(result) - def pause_to_reflect(self) -> List[str]: + def pause_to_reflect(self, now: Optional[datetime] = None) -> List[str]: """Reflect on recent observations and generate 'insights'.""" if self.verbose: logger.info("Character is reflecting") new_insights = [] topics = self._get_topics_of_reflection() for topic in topics: - insights = self._get_insights_on_topic(topic) + insights = self._get_insights_on_topic(topic, now=now) for insight in insights: - self.add_memory(insight) + self.add_memory(insight, now=now) new_insights.extend(insights) return new_insights @@ -122,14 +127,16 @@ class GenerativeAgentMemory(BaseMemory): else: return 0.0 - def add_memory(self, memory_content: str) -> List[str]: + def add_memory( + self, memory_content: str, now: Optional[datetime] = None + ) -> List[str]: """Add an observation or memory to the agent's memory.""" importance_score = self._score_memory_importance(memory_content) self.aggregate_importance += importance_score document = Document( page_content=memory_content, metadata={"importance": importance_score} ) - result = self.memory_retriever.add_documents([document]) + result = self.memory_retriever.add_documents([document], current_time=now) # After an agent has processed a certain amount of memories (as measured by # aggregate importance), it is time to reflect on recent events to add @@ -140,15 +147,21 @@ class GenerativeAgentMemory(BaseMemory): and not self.reflecting ): self.reflecting = True - self.pause_to_reflect() + self.pause_to_reflect(now=now) # Hack to clear the importance from reflection self.aggregate_importance = 0.0 self.reflecting = False return result - def fetch_memories(self, observation: str) -> List[Document]: + def fetch_memories( + self, observation: str, now: Optional[datetime] = None + ) -> List[Document]: """Fetch related memories.""" - return self.memory_retriever.get_relevant_documents(observation) + if now is not None: + with mock_now(now): + return self.memory_retriever.get_relevant_documents(observation) + else: + return self.memory_retriever.get_relevant_documents(observation) def format_memories_detail(self, relevant_memories: List[Document]) -> str: content_strs = set() @@ -183,9 +196,10 @@ class GenerativeAgentMemory(BaseMemory): def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, str]: """Return key-value pairs given the text input to the chain.""" queries = inputs.get(self.queries_key) + now = inputs.get(self.now_key) if queries is not None: relevant_memories = [ - mem for query in queries for mem in self.fetch_memories(query) + mem for query in queries for mem in self.fetch_memories(query, now=now) ] return { self.relevant_memories_key: self.format_memories_detail( @@ -205,12 +219,13 @@ class GenerativeAgentMemory(BaseMemory): } return {} - def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None: + def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, Any]) -> None: """Save the context of this model run to memory.""" # TODO: fix the save memory key mem = outputs.get(self.add_memory_key) + now = outputs.get(self.now_key) if mem: - self.add_memory(mem) + self.add_memory(mem, now=now) def clear(self) -> None: """Clear memory contents.""" diff --git a/langchain/retrievers/time_weighted_retriever.py b/langchain/retrievers/time_weighted_retriever.py index b3225a63..2b789137 100644 --- a/langchain/retrievers/time_weighted_retriever.py +++ b/langchain/retrievers/time_weighted_retriever.py @@ -1,6 +1,6 @@ """Retriever that combines embedding similarity with recency in retrieving values.""" +import datetime from copy import deepcopy -from datetime import datetime from typing import Any, Dict, List, Optional, Tuple from pydantic import BaseModel, Field @@ -9,7 +9,7 @@ from langchain.schema import BaseRetriever, Document from langchain.vectorstores.base import VectorStore -def _get_hours_passed(time: datetime, ref_time: datetime) -> float: +def _get_hours_passed(time: datetime.datetime, ref_time: datetime.datetime) -> float: """Get the hours passed between two datetime objects.""" return (time - ref_time).total_seconds() / 3600 @@ -51,7 +51,7 @@ class TimeWeightedVectorStoreRetriever(BaseRetriever, BaseModel): self, document: Document, vector_relevance: Optional[float], - current_time: datetime, + current_time: datetime.datetime, ) -> float: """Return the combined score for a document.""" hours_passed = _get_hours_passed( @@ -82,7 +82,7 @@ class TimeWeightedVectorStoreRetriever(BaseRetriever, BaseModel): def get_relevant_documents(self, query: str) -> List[Document]: """Return documents that are relevant to the query.""" - current_time = datetime.now() + current_time = datetime.datetime.now() docs_and_scores = { doc.metadata["buffer_idx"]: (doc, self.default_salience) for doc in self.memory_stream[-self.k :] @@ -96,7 +96,6 @@ class TimeWeightedVectorStoreRetriever(BaseRetriever, BaseModel): rescored_docs.sort(key=lambda x: x[1], reverse=True) result = [] # Ensure frequently accessed memories aren't forgotten - current_time = datetime.now() for doc, _ in rescored_docs[: self.k]: # TODO: Update vector store doc once `update` method is exposed. buffered_doc = self.memory_stream[doc.metadata["buffer_idx"]] @@ -110,7 +109,7 @@ class TimeWeightedVectorStoreRetriever(BaseRetriever, BaseModel): def add_documents(self, documents: List[Document], **kwargs: Any) -> List[str]: """Add documents to vectorstore.""" - current_time = kwargs.get("current_time", datetime.now()) + current_time = kwargs.get("current_time", datetime.datetime.now()) # Avoid mutating input documents dup_docs = [deepcopy(d) for d in documents] for i, doc in enumerate(dup_docs): @@ -126,7 +125,7 @@ class TimeWeightedVectorStoreRetriever(BaseRetriever, BaseModel): self, documents: List[Document], **kwargs: Any ) -> List[str]: """Add documents to vectorstore.""" - current_time = kwargs.get("current_time", datetime.now()) + current_time = kwargs.get("current_time", datetime.datetime.now()) # Avoid mutating input documents dup_docs = [deepcopy(d) for d in documents] for i, doc in enumerate(dup_docs): diff --git a/langchain/utils.py b/langchain/utils.py index 7420b371..0e9b79f5 100644 --- a/langchain/utils.py +++ b/langchain/utils.py @@ -1,4 +1,6 @@ """Generic utility functions.""" +import contextlib +import datetime import os from typing import Any, Callable, Dict, Optional, Tuple @@ -78,3 +80,34 @@ def stringify_dict(data: dict) -> str: for key, value in data.items(): text += key + ": " + stringify_value(value) + "\n" return text + + +@contextlib.contextmanager +def mock_now(dt_value): # type: ignore + """Context manager for mocking out datetime.now() in unit tests. + Example: + with mock_now(datetime.datetime(2011, 2, 3, 10, 11)): + assert datetime.datetime.now() == datetime.datetime(2011, 2, 3, 10, 11) + """ + + class MockDateTime(datetime.datetime): + @classmethod + def now(cls): # type: ignore + # Create a copy of dt_value. + return datetime.datetime( + dt_value.year, + dt_value.month, + dt_value.day, + dt_value.hour, + dt_value.minute, + dt_value.second, + dt_value.microsecond, + dt_value.tzinfo, + ) + + real_datetime = datetime.datetime + datetime.datetime = MockDateTime + try: + yield datetime.datetime + finally: + datetime.datetime = real_datetime