mirror of https://github.com/hwchase17/langchain
Neo4j chat message history (#13008)
parent
bf8cf7e042
commit
0dc4ab0be1
@ -0,0 +1,76 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "91c6a7ef",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Neo4j\n",
|
||||
"\n",
|
||||
"[Neo4j](https://en.wikipedia.org/wiki/Neo4j) is an open-source graph database management system, renowned for its efficient management of highly connected data. Unlike traditional databases that store data in tables, Neo4j uses a graph structure with nodes, edges, and properties to represent and store data. This design allows for high-performance queries on complex data relationships.\n",
|
||||
"\n",
|
||||
"This notebook goes over how to use `Neo4j` to store chat message history."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "d15e3302",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.memory import Neo4jChatMessageHistory\n",
|
||||
"\n",
|
||||
"history = Neo4jChatMessageHistory(\n",
|
||||
" url=\"bolt://localhost:7687\",\n",
|
||||
" username=\"neo4j\",\n",
|
||||
" password=\"password\",\n",
|
||||
" session_id=\"session_id_1\"\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"history.add_user_message(\"hi!\")\n",
|
||||
"\n",
|
||||
"history.add_ai_message(\"whats up?\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "64fc465e",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"history.messages"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "8af285f8",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.8.8"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
@ -0,0 +1,112 @@
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from langchain.schema import BaseChatMessageHistory
|
||||
from langchain.schema.messages import BaseMessage, messages_from_dict
|
||||
from langchain.utils import get_from_env
|
||||
|
||||
|
||||
class Neo4jChatMessageHistory(BaseChatMessageHistory):
|
||||
"""Chat message history stored in a Neo4j database."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
session_id: Union[str, int],
|
||||
url: Optional[str] = None,
|
||||
username: Optional[str] = None,
|
||||
password: Optional[str] = None,
|
||||
database: str = "neo4j",
|
||||
node_label: str = "Session",
|
||||
window: int = 3,
|
||||
):
|
||||
try:
|
||||
import neo4j
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"Could not import neo4j python package. "
|
||||
"Please install it with `pip install neo4j`."
|
||||
)
|
||||
|
||||
# Make sure session id is not null
|
||||
if not session_id:
|
||||
raise ValueError("Please ensure that the session_id parameter is provided")
|
||||
|
||||
url = get_from_env("url", "NEO4J_URI", url)
|
||||
username = get_from_env("username", "NEO4J_USERNAME", username)
|
||||
password = get_from_env("password", "NEO4J_PASSWORD", password)
|
||||
database = get_from_env("database", "NEO4J_DATABASE", database)
|
||||
|
||||
self._driver = neo4j.GraphDatabase.driver(url, auth=(username, password))
|
||||
self._database = database
|
||||
self._session_id = session_id
|
||||
self._node_label = node_label
|
||||
self._window = window
|
||||
|
||||
# Verify connection
|
||||
try:
|
||||
self._driver.verify_connectivity()
|
||||
except neo4j.exceptions.ServiceUnavailable:
|
||||
raise ValueError(
|
||||
"Could not connect to Neo4j database. "
|
||||
"Please ensure that the url is correct"
|
||||
)
|
||||
except neo4j.exceptions.AuthError:
|
||||
raise ValueError(
|
||||
"Could not connect to Neo4j database. "
|
||||
"Please ensure that the username and password are correct"
|
||||
)
|
||||
# Create session node
|
||||
self._driver.execute_query(
|
||||
f"MERGE (s:`{self._node_label}` {{id:$session_id}})",
|
||||
{"session_id": self._session_id},
|
||||
).summary
|
||||
|
||||
@property
|
||||
def messages(self) -> List[BaseMessage]: # type: ignore
|
||||
"""Retrieve the messages from Neo4j"""
|
||||
query = (
|
||||
f"MATCH (s:`{self._node_label}`)-[:LAST_MESSAGE]->(last_message) "
|
||||
"WHERE s.id = $session_id MATCH p=(last_message)<-[:NEXT*0.."
|
||||
f"{self._window*2}]-() WITH p, length(p) AS length "
|
||||
"ORDER BY length DESC LIMIT 1 UNWIND reverse(nodes(p)) AS node "
|
||||
"RETURN {data:{content: node.content}, type:node.type} AS result"
|
||||
)
|
||||
records, _, _ = self._driver.execute_query(
|
||||
query, {"session_id": self._session_id}
|
||||
)
|
||||
|
||||
messages = messages_from_dict([el["result"] for el in records])
|
||||
return messages
|
||||
|
||||
def add_message(self, message: BaseMessage) -> None:
|
||||
"""Append the message to the record in Neo4j"""
|
||||
query = (
|
||||
f"MATCH (s:`{self._node_label}`) WHERE s.id = $session_id "
|
||||
"OPTIONAL MATCH (s)-[lm:LAST_MESSAGE]->(last_message) "
|
||||
"CREATE (s)-[:LAST_MESSAGE]->(new:Message) "
|
||||
"SET new += {type:$type, content:$content} "
|
||||
"WITH new, lm, last_message WHERE last_message IS NOT NULL "
|
||||
"CREATE (last_message)-[:NEXT]->(new) "
|
||||
"DELETE lm"
|
||||
)
|
||||
self._driver.execute_query(
|
||||
query,
|
||||
{
|
||||
"type": message.type,
|
||||
"content": message.content,
|
||||
"session_id": self._session_id,
|
||||
},
|
||||
).summary
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Clear session memory from Neo4j"""
|
||||
query = (
|
||||
f"MATCH (s:`{self._node_label}`)-[:LAST_MESSAGE]->(last_message) "
|
||||
"WHERE s.id = $session_id MATCH p=(last_message)<-[:NEXT]-() "
|
||||
"WITH p, length(p) AS length ORDER BY length DESC LIMIT 1 "
|
||||
"UNWIND nodes(p) as node DETACH DELETE node;"
|
||||
)
|
||||
self._driver.execute_query(query, {"session_id": self._session_id}).summary
|
||||
|
||||
def __del__(self) -> None:
|
||||
if self._driver:
|
||||
self._driver.close()
|
@ -0,0 +1,30 @@
|
||||
import json
|
||||
|
||||
from langchain.memory import ConversationBufferMemory
|
||||
from langchain.memory.chat_message_histories import Neo4jChatMessageHistory
|
||||
from langchain.schema.messages import _message_to_dict
|
||||
|
||||
|
||||
def test_memory_with_message_store() -> None:
|
||||
"""Test the memory with a message store."""
|
||||
# setup MongoDB as a message store
|
||||
message_history = Neo4jChatMessageHistory(session_id="test-session")
|
||||
memory = ConversationBufferMemory(
|
||||
memory_key="baz", chat_memory=message_history, return_messages=True
|
||||
)
|
||||
|
||||
# add some messages
|
||||
memory.chat_memory.add_ai_message("This is me, the AI")
|
||||
memory.chat_memory.add_user_message("This is me, the human")
|
||||
|
||||
# get the message history from the memory store and turn it into a json
|
||||
messages = memory.chat_memory.messages
|
||||
messages_json = json.dumps([_message_to_dict(msg) for msg in messages])
|
||||
|
||||
assert "This is me, the AI" in messages_json
|
||||
assert "This is me, the human" in messages_json
|
||||
|
||||
# remove the record from Azure Cosmos DB, so the next test run won't pick it up
|
||||
memory.chat_memory.clear()
|
||||
|
||||
assert memory.chat_memory.messages == []
|
Loading…
Reference in New Issue