mirror of
https://github.com/hwchase17/langchain
synced 2024-11-18 09:25:54 +00:00
ed58eeb9c5
Moved the following modules to new package langchain-community in a backwards compatible fashion: ``` mv langchain/langchain/adapters community/langchain_community mv langchain/langchain/callbacks community/langchain_community/callbacks mv langchain/langchain/chat_loaders community/langchain_community mv langchain/langchain/chat_models community/langchain_community mv langchain/langchain/document_loaders community/langchain_community mv langchain/langchain/docstore community/langchain_community mv langchain/langchain/document_transformers community/langchain_community mv langchain/langchain/embeddings community/langchain_community mv langchain/langchain/graphs community/langchain_community mv langchain/langchain/llms community/langchain_community mv langchain/langchain/memory/chat_message_histories community/langchain_community mv langchain/langchain/retrievers community/langchain_community mv langchain/langchain/storage community/langchain_community mv langchain/langchain/tools community/langchain_community mv langchain/langchain/utilities community/langchain_community mv langchain/langchain/vectorstores community/langchain_community mv langchain/langchain/agents/agent_toolkits community/langchain_community mv langchain/langchain/cache.py community/langchain_community mv langchain/langchain/adapters community/langchain_community mv langchain/langchain/callbacks community/langchain_community/callbacks mv langchain/langchain/chat_loaders community/langchain_community mv langchain/langchain/chat_models community/langchain_community mv langchain/langchain/document_loaders community/langchain_community mv langchain/langchain/docstore community/langchain_community mv langchain/langchain/document_transformers community/langchain_community mv langchain/langchain/embeddings community/langchain_community mv langchain/langchain/graphs community/langchain_community mv langchain/langchain/llms community/langchain_community mv langchain/langchain/memory/chat_message_histories community/langchain_community mv langchain/langchain/retrievers community/langchain_community mv langchain/langchain/storage community/langchain_community mv langchain/langchain/tools community/langchain_community mv langchain/langchain/utilities community/langchain_community mv langchain/langchain/vectorstores community/langchain_community mv langchain/langchain/agents/agent_toolkits community/langchain_community mv langchain/langchain/cache.py community/langchain_community ``` Moved the following to core ``` mv langchain/langchain/utils/json_schema.py core/langchain_core/utils mv langchain/langchain/utils/html.py core/langchain_core/utils mv langchain/langchain/utils/strings.py core/langchain_core/utils cat langchain/langchain/utils/env.py >> core/langchain_core/utils/env.py rm langchain/langchain/utils/env.py ``` See .scripts/community_split/script_integrations.sh for all changes
113 lines
4.2 KiB
Python
113 lines
4.2 KiB
Python
from typing import List, Optional, Union
|
|
|
|
from langchain_core.chat_history import BaseChatMessageHistory
|
|
from langchain_core.messages import BaseMessage, messages_from_dict
|
|
from langchain_core.utils import get_from_env
|
|
|
|
|
|
class Neo4jChatMessageHistory(BaseChatMessageHistory):
|
|
"""Chat message history stored in a Neo4j database."""
|
|
|
|
def __init__(
|
|
self,
|
|
session_id: Union[str, int],
|
|
url: Optional[str] = None,
|
|
username: Optional[str] = None,
|
|
password: Optional[str] = None,
|
|
database: str = "neo4j",
|
|
node_label: str = "Session",
|
|
window: int = 3,
|
|
):
|
|
try:
|
|
import neo4j
|
|
except ImportError:
|
|
raise ValueError(
|
|
"Could not import neo4j python package. "
|
|
"Please install it with `pip install neo4j`."
|
|
)
|
|
|
|
# Make sure session id is not null
|
|
if not session_id:
|
|
raise ValueError("Please ensure that the session_id parameter is provided")
|
|
|
|
url = get_from_env("url", "NEO4J_URI", url)
|
|
username = get_from_env("username", "NEO4J_USERNAME", username)
|
|
password = get_from_env("password", "NEO4J_PASSWORD", password)
|
|
database = get_from_env("database", "NEO4J_DATABASE", database)
|
|
|
|
self._driver = neo4j.GraphDatabase.driver(url, auth=(username, password))
|
|
self._database = database
|
|
self._session_id = session_id
|
|
self._node_label = node_label
|
|
self._window = window
|
|
|
|
# Verify connection
|
|
try:
|
|
self._driver.verify_connectivity()
|
|
except neo4j.exceptions.ServiceUnavailable:
|
|
raise ValueError(
|
|
"Could not connect to Neo4j database. "
|
|
"Please ensure that the url is correct"
|
|
)
|
|
except neo4j.exceptions.AuthError:
|
|
raise ValueError(
|
|
"Could not connect to Neo4j database. "
|
|
"Please ensure that the username and password are correct"
|
|
)
|
|
# Create session node
|
|
self._driver.execute_query(
|
|
f"MERGE (s:`{self._node_label}` {{id:$session_id}})",
|
|
{"session_id": self._session_id},
|
|
).summary
|
|
|
|
@property
|
|
def messages(self) -> List[BaseMessage]: # type: ignore
|
|
"""Retrieve the messages from Neo4j"""
|
|
query = (
|
|
f"MATCH (s:`{self._node_label}`)-[:LAST_MESSAGE]->(last_message) "
|
|
"WHERE s.id = $session_id MATCH p=(last_message)<-[:NEXT*0.."
|
|
f"{self._window*2}]-() WITH p, length(p) AS length "
|
|
"ORDER BY length DESC LIMIT 1 UNWIND reverse(nodes(p)) AS node "
|
|
"RETURN {data:{content: node.content}, type:node.type} AS result"
|
|
)
|
|
records, _, _ = self._driver.execute_query(
|
|
query, {"session_id": self._session_id}
|
|
)
|
|
|
|
messages = messages_from_dict([el["result"] for el in records])
|
|
return messages
|
|
|
|
def add_message(self, message: BaseMessage) -> None:
|
|
"""Append the message to the record in Neo4j"""
|
|
query = (
|
|
f"MATCH (s:`{self._node_label}`) WHERE s.id = $session_id "
|
|
"OPTIONAL MATCH (s)-[lm:LAST_MESSAGE]->(last_message) "
|
|
"CREATE (s)-[:LAST_MESSAGE]->(new:Message) "
|
|
"SET new += {type:$type, content:$content} "
|
|
"WITH new, lm, last_message WHERE last_message IS NOT NULL "
|
|
"CREATE (last_message)-[:NEXT]->(new) "
|
|
"DELETE lm"
|
|
)
|
|
self._driver.execute_query(
|
|
query,
|
|
{
|
|
"type": message.type,
|
|
"content": message.content,
|
|
"session_id": self._session_id,
|
|
},
|
|
).summary
|
|
|
|
def clear(self) -> None:
|
|
"""Clear session memory from Neo4j"""
|
|
query = (
|
|
f"MATCH (s:`{self._node_label}`)-[:LAST_MESSAGE]->(last_message) "
|
|
"WHERE s.id = $session_id MATCH p=(last_message)<-[:NEXT]-() "
|
|
"WITH p, length(p) AS length ORDER BY length DESC LIMIT 1 "
|
|
"UNWIND nodes(p) as node DETACH DELETE node;"
|
|
)
|
|
self._driver.execute_query(query, {"session_id": self._session_id}).summary
|
|
|
|
def __del__(self) -> None:
|
|
if self._driver:
|
|
self._driver.close()
|