You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
langchain/langchain/memory/entity.py

239 lines
7.6 KiB
Python

import logging
from abc import ABC, abstractmethod
from itertools import islice
from typing import Any, Dict, Iterable, List, Optional
from pydantic import Field
from langchain.chains.llm import LLMChain
from langchain.memory.chat_memory import BaseChatMemory
from langchain.memory.prompt import (
ENTITY_EXTRACTION_PROMPT,
ENTITY_SUMMARIZATION_PROMPT,
)
from langchain.memory.utils import get_prompt_input_key
from langchain.prompts.base import BasePromptTemplate
from langchain.schema import BaseLanguageModel, BaseMessage, get_buffer_string
logger = logging.getLogger(__name__)
class BaseEntityStore(ABC):
@abstractmethod
def get(self, key: str, default: Optional[str] = None) -> Optional[str]:
"""Get entity value from store."""
pass
@abstractmethod
def set(self, key: str, value: Optional[str]) -> None:
"""Set entity value in store."""
pass
@abstractmethod
def delete(self, key: str) -> None:
"""Delete entity value from store."""
pass
@abstractmethod
def exists(self, key: str) -> bool:
"""Check if entity exists in store."""
pass
@abstractmethod
def clear(self) -> None:
"""Delete all entities from store."""
pass
class InMemoryEntityStore(BaseEntityStore):
"""Basic in-memory entity store."""
store: Dict[str, Optional[str]] = {}
def get(self, key: str, default: Optional[str] = None) -> Optional[str]:
return self.store.get(key, default)
def set(self, key: str, value: Optional[str]) -> None:
self.store[key] = value
def delete(self, key: str) -> None:
del self.store[key]
def exists(self, key: str) -> bool:
return key in self.store
def clear(self) -> None:
return self.store.clear()
class RedisEntityStore(BaseEntityStore):
"""Redis-backed Entity store. Entities get a TTL of 1 day by default, and
that TTL is extended by 3 days every time the entity is read back.
"""
redis_client: Any
session_id: str = "default"
key_prefix: str = "memory_store"
ttl: Optional[int] = 60 * 60 * 24
recall_ttl: Optional[int] = 60 * 60 * 24 * 3
def __init__(
self,
session_id: str = "default",
url: str = "redis://localhost:6379/0",
key_prefix: str = "memory_store",
ttl: Optional[int] = 60 * 60 * 24,
recall_ttl: Optional[int] = 60 * 60 * 24 * 3,
*args: Any,
**kwargs: Any,
):
try:
import redis
except ImportError:
raise ValueError(
"Could not import redis python package. "
"Please install it with `pip install redis`."
)
super().__init__(*args, **kwargs)
try:
self.redis_client = redis.Redis.from_url(url=url, decode_responses=True)
except redis.exceptions.ConnectionError as error:
logger.error(error)
self.session_id = session_id
self.key_prefix = key_prefix
self.ttl = ttl
self.recall_ttl = recall_ttl or ttl
@property
def full_key_prefix(self) -> str:
return f"{self.key_prefix}:{self.session_id}"
def get(self, key: str, default: Optional[str] = None) -> Optional[str]:
res = (
self.redis_client.getex(f"{self.full_key_prefix}:{key}", ex=self.recall_ttl)
or default
or ""
)
logger.debug(f"REDIS MEM get '{self.full_key_prefix}:{key}': '{res}'")
return res
def set(self, key: str, value: Optional[str]) -> None:
if not value:
return self.delete(key)
self.redis_client.set(f"{self.full_key_prefix}:{key}", value, ex=self.ttl)
logger.debug(
f"REDIS MEM set '{self.full_key_prefix}:{key}': '{value}' EX {self.ttl}"
)
def delete(self, key: str) -> None:
self.redis_client.delete(f"{self.full_key_prefix}:{key}")
def exists(self, key: str) -> bool:
return self.redis_client.exists(f"{self.full_key_prefix}:{key}") == 1
def clear(self) -> None:
# iterate a list in batches of size batch_size
def batched(iterable: Iterable[Any], batch_size: int) -> Iterable[Any]:
iterator = iter(iterable)
while batch := list(islice(iterator, batch_size)):
yield batch
for keybatch in batched(
self.redis_client.scan_iter(f"{self.full_key_prefix}:*"), 500
):
self.redis_client.delete(*keybatch)
class ConversationEntityMemory(BaseChatMemory):
"""Entity extractor & summarizer to memory."""
human_prefix: str = "Human"
ai_prefix: str = "AI"
llm: BaseLanguageModel
entity_extraction_prompt: BasePromptTemplate = ENTITY_EXTRACTION_PROMPT
entity_summarization_prompt: BasePromptTemplate = ENTITY_SUMMARIZATION_PROMPT
entity_cache: List[str] = []
k: int = 3
chat_history_key: str = "history"
entity_store: BaseEntityStore = Field(default_factory=InMemoryEntityStore)
@property
def buffer(self) -> List[BaseMessage]:
return self.chat_memory.messages
@property
def memory_variables(self) -> List[str]:
"""Will always return list of memory variables.
:meta private:
"""
return ["entities", self.chat_history_key]
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
"""Return history buffer."""
chain = LLMChain(llm=self.llm, prompt=self.entity_extraction_prompt)
if self.input_key is None:
prompt_input_key = get_prompt_input_key(inputs, self.memory_variables)
else:
prompt_input_key = self.input_key
buffer_string = get_buffer_string(
self.buffer[-self.k * 2 :],
human_prefix=self.human_prefix,
ai_prefix=self.ai_prefix,
)
output = chain.predict(
history=buffer_string,
input=inputs[prompt_input_key],
)
if output.strip() == "NONE":
entities = []
else:
entities = [w.strip() for w in output.split(",")]
entity_summaries = {}
for entity in entities:
entity_summaries[entity] = self.entity_store.get(entity, "")
self.entity_cache = entities
if self.return_messages:
buffer: Any = self.buffer[-self.k * 2 :]
else:
buffer = buffer_string
return {
self.chat_history_key: buffer,
"entities": entity_summaries,
}
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
"""Save context from this conversation to buffer."""
super().save_context(inputs, outputs)
if self.input_key is None:
prompt_input_key = get_prompt_input_key(inputs, self.memory_variables)
else:
prompt_input_key = self.input_key
buffer_string = get_buffer_string(
self.buffer[-self.k * 2 :],
human_prefix=self.human_prefix,
ai_prefix=self.ai_prefix,
)
input_data = inputs[prompt_input_key]
chain = LLMChain(llm=self.llm, prompt=self.entity_summarization_prompt)
for entity in self.entity_cache:
existing_summary = self.entity_store.get(entity, "")
output = chain.predict(
summary=existing_summary,
entity=entity,
history=buffer_string,
input=input_data,
)
self.entity_store.set(entity, output.strip())
def clear(self) -> None:
"""Clear memory contents."""
self.chat_memory.clear()
self.entity_store.clear()