forked from Archives/langchain
You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
102 lines
3.5 KiB
Python
102 lines
3.5 KiB
Python
from typing import Any, Dict, List
|
|
|
|
node_properties_query = """
|
|
CALL apoc.meta.data()
|
|
YIELD label, other, elementType, type, property
|
|
WHERE NOT type = "RELATIONSHIP" AND elementType = "node"
|
|
WITH label AS nodeLabels, collect({property:property, type:type}) AS properties
|
|
RETURN {labels: nodeLabels, properties: properties} AS output
|
|
|
|
"""
|
|
|
|
rel_properties_query = """
|
|
CALL apoc.meta.data()
|
|
YIELD label, other, elementType, type, property
|
|
WHERE NOT type = "RELATIONSHIP" AND elementType = "relationship"
|
|
WITH label AS nodeLabels, collect({property:property, type:type}) AS properties
|
|
RETURN {type: nodeLabels, properties: properties} AS output
|
|
"""
|
|
|
|
rel_query = """
|
|
CALL apoc.meta.data()
|
|
YIELD label, other, elementType, type, property
|
|
WHERE type = "RELATIONSHIP" AND elementType = "node"
|
|
UNWIND other AS other_node
|
|
RETURN "(:" + label + ")-[:" + property + "]->(:" + toString(other_node) + ")" AS output
|
|
"""
|
|
|
|
|
|
class Neo4jGraph:
|
|
"""Neo4j wrapper for graph operations."""
|
|
|
|
def __init__(
|
|
self, url: str, username: str, password: str, database: str = "neo4j"
|
|
) -> None:
|
|
"""Create a new Neo4j graph wrapper instance."""
|
|
try:
|
|
import neo4j
|
|
except ImportError:
|
|
raise ValueError(
|
|
"Could not import neo4j python package. "
|
|
"Please install it with `pip install neo4j`."
|
|
)
|
|
|
|
self._driver = neo4j.GraphDatabase.driver(url, auth=(username, password))
|
|
self._database = database
|
|
self.schema = ""
|
|
# Verify connection
|
|
try:
|
|
self._driver.verify_connectivity()
|
|
except neo4j.exceptions.ServiceUnavailable:
|
|
raise ValueError(
|
|
"Could not connect to Neo4j database. "
|
|
"Please ensure that the url is correct"
|
|
)
|
|
except neo4j.exceptions.AuthError:
|
|
raise ValueError(
|
|
"Could not connect to Neo4j database. "
|
|
"Please ensure that the username and password are correct"
|
|
)
|
|
# Set schema
|
|
try:
|
|
self.refresh_schema()
|
|
except neo4j.exceptions.ClientError:
|
|
raise ValueError(
|
|
"Could not use APOC procedures. "
|
|
"Please ensure the APOC plugin is installed in Neo4j and that "
|
|
"'apoc.meta.data()' is allowed in Neo4j configuration "
|
|
)
|
|
|
|
@property
|
|
def get_schema(self) -> str:
|
|
"""Returns the schema of the Neo4j database"""
|
|
return self.schema
|
|
|
|
def query(self, query: str, params: dict = {}) -> List[Dict[str, Any]]:
|
|
"""Query Neo4j database."""
|
|
from neo4j.exceptions import CypherSyntaxError
|
|
|
|
with self._driver.session(database=self._database) as session:
|
|
try:
|
|
data = session.run(query, params)
|
|
return [r.data() for r in data]
|
|
except CypherSyntaxError as e:
|
|
raise ValueError("Generated Cypher Statement is not valid\n" f"{e}")
|
|
|
|
def refresh_schema(self) -> None:
|
|
"""
|
|
Refreshes the Neo4j graph schema information.
|
|
"""
|
|
node_properties = self.query(node_properties_query)
|
|
relationships_properties = self.query(rel_properties_query)
|
|
relationships = self.query(rel_query)
|
|
|
|
self.schema = f"""
|
|
Node properties are the following:
|
|
{[el['output'] for el in node_properties]}
|
|
Relationship properties are the following:
|
|
{[el['output'] for el in relationships_properties]}
|
|
The relationships are the following:
|
|
{[el['output'] for el in relationships]}
|
|
"""
|