mirror of
https://github.com/hwchase17/langchain
synced 2024-11-18 09:25:54 +00:00
Harrison/virtual time (#4658)
Co-authored-by: ifsheldon <39153080+ifsheldon@users.noreply.github.com> Co-authored-by: maple.liang <maple.liang@gempoll.com>
This commit is contained in:
parent
f2f2aced6d
commit
243886be93
@ -70,7 +70,7 @@
|
|||||||
{
|
{
|
||||||
"data": {
|
"data": {
|
||||||
"text/plain": [
|
"text/plain": [
|
||||||
"['5c9f7c06-c9eb-45f2-aea5-efce5fb9f2bd']"
|
"['d7f85756-2371-4bdf-9140-052780a0f9b3']"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"execution_count": 3,
|
"execution_count": 3,
|
||||||
@ -93,7 +93,7 @@
|
|||||||
{
|
{
|
||||||
"data": {
|
"data": {
|
||||||
"text/plain": [
|
"text/plain": [
|
||||||
"[Document(page_content='hello world', metadata={'last_accessed_at': datetime.datetime(2023, 4, 16, 22, 9, 1, 966261), 'created_at': datetime.datetime(2023, 4, 16, 22, 9, 0, 374683), 'buffer_idx': 0})]"
|
"[Document(page_content='hello world', metadata={'last_accessed_at': datetime.datetime(2023, 5, 13, 21, 0, 27, 678341), 'created_at': datetime.datetime(2023, 5, 13, 21, 0, 27, 279596), 'buffer_idx': 0})]"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"execution_count": 4,
|
"execution_count": 4,
|
||||||
@ -177,10 +177,51 @@
|
|||||||
"retriever.get_relevant_documents(\"hello world\")"
|
"retriever.get_relevant_documents(\"hello world\")"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "32e0131e",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Virtual Time\n",
|
||||||
|
"\n",
|
||||||
|
"Using some utils in LangChain, you can mock out the time component"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 5,
|
||||||
|
"id": "da080d40",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from langchain.utils import mock_now\n",
|
||||||
|
"import datetime"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 7,
|
||||||
|
"id": "7c7deff1",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"[Document(page_content='hello world', metadata={'last_accessed_at': MockDateTime(2011, 2, 3, 10, 11), 'created_at': datetime.datetime(2023, 5, 13, 21, 0, 27, 279596), 'buffer_idx': 0})]\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"# Notice the last access time is that date time\n",
|
||||||
|
"with mock_now(datetime.datetime(2011, 2, 3, 10, 11)):\n",
|
||||||
|
" print(retriever.get_relevant_documents(\"hello world\"))"
|
||||||
|
]
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": null,
|
||||||
"id": "bf6d8c90",
|
"id": "c78d367d",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": []
|
"source": []
|
||||||
|
@ -88,7 +88,9 @@ Relevant context:
|
|||||||
q2 = f"{entity_name} is {entity_action}"
|
q2 = f"{entity_name} is {entity_action}"
|
||||||
return self.chain(prompt=prompt).run(q1=q1, queries=[q1, q2]).strip()
|
return self.chain(prompt=prompt).run(q1=q1, queries=[q1, q2]).strip()
|
||||||
|
|
||||||
def _generate_reaction(self, observation: str, suffix: str) -> str:
|
def _generate_reaction(
|
||||||
|
self, observation: str, suffix: str, now: Optional[datetime] = None
|
||||||
|
) -> str:
|
||||||
"""React to a given observation or dialogue act."""
|
"""React to a given observation or dialogue act."""
|
||||||
prompt = PromptTemplate.from_template(
|
prompt = PromptTemplate.from_template(
|
||||||
"{agent_summary_description}"
|
"{agent_summary_description}"
|
||||||
@ -101,9 +103,13 @@ Relevant context:
|
|||||||
+ "\n\n"
|
+ "\n\n"
|
||||||
+ suffix
|
+ suffix
|
||||||
)
|
)
|
||||||
agent_summary_description = self.get_summary()
|
agent_summary_description = self.get_summary(now=now)
|
||||||
relevant_memories_str = self.summarize_related_memories(observation)
|
relevant_memories_str = self.summarize_related_memories(observation)
|
||||||
current_time_str = datetime.now().strftime("%B %d, %Y, %I:%M %p")
|
current_time_str = (
|
||||||
|
datetime.now().strftime("%B %d, %Y, %I:%M %p")
|
||||||
|
if now is None
|
||||||
|
else now.strftime("%B %d, %Y, %I:%M %p")
|
||||||
|
)
|
||||||
kwargs: Dict[str, Any] = dict(
|
kwargs: Dict[str, Any] = dict(
|
||||||
agent_summary_description=agent_summary_description,
|
agent_summary_description=agent_summary_description,
|
||||||
current_time=current_time_str,
|
current_time=current_time_str,
|
||||||
@ -121,7 +127,9 @@ Relevant context:
|
|||||||
def _clean_response(self, text: str) -> str:
|
def _clean_response(self, text: str) -> str:
|
||||||
return re.sub(f"^{self.name} ", "", text.strip()).strip()
|
return re.sub(f"^{self.name} ", "", text.strip()).strip()
|
||||||
|
|
||||||
def generate_reaction(self, observation: str) -> Tuple[bool, str]:
|
def generate_reaction(
|
||||||
|
self, observation: str, now: Optional[datetime] = None
|
||||||
|
) -> Tuple[bool, str]:
|
||||||
"""React to a given observation."""
|
"""React to a given observation."""
|
||||||
call_to_action_template = (
|
call_to_action_template = (
|
||||||
"Should {agent_name} react to the observation, and if so,"
|
"Should {agent_name} react to the observation, and if so,"
|
||||||
@ -130,14 +138,17 @@ Relevant context:
|
|||||||
+ "\notherwise, write:\nREACT: {agent_name}'s reaction (if anything)."
|
+ "\notherwise, write:\nREACT: {agent_name}'s reaction (if anything)."
|
||||||
+ "\nEither do nothing, react, or say something but not both.\n\n"
|
+ "\nEither do nothing, react, or say something but not both.\n\n"
|
||||||
)
|
)
|
||||||
full_result = self._generate_reaction(observation, call_to_action_template)
|
full_result = self._generate_reaction(
|
||||||
|
observation, call_to_action_template, now=now
|
||||||
|
)
|
||||||
result = full_result.strip().split("\n")[0]
|
result = full_result.strip().split("\n")[0]
|
||||||
# AAA
|
# AAA
|
||||||
self.memory.save_context(
|
self.memory.save_context(
|
||||||
{},
|
{},
|
||||||
{
|
{
|
||||||
self.memory.add_memory_key: f"{self.name} observed "
|
self.memory.add_memory_key: f"{self.name} observed "
|
||||||
f"{observation} and reacted by {result}"
|
f"{observation} and reacted by {result}",
|
||||||
|
self.memory.now_key: now,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
if "REACT:" in result:
|
if "REACT:" in result:
|
||||||
@ -149,14 +160,18 @@ Relevant context:
|
|||||||
else:
|
else:
|
||||||
return False, result
|
return False, result
|
||||||
|
|
||||||
def generate_dialogue_response(self, observation: str) -> Tuple[bool, str]:
|
def generate_dialogue_response(
|
||||||
|
self, observation: str, now: Optional[datetime] = None
|
||||||
|
) -> Tuple[bool, str]:
|
||||||
"""React to a given observation."""
|
"""React to a given observation."""
|
||||||
call_to_action_template = (
|
call_to_action_template = (
|
||||||
"What would {agent_name} say? To end the conversation, write:"
|
"What would {agent_name} say? To end the conversation, write:"
|
||||||
' GOODBYE: "what to say". Otherwise to continue the conversation,'
|
' GOODBYE: "what to say". Otherwise to continue the conversation,'
|
||||||
' write: SAY: "what to say next"\n\n'
|
' write: SAY: "what to say next"\n\n'
|
||||||
)
|
)
|
||||||
full_result = self._generate_reaction(observation, call_to_action_template)
|
full_result = self._generate_reaction(
|
||||||
|
observation, call_to_action_template, now=now
|
||||||
|
)
|
||||||
result = full_result.strip().split("\n")[0]
|
result = full_result.strip().split("\n")[0]
|
||||||
if "GOODBYE:" in result:
|
if "GOODBYE:" in result:
|
||||||
farewell = self._clean_response(result.split("GOODBYE:")[-1])
|
farewell = self._clean_response(result.split("GOODBYE:")[-1])
|
||||||
@ -164,7 +179,8 @@ Relevant context:
|
|||||||
{},
|
{},
|
||||||
{
|
{
|
||||||
self.memory.add_memory_key: f"{self.name} observed "
|
self.memory.add_memory_key: f"{self.name} observed "
|
||||||
f"{observation} and said {farewell}"
|
f"{observation} and said {farewell}",
|
||||||
|
self.memory.now_key: now,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
return False, f"{self.name} said {farewell}"
|
return False, f"{self.name} said {farewell}"
|
||||||
@ -174,7 +190,8 @@ Relevant context:
|
|||||||
{},
|
{},
|
||||||
{
|
{
|
||||||
self.memory.add_memory_key: f"{self.name} observed "
|
self.memory.add_memory_key: f"{self.name} observed "
|
||||||
f"{observation} and said {response_text}"
|
f"{observation} and said {response_text}",
|
||||||
|
self.memory.now_key: now,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
return True, f"{self.name} said {response_text}"
|
return True, f"{self.name} said {response_text}"
|
||||||
@ -203,9 +220,11 @@ Relevant context:
|
|||||||
.strip()
|
.strip()
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_summary(self, force_refresh: bool = False) -> str:
|
def get_summary(
|
||||||
|
self, force_refresh: bool = False, now: Optional[datetime] = None
|
||||||
|
) -> str:
|
||||||
"""Return a descriptive summary of the agent."""
|
"""Return a descriptive summary of the agent."""
|
||||||
current_time = datetime.now()
|
current_time = datetime.now() if now is None else now
|
||||||
since_refresh = (current_time - self.last_refreshed).seconds
|
since_refresh = (current_time - self.last_refreshed).seconds
|
||||||
if (
|
if (
|
||||||
not self.summary
|
not self.summary
|
||||||
@ -221,10 +240,13 @@ Relevant context:
|
|||||||
+ f"\n{self.summary}"
|
+ f"\n{self.summary}"
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_full_header(self, force_refresh: bool = False) -> str:
|
def get_full_header(
|
||||||
|
self, force_refresh: bool = False, now: Optional[datetime] = None
|
||||||
|
) -> str:
|
||||||
"""Return a full header of the agent's status, summary, and current time."""
|
"""Return a full header of the agent's status, summary, and current time."""
|
||||||
summary = self.get_summary(force_refresh=force_refresh)
|
now = datetime.now() if now is None else now
|
||||||
current_time_str = datetime.now().strftime("%B %d, %Y, %I:%M %p")
|
summary = self.get_summary(force_refresh=force_refresh, now=now)
|
||||||
|
current_time_str = now.strftime("%B %d, %Y, %I:%M %p")
|
||||||
return (
|
return (
|
||||||
f"{summary}\nIt is {current_time_str}.\n{self.name}'s status: {self.status}"
|
f"{summary}\nIt is {current_time_str}.\n{self.name}'s status: {self.status}"
|
||||||
)
|
)
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
|
from datetime import datetime
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
from langchain import LLMChain
|
from langchain import LLMChain
|
||||||
@ -7,6 +8,7 @@ from langchain.base_language import BaseLanguageModel
|
|||||||
from langchain.prompts import PromptTemplate
|
from langchain.prompts import PromptTemplate
|
||||||
from langchain.retrievers import TimeWeightedVectorStoreRetriever
|
from langchain.retrievers import TimeWeightedVectorStoreRetriever
|
||||||
from langchain.schema import BaseMemory, Document
|
from langchain.schema import BaseMemory, Document
|
||||||
|
from langchain.utils import mock_now
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -44,6 +46,7 @@ class GenerativeAgentMemory(BaseMemory):
|
|||||||
relevant_memories_key: str = "relevant_memories"
|
relevant_memories_key: str = "relevant_memories"
|
||||||
relevant_memories_simple_key: str = "relevant_memories_simple"
|
relevant_memories_simple_key: str = "relevant_memories_simple"
|
||||||
most_recent_memories_key: str = "most_recent_memories"
|
most_recent_memories_key: str = "most_recent_memories"
|
||||||
|
now_key: str = "now"
|
||||||
reflecting: bool = False
|
reflecting: bool = False
|
||||||
|
|
||||||
def chain(self, prompt: PromptTemplate) -> LLMChain:
|
def chain(self, prompt: PromptTemplate) -> LLMChain:
|
||||||
@ -68,7 +71,9 @@ class GenerativeAgentMemory(BaseMemory):
|
|||||||
result = self.chain(prompt).run(observations=observation_str)
|
result = self.chain(prompt).run(observations=observation_str)
|
||||||
return self._parse_list(result)
|
return self._parse_list(result)
|
||||||
|
|
||||||
def _get_insights_on_topic(self, topic: str) -> List[str]:
|
def _get_insights_on_topic(
|
||||||
|
self, topic: str, now: Optional[datetime] = None
|
||||||
|
) -> List[str]:
|
||||||
"""Generate 'insights' on a topic of reflection, based on pertinent memories."""
|
"""Generate 'insights' on a topic of reflection, based on pertinent memories."""
|
||||||
prompt = PromptTemplate.from_template(
|
prompt = PromptTemplate.from_template(
|
||||||
"Statements about {topic}\n"
|
"Statements about {topic}\n"
|
||||||
@ -76,7 +81,7 @@ class GenerativeAgentMemory(BaseMemory):
|
|||||||
+ "What 5 high-level insights can you infer from the above statements?"
|
+ "What 5 high-level insights can you infer from the above statements?"
|
||||||
+ " (example format: insight (because of 1, 5, 3))"
|
+ " (example format: insight (because of 1, 5, 3))"
|
||||||
)
|
)
|
||||||
related_memories = self.fetch_memories(topic)
|
related_memories = self.fetch_memories(topic, now=now)
|
||||||
related_statements = "\n".join(
|
related_statements = "\n".join(
|
||||||
[
|
[
|
||||||
f"{i+1}. {memory.page_content}"
|
f"{i+1}. {memory.page_content}"
|
||||||
@ -89,16 +94,16 @@ class GenerativeAgentMemory(BaseMemory):
|
|||||||
# TODO: Parse the connections between memories and insights
|
# TODO: Parse the connections between memories and insights
|
||||||
return self._parse_list(result)
|
return self._parse_list(result)
|
||||||
|
|
||||||
def pause_to_reflect(self) -> List[str]:
|
def pause_to_reflect(self, now: Optional[datetime] = None) -> List[str]:
|
||||||
"""Reflect on recent observations and generate 'insights'."""
|
"""Reflect on recent observations and generate 'insights'."""
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
logger.info("Character is reflecting")
|
logger.info("Character is reflecting")
|
||||||
new_insights = []
|
new_insights = []
|
||||||
topics = self._get_topics_of_reflection()
|
topics = self._get_topics_of_reflection()
|
||||||
for topic in topics:
|
for topic in topics:
|
||||||
insights = self._get_insights_on_topic(topic)
|
insights = self._get_insights_on_topic(topic, now=now)
|
||||||
for insight in insights:
|
for insight in insights:
|
||||||
self.add_memory(insight)
|
self.add_memory(insight, now=now)
|
||||||
new_insights.extend(insights)
|
new_insights.extend(insights)
|
||||||
return new_insights
|
return new_insights
|
||||||
|
|
||||||
@ -122,14 +127,16 @@ class GenerativeAgentMemory(BaseMemory):
|
|||||||
else:
|
else:
|
||||||
return 0.0
|
return 0.0
|
||||||
|
|
||||||
def add_memory(self, memory_content: str) -> List[str]:
|
def add_memory(
|
||||||
|
self, memory_content: str, now: Optional[datetime] = None
|
||||||
|
) -> List[str]:
|
||||||
"""Add an observation or memory to the agent's memory."""
|
"""Add an observation or memory to the agent's memory."""
|
||||||
importance_score = self._score_memory_importance(memory_content)
|
importance_score = self._score_memory_importance(memory_content)
|
||||||
self.aggregate_importance += importance_score
|
self.aggregate_importance += importance_score
|
||||||
document = Document(
|
document = Document(
|
||||||
page_content=memory_content, metadata={"importance": importance_score}
|
page_content=memory_content, metadata={"importance": importance_score}
|
||||||
)
|
)
|
||||||
result = self.memory_retriever.add_documents([document])
|
result = self.memory_retriever.add_documents([document], current_time=now)
|
||||||
|
|
||||||
# After an agent has processed a certain amount of memories (as measured by
|
# After an agent has processed a certain amount of memories (as measured by
|
||||||
# aggregate importance), it is time to reflect on recent events to add
|
# aggregate importance), it is time to reflect on recent events to add
|
||||||
@ -140,14 +147,20 @@ class GenerativeAgentMemory(BaseMemory):
|
|||||||
and not self.reflecting
|
and not self.reflecting
|
||||||
):
|
):
|
||||||
self.reflecting = True
|
self.reflecting = True
|
||||||
self.pause_to_reflect()
|
self.pause_to_reflect(now=now)
|
||||||
# Hack to clear the importance from reflection
|
# Hack to clear the importance from reflection
|
||||||
self.aggregate_importance = 0.0
|
self.aggregate_importance = 0.0
|
||||||
self.reflecting = False
|
self.reflecting = False
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def fetch_memories(self, observation: str) -> List[Document]:
|
def fetch_memories(
|
||||||
|
self, observation: str, now: Optional[datetime] = None
|
||||||
|
) -> List[Document]:
|
||||||
"""Fetch related memories."""
|
"""Fetch related memories."""
|
||||||
|
if now is not None:
|
||||||
|
with mock_now(now):
|
||||||
|
return self.memory_retriever.get_relevant_documents(observation)
|
||||||
|
else:
|
||||||
return self.memory_retriever.get_relevant_documents(observation)
|
return self.memory_retriever.get_relevant_documents(observation)
|
||||||
|
|
||||||
def format_memories_detail(self, relevant_memories: List[Document]) -> str:
|
def format_memories_detail(self, relevant_memories: List[Document]) -> str:
|
||||||
@ -183,9 +196,10 @@ class GenerativeAgentMemory(BaseMemory):
|
|||||||
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, str]:
|
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, str]:
|
||||||
"""Return key-value pairs given the text input to the chain."""
|
"""Return key-value pairs given the text input to the chain."""
|
||||||
queries = inputs.get(self.queries_key)
|
queries = inputs.get(self.queries_key)
|
||||||
|
now = inputs.get(self.now_key)
|
||||||
if queries is not None:
|
if queries is not None:
|
||||||
relevant_memories = [
|
relevant_memories = [
|
||||||
mem for query in queries for mem in self.fetch_memories(query)
|
mem for query in queries for mem in self.fetch_memories(query, now=now)
|
||||||
]
|
]
|
||||||
return {
|
return {
|
||||||
self.relevant_memories_key: self.format_memories_detail(
|
self.relevant_memories_key: self.format_memories_detail(
|
||||||
@ -205,12 +219,13 @@ class GenerativeAgentMemory(BaseMemory):
|
|||||||
}
|
}
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
|
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, Any]) -> None:
|
||||||
"""Save the context of this model run to memory."""
|
"""Save the context of this model run to memory."""
|
||||||
# TODO: fix the save memory key
|
# TODO: fix the save memory key
|
||||||
mem = outputs.get(self.add_memory_key)
|
mem = outputs.get(self.add_memory_key)
|
||||||
|
now = outputs.get(self.now_key)
|
||||||
if mem:
|
if mem:
|
||||||
self.add_memory(mem)
|
self.add_memory(mem, now=now)
|
||||||
|
|
||||||
def clear(self) -> None:
|
def clear(self) -> None:
|
||||||
"""Clear memory contents."""
|
"""Clear memory contents."""
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
"""Retriever that combines embedding similarity with recency in retrieving values."""
|
"""Retriever that combines embedding similarity with recency in retrieving values."""
|
||||||
|
import datetime
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from datetime import datetime
|
|
||||||
from typing import Any, Dict, List, Optional, Tuple
|
from typing import Any, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
@ -9,7 +9,7 @@ from langchain.schema import BaseRetriever, Document
|
|||||||
from langchain.vectorstores.base import VectorStore
|
from langchain.vectorstores.base import VectorStore
|
||||||
|
|
||||||
|
|
||||||
def _get_hours_passed(time: datetime, ref_time: datetime) -> float:
|
def _get_hours_passed(time: datetime.datetime, ref_time: datetime.datetime) -> float:
|
||||||
"""Get the hours passed between two datetime objects."""
|
"""Get the hours passed between two datetime objects."""
|
||||||
return (time - ref_time).total_seconds() / 3600
|
return (time - ref_time).total_seconds() / 3600
|
||||||
|
|
||||||
@ -51,7 +51,7 @@ class TimeWeightedVectorStoreRetriever(BaseRetriever, BaseModel):
|
|||||||
self,
|
self,
|
||||||
document: Document,
|
document: Document,
|
||||||
vector_relevance: Optional[float],
|
vector_relevance: Optional[float],
|
||||||
current_time: datetime,
|
current_time: datetime.datetime,
|
||||||
) -> float:
|
) -> float:
|
||||||
"""Return the combined score for a document."""
|
"""Return the combined score for a document."""
|
||||||
hours_passed = _get_hours_passed(
|
hours_passed = _get_hours_passed(
|
||||||
@ -82,7 +82,7 @@ class TimeWeightedVectorStoreRetriever(BaseRetriever, BaseModel):
|
|||||||
|
|
||||||
def get_relevant_documents(self, query: str) -> List[Document]:
|
def get_relevant_documents(self, query: str) -> List[Document]:
|
||||||
"""Return documents that are relevant to the query."""
|
"""Return documents that are relevant to the query."""
|
||||||
current_time = datetime.now()
|
current_time = datetime.datetime.now()
|
||||||
docs_and_scores = {
|
docs_and_scores = {
|
||||||
doc.metadata["buffer_idx"]: (doc, self.default_salience)
|
doc.metadata["buffer_idx"]: (doc, self.default_salience)
|
||||||
for doc in self.memory_stream[-self.k :]
|
for doc in self.memory_stream[-self.k :]
|
||||||
@ -96,7 +96,6 @@ class TimeWeightedVectorStoreRetriever(BaseRetriever, BaseModel):
|
|||||||
rescored_docs.sort(key=lambda x: x[1], reverse=True)
|
rescored_docs.sort(key=lambda x: x[1], reverse=True)
|
||||||
result = []
|
result = []
|
||||||
# Ensure frequently accessed memories aren't forgotten
|
# Ensure frequently accessed memories aren't forgotten
|
||||||
current_time = datetime.now()
|
|
||||||
for doc, _ in rescored_docs[: self.k]:
|
for doc, _ in rescored_docs[: self.k]:
|
||||||
# TODO: Update vector store doc once `update` method is exposed.
|
# TODO: Update vector store doc once `update` method is exposed.
|
||||||
buffered_doc = self.memory_stream[doc.metadata["buffer_idx"]]
|
buffered_doc = self.memory_stream[doc.metadata["buffer_idx"]]
|
||||||
@ -110,7 +109,7 @@ class TimeWeightedVectorStoreRetriever(BaseRetriever, BaseModel):
|
|||||||
|
|
||||||
def add_documents(self, documents: List[Document], **kwargs: Any) -> List[str]:
|
def add_documents(self, documents: List[Document], **kwargs: Any) -> List[str]:
|
||||||
"""Add documents to vectorstore."""
|
"""Add documents to vectorstore."""
|
||||||
current_time = kwargs.get("current_time", datetime.now())
|
current_time = kwargs.get("current_time", datetime.datetime.now())
|
||||||
# Avoid mutating input documents
|
# Avoid mutating input documents
|
||||||
dup_docs = [deepcopy(d) for d in documents]
|
dup_docs = [deepcopy(d) for d in documents]
|
||||||
for i, doc in enumerate(dup_docs):
|
for i, doc in enumerate(dup_docs):
|
||||||
@ -126,7 +125,7 @@ class TimeWeightedVectorStoreRetriever(BaseRetriever, BaseModel):
|
|||||||
self, documents: List[Document], **kwargs: Any
|
self, documents: List[Document], **kwargs: Any
|
||||||
) -> List[str]:
|
) -> List[str]:
|
||||||
"""Add documents to vectorstore."""
|
"""Add documents to vectorstore."""
|
||||||
current_time = kwargs.get("current_time", datetime.now())
|
current_time = kwargs.get("current_time", datetime.datetime.now())
|
||||||
# Avoid mutating input documents
|
# Avoid mutating input documents
|
||||||
dup_docs = [deepcopy(d) for d in documents]
|
dup_docs = [deepcopy(d) for d in documents]
|
||||||
for i, doc in enumerate(dup_docs):
|
for i, doc in enumerate(dup_docs):
|
||||||
|
@ -1,4 +1,6 @@
|
|||||||
"""Generic utility functions."""
|
"""Generic utility functions."""
|
||||||
|
import contextlib
|
||||||
|
import datetime
|
||||||
import os
|
import os
|
||||||
from typing import Any, Callable, Dict, Optional, Tuple
|
from typing import Any, Callable, Dict, Optional, Tuple
|
||||||
|
|
||||||
@ -78,3 +80,34 @@ def stringify_dict(data: dict) -> str:
|
|||||||
for key, value in data.items():
|
for key, value in data.items():
|
||||||
text += key + ": " + stringify_value(value) + "\n"
|
text += key + ": " + stringify_value(value) + "\n"
|
||||||
return text
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
@contextlib.contextmanager
|
||||||
|
def mock_now(dt_value): # type: ignore
|
||||||
|
"""Context manager for mocking out datetime.now() in unit tests.
|
||||||
|
Example:
|
||||||
|
with mock_now(datetime.datetime(2011, 2, 3, 10, 11)):
|
||||||
|
assert datetime.datetime.now() == datetime.datetime(2011, 2, 3, 10, 11)
|
||||||
|
"""
|
||||||
|
|
||||||
|
class MockDateTime(datetime.datetime):
|
||||||
|
@classmethod
|
||||||
|
def now(cls): # type: ignore
|
||||||
|
# Create a copy of dt_value.
|
||||||
|
return datetime.datetime(
|
||||||
|
dt_value.year,
|
||||||
|
dt_value.month,
|
||||||
|
dt_value.day,
|
||||||
|
dt_value.hour,
|
||||||
|
dt_value.minute,
|
||||||
|
dt_value.second,
|
||||||
|
dt_value.microsecond,
|
||||||
|
dt_value.tzinfo,
|
||||||
|
)
|
||||||
|
|
||||||
|
real_datetime = datetime.datetime
|
||||||
|
datetime.datetime = MockDateTime
|
||||||
|
try:
|
||||||
|
yield datetime.datetime
|
||||||
|
finally:
|
||||||
|
datetime.datetime = real_datetime
|
||||||
|
Loading…
Reference in New Issue
Block a user