|
|
@ -1,4 +1,9 @@
|
|
|
|
from typing import Any, Dict, List, Optional
|
|
|
|
import logging
|
|
|
|
|
|
|
|
from abc import ABC, abstractmethod
|
|
|
|
|
|
|
|
from itertools import islice
|
|
|
|
|
|
|
|
from typing import Any, Dict, Iterable, List, Optional
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from pydantic import Field
|
|
|
|
|
|
|
|
|
|
|
|
from langchain.chains.llm import LLMChain
|
|
|
|
from langchain.chains.llm import LLMChain
|
|
|
|
from langchain.memory.chat_memory import BaseChatMemory
|
|
|
|
from langchain.memory.chat_memory import BaseChatMemory
|
|
|
@ -10,6 +15,137 @@ from langchain.memory.utils import get_prompt_input_key
|
|
|
|
from langchain.prompts.base import BasePromptTemplate
|
|
|
|
from langchain.prompts.base import BasePromptTemplate
|
|
|
|
from langchain.schema import BaseLanguageModel, BaseMessage, get_buffer_string
|
|
|
|
from langchain.schema import BaseLanguageModel, BaseMessage, get_buffer_string
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class BaseEntityStore(ABC):
|
|
|
|
|
|
|
|
@abstractmethod
|
|
|
|
|
|
|
|
def get(self, key: str, default: Optional[str] = None) -> Optional[str]:
|
|
|
|
|
|
|
|
"""Get entity value from store."""
|
|
|
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@abstractmethod
|
|
|
|
|
|
|
|
def set(self, key: str, value: Optional[str]) -> None:
|
|
|
|
|
|
|
|
"""Set entity value in store."""
|
|
|
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@abstractmethod
|
|
|
|
|
|
|
|
def delete(self, key: str) -> None:
|
|
|
|
|
|
|
|
"""Delete entity value from store."""
|
|
|
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@abstractmethod
|
|
|
|
|
|
|
|
def exists(self, key: str) -> bool:
|
|
|
|
|
|
|
|
"""Check if entity exists in store."""
|
|
|
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@abstractmethod
|
|
|
|
|
|
|
|
def clear(self) -> None:
|
|
|
|
|
|
|
|
"""Delete all entities from store."""
|
|
|
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class InMemoryEntityStore(BaseEntityStore):
|
|
|
|
|
|
|
|
"""Basic in-memory entity store."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
store: Dict[str, Optional[str]] = {}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get(self, key: str, default: Optional[str] = None) -> Optional[str]:
|
|
|
|
|
|
|
|
return self.store.get(key, default)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def set(self, key: str, value: Optional[str]) -> None:
|
|
|
|
|
|
|
|
self.store[key] = value
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def delete(self, key: str) -> None:
|
|
|
|
|
|
|
|
del self.store[key]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def exists(self, key: str) -> bool:
|
|
|
|
|
|
|
|
return key in self.store
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def clear(self) -> None:
|
|
|
|
|
|
|
|
return self.store.clear()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class RedisEntityStore(BaseEntityStore):
|
|
|
|
|
|
|
|
"""Redis-backed Entity store. Entities get a TTL of 1 day by default, and
|
|
|
|
|
|
|
|
that TTL is extended by 3 days every time the entity is read back.
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
redis_client: Any
|
|
|
|
|
|
|
|
session_id: str = "default"
|
|
|
|
|
|
|
|
key_prefix: str = "memory_store"
|
|
|
|
|
|
|
|
ttl: Optional[int] = 60 * 60 * 24
|
|
|
|
|
|
|
|
recall_ttl: Optional[int] = 60 * 60 * 24 * 3
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
|
|
|
|
self,
|
|
|
|
|
|
|
|
session_id: str = "default",
|
|
|
|
|
|
|
|
url: str = "redis://localhost:6379/0",
|
|
|
|
|
|
|
|
key_prefix: str = "memory_store",
|
|
|
|
|
|
|
|
ttl: Optional[int] = 60 * 60 * 24,
|
|
|
|
|
|
|
|
recall_ttl: Optional[int] = 60 * 60 * 24 * 3,
|
|
|
|
|
|
|
|
*args: Any,
|
|
|
|
|
|
|
|
**kwargs: Any,
|
|
|
|
|
|
|
|
):
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
|
|
|
import redis
|
|
|
|
|
|
|
|
except ImportError:
|
|
|
|
|
|
|
|
raise ValueError(
|
|
|
|
|
|
|
|
"Could not import redis python package. "
|
|
|
|
|
|
|
|
"Please install it with `pip install redis`."
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
super().__init__(*args, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
|
|
|
self.redis_client = redis.Redis.from_url(url=url, decode_responses=True)
|
|
|
|
|
|
|
|
except redis.exceptions.ConnectionError as error:
|
|
|
|
|
|
|
|
logger.error(error)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.session_id = session_id
|
|
|
|
|
|
|
|
self.key_prefix = key_prefix
|
|
|
|
|
|
|
|
self.ttl = ttl
|
|
|
|
|
|
|
|
self.recall_ttl = recall_ttl or ttl
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
|
|
|
def full_key_prefix(self) -> str:
|
|
|
|
|
|
|
|
return f"{self.key_prefix}:{self.session_id}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get(self, key: str, default: Optional[str] = None) -> Optional[str]:
|
|
|
|
|
|
|
|
res = (
|
|
|
|
|
|
|
|
self.redis_client.getex(f"{self.full_key_prefix}:{key}", ex=self.recall_ttl)
|
|
|
|
|
|
|
|
or default
|
|
|
|
|
|
|
|
or ""
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
logger.debug(f"REDIS MEM get '{self.full_key_prefix}:{key}': '{res}'")
|
|
|
|
|
|
|
|
return res
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def set(self, key: str, value: Optional[str]) -> None:
|
|
|
|
|
|
|
|
if not value:
|
|
|
|
|
|
|
|
return self.delete(key)
|
|
|
|
|
|
|
|
self.redis_client.set(f"{self.full_key_prefix}:{key}", value, ex=self.ttl)
|
|
|
|
|
|
|
|
logger.debug(
|
|
|
|
|
|
|
|
f"REDIS MEM set '{self.full_key_prefix}:{key}': '{value}' EX {self.ttl}"
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def delete(self, key: str) -> None:
|
|
|
|
|
|
|
|
self.redis_client.delete(f"{self.full_key_prefix}:{key}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def exists(self, key: str) -> bool:
|
|
|
|
|
|
|
|
return self.redis_client.exists(f"{self.full_key_prefix}:{key}") == 1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def clear(self) -> None:
|
|
|
|
|
|
|
|
# iterate a list in batches of size batch_size
|
|
|
|
|
|
|
|
def batched(iterable: Iterable[Any], batch_size: int) -> Iterable[Any]:
|
|
|
|
|
|
|
|
iterator = iter(iterable)
|
|
|
|
|
|
|
|
while batch := list(islice(iterator, batch_size)):
|
|
|
|
|
|
|
|
yield batch
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for keybatch in batched(
|
|
|
|
|
|
|
|
self.redis_client.scan_iter(f"{self.full_key_prefix}:*"), 500
|
|
|
|
|
|
|
|
):
|
|
|
|
|
|
|
|
self.redis_client.delete(*keybatch)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ConversationEntityMemory(BaseChatMemory):
|
|
|
|
class ConversationEntityMemory(BaseChatMemory):
|
|
|
|
"""Entity extractor & summarizer to memory."""
|
|
|
|
"""Entity extractor & summarizer to memory."""
|
|
|
@ -19,10 +155,10 @@ class ConversationEntityMemory(BaseChatMemory):
|
|
|
|
llm: BaseLanguageModel
|
|
|
|
llm: BaseLanguageModel
|
|
|
|
entity_extraction_prompt: BasePromptTemplate = ENTITY_EXTRACTION_PROMPT
|
|
|
|
entity_extraction_prompt: BasePromptTemplate = ENTITY_EXTRACTION_PROMPT
|
|
|
|
entity_summarization_prompt: BasePromptTemplate = ENTITY_SUMMARIZATION_PROMPT
|
|
|
|
entity_summarization_prompt: BasePromptTemplate = ENTITY_SUMMARIZATION_PROMPT
|
|
|
|
store: Dict[str, Optional[str]] = {}
|
|
|
|
|
|
|
|
entity_cache: List[str] = []
|
|
|
|
entity_cache: List[str] = []
|
|
|
|
k: int = 3
|
|
|
|
k: int = 3
|
|
|
|
chat_history_key: str = "history"
|
|
|
|
chat_history_key: str = "history"
|
|
|
|
|
|
|
|
entity_store: BaseEntityStore = Field(default_factory=InMemoryEntityStore)
|
|
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
@property
|
|
|
|
def buffer(self) -> List[BaseMessage]:
|
|
|
|
def buffer(self) -> List[BaseMessage]:
|
|
|
@ -58,7 +194,7 @@ class ConversationEntityMemory(BaseChatMemory):
|
|
|
|
entities = [w.strip() for w in output.split(",")]
|
|
|
|
entities = [w.strip() for w in output.split(",")]
|
|
|
|
entity_summaries = {}
|
|
|
|
entity_summaries = {}
|
|
|
|
for entity in entities:
|
|
|
|
for entity in entities:
|
|
|
|
entity_summaries[entity] = self.store.get(entity, "")
|
|
|
|
entity_summaries[entity] = self.entity_store.get(entity, "")
|
|
|
|
self.entity_cache = entities
|
|
|
|
self.entity_cache = entities
|
|
|
|
if self.return_messages:
|
|
|
|
if self.return_messages:
|
|
|
|
buffer: Any = self.buffer[-self.k * 2 :]
|
|
|
|
buffer: Any = self.buffer[-self.k * 2 :]
|
|
|
@ -87,16 +223,16 @@ class ConversationEntityMemory(BaseChatMemory):
|
|
|
|
chain = LLMChain(llm=self.llm, prompt=self.entity_summarization_prompt)
|
|
|
|
chain = LLMChain(llm=self.llm, prompt=self.entity_summarization_prompt)
|
|
|
|
|
|
|
|
|
|
|
|
for entity in self.entity_cache:
|
|
|
|
for entity in self.entity_cache:
|
|
|
|
existing_summary = self.store.get(entity, "")
|
|
|
|
existing_summary = self.entity_store.get(entity, "")
|
|
|
|
output = chain.predict(
|
|
|
|
output = chain.predict(
|
|
|
|
summary=existing_summary,
|
|
|
|
summary=existing_summary,
|
|
|
|
entity=entity,
|
|
|
|
entity=entity,
|
|
|
|
history=buffer_string,
|
|
|
|
history=buffer_string,
|
|
|
|
input=input_data,
|
|
|
|
input=input_data,
|
|
|
|
)
|
|
|
|
)
|
|
|
|
self.store[entity] = output.strip()
|
|
|
|
self.entity_store.set(entity, output.strip())
|
|
|
|
|
|
|
|
|
|
|
|
def clear(self) -> None:
|
|
|
|
def clear(self) -> None:
|
|
|
|
"""Clear memory contents."""
|
|
|
|
"""Clear memory contents."""
|
|
|
|
self.chat_memory.clear()
|
|
|
|
self.chat_memory.clear()
|
|
|
|
self.store = {}
|
|
|
|
self.entity_store.clear()
|
|
|
|