From 58a93f88dac5e2e15c5a9005c262196ed273cdf0 Mon Sep 17 00:00:00 2001 From: Harrison Chase Date: Thu, 6 Apr 2023 22:54:38 -0700 Subject: [PATCH] Harrison/entity store (#2525) Co-authored-by: Alex Iribarren --- langchain/memory/__init__.py | 8 +- langchain/memory/entity.py | 148 +++++++++++++++++++++++++++++++++-- 2 files changed, 149 insertions(+), 7 deletions(-) diff --git a/langchain/memory/__init__.py b/langchain/memory/__init__.py index 5799e734b1..aa56cc6c0f 100644 --- a/langchain/memory/__init__.py +++ b/langchain/memory/__init__.py @@ -7,7 +7,11 @@ from langchain.memory.chat_message_histories.dynamodb import DynamoDBChatMessage from langchain.memory.chat_message_histories.in_memory import ChatMessageHistory from langchain.memory.chat_message_histories.redis import RedisChatMessageHistory from langchain.memory.combined import CombinedMemory -from langchain.memory.entity import ConversationEntityMemory +from langchain.memory.entity import ( + ConversationEntityMemory, + InMemoryEntityStore, + RedisEntityStore, +) from langchain.memory.kg import ConversationKGMemory from langchain.memory.readonly import ReadOnlySharedMemory from langchain.memory.simple import SimpleMemory @@ -23,6 +27,8 @@ __all__ = [ "ConversationSummaryBufferMemory", "ConversationKGMemory", "ConversationEntityMemory", + "InMemoryEntityStore", + "RedisEntityStore", "ConversationSummaryMemory", "ChatMessageHistory", "ConversationStringBufferMemory", diff --git a/langchain/memory/entity.py b/langchain/memory/entity.py index 95aac811a2..8863c076fc 100644 --- a/langchain/memory/entity.py +++ b/langchain/memory/entity.py @@ -1,4 +1,9 @@ -from typing import Any, Dict, List, Optional +import logging +from abc import ABC, abstractmethod +from itertools import islice +from typing import Any, Dict, Iterable, List, Optional + +from pydantic import Field from langchain.chains.llm import LLMChain from langchain.memory.chat_memory import BaseChatMemory @@ -10,6 +15,137 @@ from langchain.memory.utils import get_prompt_input_key from langchain.prompts.base import BasePromptTemplate from langchain.schema import BaseLanguageModel, BaseMessage, get_buffer_string +logger = logging.getLogger(__name__) + + +class BaseEntityStore(ABC): + @abstractmethod + def get(self, key: str, default: Optional[str] = None) -> Optional[str]: + """Get entity value from store.""" + pass + + @abstractmethod + def set(self, key: str, value: Optional[str]) -> None: + """Set entity value in store.""" + pass + + @abstractmethod + def delete(self, key: str) -> None: + """Delete entity value from store.""" + pass + + @abstractmethod + def exists(self, key: str) -> bool: + """Check if entity exists in store.""" + pass + + @abstractmethod + def clear(self) -> None: + """Delete all entities from store.""" + pass + + +class InMemoryEntityStore(BaseEntityStore): + """Basic in-memory entity store.""" + + store: Dict[str, Optional[str]] = {} + + def get(self, key: str, default: Optional[str] = None) -> Optional[str]: + return self.store.get(key, default) + + def set(self, key: str, value: Optional[str]) -> None: + self.store[key] = value + + def delete(self, key: str) -> None: + del self.store[key] + + def exists(self, key: str) -> bool: + return key in self.store + + def clear(self) -> None: + return self.store.clear() + + +class RedisEntityStore(BaseEntityStore): + """Redis-backed Entity store. Entities get a TTL of 1 day by default, and + that TTL is extended by 3 days every time the entity is read back. + """ + + redis_client: Any + session_id: str = "default" + key_prefix: str = "memory_store" + ttl: Optional[int] = 60 * 60 * 24 + recall_ttl: Optional[int] = 60 * 60 * 24 * 3 + + def __init__( + self, + session_id: str = "default", + url: str = "redis://localhost:6379/0", + key_prefix: str = "memory_store", + ttl: Optional[int] = 60 * 60 * 24, + recall_ttl: Optional[int] = 60 * 60 * 24 * 3, + *args: Any, + **kwargs: Any, + ): + try: + import redis + except ImportError: + raise ValueError( + "Could not import redis python package. " + "Please install it with `pip install redis`." + ) + + super().__init__(*args, **kwargs) + + try: + self.redis_client = redis.Redis.from_url(url=url, decode_responses=True) + except redis.exceptions.ConnectionError as error: + logger.error(error) + + self.session_id = session_id + self.key_prefix = key_prefix + self.ttl = ttl + self.recall_ttl = recall_ttl or ttl + + @property + def full_key_prefix(self) -> str: + return f"{self.key_prefix}:{self.session_id}" + + def get(self, key: str, default: Optional[str] = None) -> Optional[str]: + res = ( + self.redis_client.getex(f"{self.full_key_prefix}:{key}", ex=self.recall_ttl) + or default + or "" + ) + logger.debug(f"REDIS MEM get '{self.full_key_prefix}:{key}': '{res}'") + return res + + def set(self, key: str, value: Optional[str]) -> None: + if not value: + return self.delete(key) + self.redis_client.set(f"{self.full_key_prefix}:{key}", value, ex=self.ttl) + logger.debug( + f"REDIS MEM set '{self.full_key_prefix}:{key}': '{value}' EX {self.ttl}" + ) + + def delete(self, key: str) -> None: + self.redis_client.delete(f"{self.full_key_prefix}:{key}") + + def exists(self, key: str) -> bool: + return self.redis_client.exists(f"{self.full_key_prefix}:{key}") == 1 + + def clear(self) -> None: + # iterate a list in batches of size batch_size + def batched(iterable: Iterable[Any], batch_size: int) -> Iterable[Any]: + iterator = iter(iterable) + while batch := list(islice(iterator, batch_size)): + yield batch + + for keybatch in batched( + self.redis_client.scan_iter(f"{self.full_key_prefix}:*"), 500 + ): + self.redis_client.delete(*keybatch) + class ConversationEntityMemory(BaseChatMemory): """Entity extractor & summarizer to memory.""" @@ -19,10 +155,10 @@ class ConversationEntityMemory(BaseChatMemory): llm: BaseLanguageModel entity_extraction_prompt: BasePromptTemplate = ENTITY_EXTRACTION_PROMPT entity_summarization_prompt: BasePromptTemplate = ENTITY_SUMMARIZATION_PROMPT - store: Dict[str, Optional[str]] = {} entity_cache: List[str] = [] k: int = 3 chat_history_key: str = "history" + entity_store: BaseEntityStore = Field(default_factory=InMemoryEntityStore) @property def buffer(self) -> List[BaseMessage]: @@ -58,7 +194,7 @@ class ConversationEntityMemory(BaseChatMemory): entities = [w.strip() for w in output.split(",")] entity_summaries = {} for entity in entities: - entity_summaries[entity] = self.store.get(entity, "") + entity_summaries[entity] = self.entity_store.get(entity, "") self.entity_cache = entities if self.return_messages: buffer: Any = self.buffer[-self.k * 2 :] @@ -87,16 +223,16 @@ class ConversationEntityMemory(BaseChatMemory): chain = LLMChain(llm=self.llm, prompt=self.entity_summarization_prompt) for entity in self.entity_cache: - existing_summary = self.store.get(entity, "") + existing_summary = self.entity_store.get(entity, "") output = chain.predict( summary=existing_summary, entity=entity, history=buffer_string, input=input_data, ) - self.store[entity] = output.strip() + self.entity_store.set(entity, output.strip()) def clear(self) -> None: """Clear memory contents.""" self.chat_memory.clear() - self.store = {} + self.entity_store.clear()