langchain[patch],community[minor]: Migrate memory implementations to community (#20845)

Migrates memory implementations to community
This commit is contained in:
Eugene Yurtsev 2024-05-02 10:46:50 -04:00 committed by GitHub
parent b5c3a04e4b
commit 3cd7fced5f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 373 additions and 217 deletions

View File

@ -0,0 +1,140 @@
from typing import Any, Dict, List, Type, Union
from langchain_core.language_models import BaseLanguageModel
from langchain_core.messages import BaseMessage, SystemMessage, get_buffer_string
from langchain_core.prompts import BasePromptTemplate
from langchain_core.pydantic_v1 import Field
from langchain_community.graphs import NetworkxEntityGraph
from langchain_community.graphs.networkx_graph import (
KnowledgeTriple,
get_entities,
parse_triples,
)
try:
from langchain.chains.llm import LLMChain
from langchain.memory.chat_memory import BaseChatMemory
from langchain.memory.prompt import (
ENTITY_EXTRACTION_PROMPT,
KNOWLEDGE_TRIPLE_EXTRACTION_PROMPT,
)
from langchain.memory.utils import get_prompt_input_key
class ConversationKGMemory(BaseChatMemory):
"""Knowledge graph conversation memory.
Integrates with external knowledge graph to store and retrieve
information about knowledge triples in the conversation.
"""
k: int = 2
human_prefix: str = "Human"
ai_prefix: str = "AI"
kg: NetworkxEntityGraph = Field(default_factory=NetworkxEntityGraph)
knowledge_extraction_prompt: BasePromptTemplate = (
KNOWLEDGE_TRIPLE_EXTRACTION_PROMPT
)
entity_extraction_prompt: BasePromptTemplate = ENTITY_EXTRACTION_PROMPT
llm: BaseLanguageModel
summary_message_cls: Type[BaseMessage] = SystemMessage
"""Number of previous utterances to include in the context."""
memory_key: str = "history" #: :meta private:
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
"""Return history buffer."""
entities = self._get_current_entities(inputs)
summary_strings = []
for entity in entities:
knowledge = self.kg.get_entity_knowledge(entity)
if knowledge:
summary = f"On {entity}: {'. '.join(knowledge)}."
summary_strings.append(summary)
context: Union[str, List]
if not summary_strings:
context = [] if self.return_messages else ""
elif self.return_messages:
context = [
self.summary_message_cls(content=text) for text in summary_strings
]
else:
context = "\n".join(summary_strings)
return {self.memory_key: context}
@property
def memory_variables(self) -> List[str]:
"""Will always return list of memory variables.
:meta private:
"""
return [self.memory_key]
def _get_prompt_input_key(self, inputs: Dict[str, Any]) -> str:
"""Get the input key for the prompt."""
if self.input_key is None:
return get_prompt_input_key(inputs, self.memory_variables)
return self.input_key
def _get_prompt_output_key(self, outputs: Dict[str, Any]) -> str:
"""Get the output key for the prompt."""
if self.output_key is None:
if len(outputs) != 1:
raise ValueError(f"One output key expected, got {outputs.keys()}")
return list(outputs.keys())[0]
return self.output_key
def get_current_entities(self, input_string: str) -> List[str]:
chain = LLMChain(llm=self.llm, prompt=self.entity_extraction_prompt)
buffer_string = get_buffer_string(
self.chat_memory.messages[-self.k * 2 :],
human_prefix=self.human_prefix,
ai_prefix=self.ai_prefix,
)
output = chain.predict(
history=buffer_string,
input=input_string,
)
return get_entities(output)
def _get_current_entities(self, inputs: Dict[str, Any]) -> List[str]:
"""Get the current entities in the conversation."""
prompt_input_key = self._get_prompt_input_key(inputs)
return self.get_current_entities(inputs[prompt_input_key])
def get_knowledge_triplets(self, input_string: str) -> List[KnowledgeTriple]:
chain = LLMChain(llm=self.llm, prompt=self.knowledge_extraction_prompt)
buffer_string = get_buffer_string(
self.chat_memory.messages[-self.k * 2 :],
human_prefix=self.human_prefix,
ai_prefix=self.ai_prefix,
)
output = chain.predict(
history=buffer_string,
input=input_string,
verbose=True,
)
knowledge = parse_triples(output)
return knowledge
def _get_and_update_kg(self, inputs: Dict[str, Any]) -> None:
"""Get and update knowledge graph from the conversation history."""
prompt_input_key = self._get_prompt_input_key(inputs)
knowledge = self.get_knowledge_triplets(inputs[prompt_input_key])
for triple in knowledge:
self.kg.add_triple(triple)
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
"""Save context from this conversation to buffer."""
super().save_context(inputs, outputs)
self._get_and_update_kg(inputs)
def clear(self) -> None:
"""Clear memory contents."""
super().clear()
self.kg.clear()
except ImportError:
# Placeholder object
class ConversationKGMemory: # type: ignore[no-redef]
pass

View File

@ -0,0 +1,100 @@
from typing import Any, Dict, List, Optional
import requests
from langchain_core.messages import get_buffer_string
try:
# Temporarily tuck import in a conditional import until
# community pkg becomes dependent on langchain core
from langchain.memory.chat_memory import BaseChatMemory
MANAGED_URL = "https://api.getmetal.io/v1/motorhead"
class MotorheadMemory(BaseChatMemory):
"""Chat message memory backed by Motorhead service."""
url: str = MANAGED_URL
timeout: int = 3000
memory_key: str = "history"
session_id: str
context: Optional[str] = None
# Managed Params
api_key: Optional[str] = None
client_id: Optional[str] = None
def __get_headers(self) -> Dict[str, str]:
is_managed = self.url == MANAGED_URL
headers = {
"Content-Type": "application/json",
}
if is_managed and not (self.api_key and self.client_id):
raise ValueError(
"""
You must provide an API key or a client ID to use the managed
version of Motorhead. Visit https://getmetal.io
for more information.
"""
)
if is_managed and self.api_key and self.client_id:
headers["x-metal-api-key"] = self.api_key
headers["x-metal-client-id"] = self.client_id
return headers
async def init(self) -> None:
res = requests.get(
f"{self.url}/sessions/{self.session_id}/memory",
timeout=self.timeout,
headers=self.__get_headers(),
)
res_data = res.json()
res_data = res_data.get("data", res_data) # Handle Managed Version
messages = res_data.get("messages", [])
context = res_data.get("context", "NONE")
for message in reversed(messages):
if message["role"] == "AI":
self.chat_memory.add_ai_message(message["content"])
else:
self.chat_memory.add_user_message(message["content"])
if context and context != "NONE":
self.context = context
def load_memory_variables(self, values: Dict[str, Any]) -> Dict[str, Any]:
if self.return_messages:
return {self.memory_key: self.chat_memory.messages}
else:
return {self.memory_key: get_buffer_string(self.chat_memory.messages)}
@property
def memory_variables(self) -> List[str]:
return [self.memory_key]
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
input_str, output_str = self._get_input_output(inputs, outputs)
requests.post(
f"{self.url}/sessions/{self.session_id}/memory",
timeout=self.timeout,
json={
"messages": [
{"role": "Human", "content": f"{input_str}"},
{"role": "AI", "content": f"{output_str}"},
]
},
headers=self.__get_headers(),
)
super().save_context(inputs, outputs)
def delete_session(self) -> None:
"""Delete a session"""
requests.delete(f"{self.url}/sessions/{self.session_id}/memory")
except ImportError:
# Placeholder object
class MotorheadMemory: # type: ignore[no-redef]
pass

View File

@ -0,0 +1,129 @@
from __future__ import annotations
from typing import Any, Dict, Optional
from langchain_community.chat_message_histories import ZepChatMessageHistory
try:
from langchain.memory import ConversationBufferMemory
class ZepMemory(ConversationBufferMemory):
"""Persist your chain history to the Zep MemoryStore.
The number of messages returned by Zep and when the Zep server summarizes chat
histories is configurable. See the Zep documentation for more details.
Documentation: https://docs.getzep.com
Example:
.. code-block:: python
memory = ZepMemory(
session_id=session_id, # Identifies your user or a user's session
url=ZEP_API_URL, # Your Zep server's URL
api_key=<your_api_key>, # Optional
memory_key="history", # Ensure this matches the key used in
# chain's prompt template
return_messages=True, # Does your prompt template expect a string
# or a list of Messages?
)
chain = LLMChain(memory=memory,...) # Configure your chain to use the ZepMemory
instance
Note:
To persist metadata alongside your chat history, your will need to create a
custom Chain class that overrides the `prep_outputs` method to include the metadata
in the call to `self.memory.save_context`.
Zep - Fast, scalable building blocks for LLM Apps
=========
Zep is an open source platform for productionizing LLM apps. Go from a prototype
built in LangChain or LlamaIndex, or a custom app, to production in minutes without
rewriting code.
For server installation instructions and more, see:
https://docs.getzep.com/deployment/quickstart/
For more information on the zep-python package, see:
https://github.com/getzep/zep-python
""" # noqa: E501
chat_memory: ZepChatMessageHistory
def __init__(
self,
session_id: str,
url: str = "http://localhost:8000",
api_key: Optional[str] = None,
output_key: Optional[str] = None,
input_key: Optional[str] = None,
return_messages: bool = False,
human_prefix: str = "Human",
ai_prefix: str = "AI",
memory_key: str = "history",
):
"""Initialize ZepMemory.
Args:
session_id (str): Identifies your user or a user's session
url (str, optional): Your Zep server's URL. Defaults to
"http://localhost:8000".
api_key (Optional[str], optional): Your Zep API key. Defaults to None.
output_key (Optional[str], optional): The key to use for the output message.
Defaults to None.
input_key (Optional[str], optional): The key to use for the input message.
Defaults to None.
return_messages (bool, optional): Does your prompt template expect a string
or a list of Messages? Defaults to False
i.e. return a string.
human_prefix (str, optional): The prefix to use for human messages.
Defaults to "Human".
ai_prefix (str, optional): The prefix to use for AI messages.
Defaults to "AI".
memory_key (str, optional): The key to use for the memory.
Defaults to "history".
Ensure that this matches the key used in
chain's prompt template.
""" # noqa: E501
chat_message_history = ZepChatMessageHistory(
session_id=session_id,
url=url,
api_key=api_key,
)
super().__init__(
chat_memory=chat_message_history,
output_key=output_key,
input_key=input_key,
return_messages=return_messages,
human_prefix=human_prefix,
ai_prefix=ai_prefix,
memory_key=memory_key,
)
def save_context(
self,
inputs: Dict[str, Any],
outputs: Dict[str, str],
metadata: Optional[Dict[str, Any]] = None,
) -> None:
"""Save context from this conversation to buffer.
Args:
inputs (Dict[str, Any]): The inputs to the chain.
outputs (Dict[str, str]): The outputs from the chain.
metadata (Optional[Dict[str, Any]], optional): Any metadata to save with
the context. Defaults to None
Returns:
None
"""
input_str, output_str = self._get_input_output(inputs, outputs)
self.chat_memory.add_user_message(input_str, metadata=metadata)
self.chat_memory.add_ai_message(output_str, metadata=metadata)
except ImportError:
# Placeholder object
class ZepMemory: # type: ignore[no-redef]
pass

View File

@ -1,94 +1,3 @@
from typing import Any, Dict, List, Optional
from langchain_community.memory.motorhead_memory import MotorheadMemory
import requests
from langchain_core.messages import get_buffer_string
from langchain.memory.chat_memory import BaseChatMemory
MANAGED_URL = "https://api.getmetal.io/v1/motorhead"
# LOCAL_URL = "http://localhost:8080"
class MotorheadMemory(BaseChatMemory):
"""Chat message memory backed by Motorhead service."""
url: str = MANAGED_URL
timeout: int = 3000
memory_key: str = "history"
session_id: str
context: Optional[str] = None
# Managed Params
api_key: Optional[str] = None
client_id: Optional[str] = None
def __get_headers(self) -> Dict[str, str]:
is_managed = self.url == MANAGED_URL
headers = {
"Content-Type": "application/json",
}
if is_managed and not (self.api_key and self.client_id):
raise ValueError(
"""
You must provide an API key or a client ID to use the managed
version of Motorhead. Visit https://getmetal.io for more information.
"""
)
if is_managed and self.api_key and self.client_id:
headers["x-metal-api-key"] = self.api_key
headers["x-metal-client-id"] = self.client_id
return headers
async def init(self) -> None:
res = requests.get(
f"{self.url}/sessions/{self.session_id}/memory",
timeout=self.timeout,
headers=self.__get_headers(),
)
res_data = res.json()
res_data = res_data.get("data", res_data) # Handle Managed Version
messages = res_data.get("messages", [])
context = res_data.get("context", "NONE")
for message in reversed(messages):
if message["role"] == "AI":
self.chat_memory.add_ai_message(message["content"])
else:
self.chat_memory.add_user_message(message["content"])
if context and context != "NONE":
self.context = context
def load_memory_variables(self, values: Dict[str, Any]) -> Dict[str, Any]:
if self.return_messages:
return {self.memory_key: self.chat_memory.messages}
else:
return {self.memory_key: get_buffer_string(self.chat_memory.messages)}
@property
def memory_variables(self) -> List[str]:
return [self.memory_key]
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
input_str, output_str = self._get_input_output(inputs, outputs)
requests.post(
f"{self.url}/sessions/{self.session_id}/memory",
timeout=self.timeout,
json={
"messages": [
{"role": "Human", "content": f"{input_str}"},
{"role": "AI", "content": f"{output_str}"},
]
},
headers=self.__get_headers(),
)
super().save_context(inputs, outputs)
def delete_session(self) -> None:
"""Delete a session"""
requests.delete(f"{self.url}/sessions/{self.session_id}/memory")
__all__ = ["MotorheadMemory"]

View File

@ -1,125 +1,3 @@
from __future__ import annotations
from langchain_community.memory.zep_memory import ZepMemory
from typing import Any, Dict, Optional
from langchain_community.chat_message_histories import ZepChatMessageHistory
from langchain.memory import ConversationBufferMemory
class ZepMemory(ConversationBufferMemory):
"""Persist your chain history to the Zep MemoryStore.
The number of messages returned by Zep and when the Zep server summarizes chat
histories is configurable. See the Zep documentation for more details.
Documentation: https://docs.getzep.com
Example:
.. code-block:: python
memory = ZepMemory(
session_id=session_id, # Identifies your user or a user's session
url=ZEP_API_URL, # Your Zep server's URL
api_key=<your_api_key>, # Optional
memory_key="history", # Ensure this matches the key used in
# chain's prompt template
return_messages=True, # Does your prompt template expect a string
# or a list of Messages?
)
chain = LLMChain(memory=memory,...) # Configure your chain to use the ZepMemory
instance
Note:
To persist metadata alongside your chat history, your will need to create a
custom Chain class that overrides the `prep_outputs` method to include the metadata
in the call to `self.memory.save_context`.
Zep - Fast, scalable building blocks for LLM Apps
=========
Zep is an open source platform for productionizing LLM apps. Go from a prototype
built in LangChain or LlamaIndex, or a custom app, to production in minutes without
rewriting code.
For server installation instructions and more, see:
https://docs.getzep.com/deployment/quickstart/
For more information on the zep-python package, see:
https://github.com/getzep/zep-python
"""
chat_memory: ZepChatMessageHistory
def __init__(
self,
session_id: str,
url: str = "http://localhost:8000",
api_key: Optional[str] = None,
output_key: Optional[str] = None,
input_key: Optional[str] = None,
return_messages: bool = False,
human_prefix: str = "Human",
ai_prefix: str = "AI",
memory_key: str = "history",
):
"""Initialize ZepMemory.
Args:
session_id (str): Identifies your user or a user's session
url (str, optional): Your Zep server's URL. Defaults to
"http://localhost:8000".
api_key (Optional[str], optional): Your Zep API key. Defaults to None.
output_key (Optional[str], optional): The key to use for the output message.
Defaults to None.
input_key (Optional[str], optional): The key to use for the input message.
Defaults to None.
return_messages (bool, optional): Does your prompt template expect a string
or a list of Messages? Defaults to False
i.e. return a string.
human_prefix (str, optional): The prefix to use for human messages.
Defaults to "Human".
ai_prefix (str, optional): The prefix to use for AI messages.
Defaults to "AI".
memory_key (str, optional): The key to use for the memory.
Defaults to "history".
Ensure that this matches the key used in
chain's prompt template.
"""
chat_message_history = ZepChatMessageHistory(
session_id=session_id,
url=url,
api_key=api_key,
)
super().__init__(
chat_memory=chat_message_history,
output_key=output_key,
input_key=input_key,
return_messages=return_messages,
human_prefix=human_prefix,
ai_prefix=ai_prefix,
memory_key=memory_key,
)
def save_context(
self,
inputs: Dict[str, Any],
outputs: Dict[str, str],
metadata: Optional[Dict[str, Any]] = None,
) -> None:
"""Save context from this conversation to buffer.
Args:
inputs (Dict[str, Any]): The inputs to the chain.
outputs (Dict[str, str]): The outputs from the chain.
metadata (Optional[Dict[str, Any]], optional): Any metadata to save with
the context. Defaults to None
Returns:
None
"""
input_str, output_str = self._get_input_output(inputs, outputs)
self.chat_memory.add_user_message(input_str, metadata=metadata)
self.chat_memory.add_ai_message(output_str, metadata=metadata)
__all__ = ["ZepMemory"]