Harrison/entity store (#2525)

Co-authored-by: Alex Iribarren <alex.iribarren@gmail.com>
doc
Harrison Chase 1 year ago committed by GitHub
parent aa439ac2ff
commit 58a93f88da
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -7,7 +7,11 @@ from langchain.memory.chat_message_histories.dynamodb import DynamoDBChatMessage
from langchain.memory.chat_message_histories.in_memory import ChatMessageHistory
from langchain.memory.chat_message_histories.redis import RedisChatMessageHistory
from langchain.memory.combined import CombinedMemory
from langchain.memory.entity import ConversationEntityMemory
from langchain.memory.entity import (
ConversationEntityMemory,
InMemoryEntityStore,
RedisEntityStore,
)
from langchain.memory.kg import ConversationKGMemory
from langchain.memory.readonly import ReadOnlySharedMemory
from langchain.memory.simple import SimpleMemory
@ -23,6 +27,8 @@ __all__ = [
"ConversationSummaryBufferMemory",
"ConversationKGMemory",
"ConversationEntityMemory",
"InMemoryEntityStore",
"RedisEntityStore",
"ConversationSummaryMemory",
"ChatMessageHistory",
"ConversationStringBufferMemory",

@ -1,4 +1,9 @@
from typing import Any, Dict, List, Optional
import logging
from abc import ABC, abstractmethod
from itertools import islice
from typing import Any, Dict, Iterable, List, Optional
from pydantic import Field
from langchain.chains.llm import LLMChain
from langchain.memory.chat_memory import BaseChatMemory
@ -10,6 +15,137 @@ from langchain.memory.utils import get_prompt_input_key
from langchain.prompts.base import BasePromptTemplate
from langchain.schema import BaseLanguageModel, BaseMessage, get_buffer_string
logger = logging.getLogger(__name__)
class BaseEntityStore(ABC):
@abstractmethod
def get(self, key: str, default: Optional[str] = None) -> Optional[str]:
"""Get entity value from store."""
pass
@abstractmethod
def set(self, key: str, value: Optional[str]) -> None:
"""Set entity value in store."""
pass
@abstractmethod
def delete(self, key: str) -> None:
"""Delete entity value from store."""
pass
@abstractmethod
def exists(self, key: str) -> bool:
"""Check if entity exists in store."""
pass
@abstractmethod
def clear(self) -> None:
"""Delete all entities from store."""
pass
class InMemoryEntityStore(BaseEntityStore):
"""Basic in-memory entity store."""
store: Dict[str, Optional[str]] = {}
def get(self, key: str, default: Optional[str] = None) -> Optional[str]:
return self.store.get(key, default)
def set(self, key: str, value: Optional[str]) -> None:
self.store[key] = value
def delete(self, key: str) -> None:
del self.store[key]
def exists(self, key: str) -> bool:
return key in self.store
def clear(self) -> None:
return self.store.clear()
class RedisEntityStore(BaseEntityStore):
"""Redis-backed Entity store. Entities get a TTL of 1 day by default, and
that TTL is extended by 3 days every time the entity is read back.
"""
redis_client: Any
session_id: str = "default"
key_prefix: str = "memory_store"
ttl: Optional[int] = 60 * 60 * 24
recall_ttl: Optional[int] = 60 * 60 * 24 * 3
def __init__(
self,
session_id: str = "default",
url: str = "redis://localhost:6379/0",
key_prefix: str = "memory_store",
ttl: Optional[int] = 60 * 60 * 24,
recall_ttl: Optional[int] = 60 * 60 * 24 * 3,
*args: Any,
**kwargs: Any,
):
try:
import redis
except ImportError:
raise ValueError(
"Could not import redis python package. "
"Please install it with `pip install redis`."
)
super().__init__(*args, **kwargs)
try:
self.redis_client = redis.Redis.from_url(url=url, decode_responses=True)
except redis.exceptions.ConnectionError as error:
logger.error(error)
self.session_id = session_id
self.key_prefix = key_prefix
self.ttl = ttl
self.recall_ttl = recall_ttl or ttl
@property
def full_key_prefix(self) -> str:
return f"{self.key_prefix}:{self.session_id}"
def get(self, key: str, default: Optional[str] = None) -> Optional[str]:
res = (
self.redis_client.getex(f"{self.full_key_prefix}:{key}", ex=self.recall_ttl)
or default
or ""
)
logger.debug(f"REDIS MEM get '{self.full_key_prefix}:{key}': '{res}'")
return res
def set(self, key: str, value: Optional[str]) -> None:
if not value:
return self.delete(key)
self.redis_client.set(f"{self.full_key_prefix}:{key}", value, ex=self.ttl)
logger.debug(
f"REDIS MEM set '{self.full_key_prefix}:{key}': '{value}' EX {self.ttl}"
)
def delete(self, key: str) -> None:
self.redis_client.delete(f"{self.full_key_prefix}:{key}")
def exists(self, key: str) -> bool:
return self.redis_client.exists(f"{self.full_key_prefix}:{key}") == 1
def clear(self) -> None:
# iterate a list in batches of size batch_size
def batched(iterable: Iterable[Any], batch_size: int) -> Iterable[Any]:
iterator = iter(iterable)
while batch := list(islice(iterator, batch_size)):
yield batch
for keybatch in batched(
self.redis_client.scan_iter(f"{self.full_key_prefix}:*"), 500
):
self.redis_client.delete(*keybatch)
class ConversationEntityMemory(BaseChatMemory):
"""Entity extractor & summarizer to memory."""
@ -19,10 +155,10 @@ class ConversationEntityMemory(BaseChatMemory):
llm: BaseLanguageModel
entity_extraction_prompt: BasePromptTemplate = ENTITY_EXTRACTION_PROMPT
entity_summarization_prompt: BasePromptTemplate = ENTITY_SUMMARIZATION_PROMPT
store: Dict[str, Optional[str]] = {}
entity_cache: List[str] = []
k: int = 3
chat_history_key: str = "history"
entity_store: BaseEntityStore = Field(default_factory=InMemoryEntityStore)
@property
def buffer(self) -> List[BaseMessage]:
@ -58,7 +194,7 @@ class ConversationEntityMemory(BaseChatMemory):
entities = [w.strip() for w in output.split(",")]
entity_summaries = {}
for entity in entities:
entity_summaries[entity] = self.store.get(entity, "")
entity_summaries[entity] = self.entity_store.get(entity, "")
self.entity_cache = entities
if self.return_messages:
buffer: Any = self.buffer[-self.k * 2 :]
@ -87,16 +223,16 @@ class ConversationEntityMemory(BaseChatMemory):
chain = LLMChain(llm=self.llm, prompt=self.entity_summarization_prompt)
for entity in self.entity_cache:
existing_summary = self.store.get(entity, "")
existing_summary = self.entity_store.get(entity, "")
output = chain.predict(
summary=existing_summary,
entity=entity,
history=buffer_string,
input=input_data,
)
self.store[entity] = output.strip()
self.entity_store.set(entity, output.strip())
def clear(self) -> None:
"""Clear memory contents."""
self.chat_memory.clear()
self.store = {}
self.entity_store.clear()

Loading…
Cancel
Save