From aeeda370aa65fd1317240bd0a4b0d035475a80a4 Mon Sep 17 00:00:00 2001 From: Tomaz Bratanic Date: Mon, 24 Jun 2024 12:05:31 -0700 Subject: [PATCH] Sanitize backticks from neo4j labels and types for import (#23367) --- .../langchain_community/graphs/neo4j_graph.py | 15 ++++++-- .../integration_tests/graphs/test_neo4j.py | 35 +++++++++++++++++++ 2 files changed, 47 insertions(+), 3 deletions(-) diff --git a/libs/community/langchain_community/graphs/neo4j_graph.py b/libs/community/langchain_community/graphs/neo4j_graph.py index cd2791d646..4897348a7b 100644 --- a/libs/community/langchain_community/graphs/neo4j_graph.py +++ b/libs/community/langchain_community/graphs/neo4j_graph.py @@ -287,6 +287,10 @@ def _format_schema(schema: Dict, is_enhanced: bool) -> str: ) +def _remove_backticks(text: str) -> str: + return text.replace("`", "") + + class Neo4jGraph(GraphStore): """Neo4j database wrapper for various graph operations. @@ -571,6 +575,9 @@ class Neo4jGraph(GraphStore): document.source.page_content.encode("utf-8") ).hexdigest() + # Remove backticks from node types + for node in document.nodes: + node.type = _remove_backticks(node.type) # Import nodes self.query( node_import_query, @@ -586,10 +593,12 @@ class Neo4jGraph(GraphStore): "data": [ { "source": el.source.id, - "source_label": el.source.type, + "source_label": _remove_backticks(el.source.type), "target": el.target.id, - "target_label": el.target.type, - "type": el.type.replace(" ", "_").upper(), + "target_label": _remove_backticks(el.target.type), + "type": _remove_backticks( + el.type.replace(" ", "_").upper() + ), "properties": el.properties, } for el in document.relationships diff --git a/libs/community/tests/integration_tests/graphs/test_neo4j.py b/libs/community/tests/integration_tests/graphs/test_neo4j.py index a519b43070..761cbc95e4 100644 --- a/libs/community/tests/integration_tests/graphs/test_neo4j.py +++ b/libs/community/tests/integration_tests/graphs/test_neo4j.py @@ -25,6 +25,20 @@ test_data = [ ) ] +test_data_backticks = [ + GraphDocument( + nodes=[Node(id="foo", type="foo`"), Node(id="bar", type="`bar")], + relationships=[ + Relationship( + source=Node(id="foo", type="f`oo"), + target=Node(id="bar", type="ba`r"), + type="`REL`", + ) + ], + source=Document(page_content="source document"), + ) +] + def test_cypher_return_correct_schema() -> None: """Test that chain returns direct results.""" @@ -363,3 +377,24 @@ def test_enhanced_schema_exception() -> None: # remove metadata portion of schema del graph.structured_schema["metadata"] assert graph.structured_schema == expected_output + + +def test_backticks() -> None: + """Test that backticks are correctly removed.""" + url = os.environ.get("NEO4J_URI") + username = os.environ.get("NEO4J_USERNAME") + password = os.environ.get("NEO4J_PASSWORD") + assert url is not None + assert username is not None + assert password is not None + + graph = Neo4jGraph(url=url, username=username, password=password) + graph.query("MATCH (n) DETACH DELETE n") + graph.add_graph_documents(test_data_backticks) + nodes = graph.query("MATCH (n) RETURN labels(n) AS labels ORDER BY n.id") + rels = graph.query("MATCH ()-[r]->() RETURN type(r) AS type") + expected_nodes = [{"labels": ["bar"]}, {"labels": ["foo"]}] + expected_rels = [{"type": "REL"}] + + assert nodes == expected_nodes + assert rels == expected_rels