diff --git a/langchain/graphs/neo4j_graph.py b/langchain/graphs/neo4j_graph.py index e56d125c..c4450043 100644 --- a/langchain/graphs/neo4j_graph.py +++ b/langchain/graphs/neo4j_graph.py @@ -21,7 +21,8 @@ rel_query = """ CALL apoc.meta.data() YIELD label, other, elementType, type, property WHERE type = "RELATIONSHIP" AND elementType = "node" -RETURN "(:" + label + ")-[:" + property + "]->(:" + toString(other[0]) + ")" AS output +UNWIND other AS other_node +RETURN "(:" + label + ")-[:" + property + "]->(:" + toString(other_node) + ")" AS output """ diff --git a/tests/integration_tests/chains/test_graph_database.py b/tests/integration_tests/chains/test_graph_database.py index 2c6133be..8e7aaf1c 100644 --- a/tests/integration_tests/chains/test_graph_database.py +++ b/tests/integration_tests/chains/test_graph_database.py @@ -4,6 +4,11 @@ import os from langchain.chains.graph_qa.cypher import GraphCypherQAChain from langchain.chains.loading import load_chain from langchain.graphs import Neo4jGraph +from langchain.graphs.neo4j_graph import ( + node_properties_query, + rel_properties_query, + rel_query, +) from langchain.llms.openai import OpenAI @@ -171,11 +176,62 @@ def test_cypher_return_direct() -> None: assert output == expected_output +def test_cypher_return_correct_schema() -> None: + """Test that chain returns direct results.""" + url = os.environ.get("NEO4J_URL") + username = os.environ.get("NEO4J_USERNAME") + password = os.environ.get("NEO4J_PASSWORD") + assert url is not None + assert username is not None + assert password is not None + + graph = Neo4jGraph( + url=url, + username=username, + password=password, + ) + # Delete all nodes in the graph + graph.query("MATCH (n) DETACH DELETE n") + # Create two nodes and a relationship + graph.query( + """ + CREATE (la:LabelA {property_a: 'a'}) + CREATE (lb:LabelB) + CREATE (lc:LabelC) + MERGE (la)-[:REL_TYPE]-> (lb) + MERGE (la)-[:REL_TYPE {rel_prop: 'abc'}]-> (lc) + """ + ) + # Refresh schema information + graph.refresh_schema() + + node_properties = graph.query(node_properties_query) + relationships_properties = graph.query(rel_properties_query) + relationships = graph.query(rel_query) + + expected_node_properties = [ + { + "properties": [{"property": "property_a", "type": "STRING"}], + "labels": "LabelA", + } + ] + expected_relationships_properties = [ + {"type": "REL_TYPE", "properties": [{"property": "rel_prop", "type": "STRING"}]} + ] + expected_relationships = [ + "(:LabelA)-[:REL_TYPE]->(:LabelB)", + "(:LabelA)-[:REL_TYPE]->(:LabelC)", + ] + + assert node_properties == expected_node_properties + assert relationships_properties == expected_relationships_properties + assert relationships == expected_relationships + + def test_cypher_save_load() -> None: """Test saving and loading.""" FILE_PATH = "cypher.yaml" - url = os.environ.get("NEO4J_URL") username = os.environ.get("NEO4J_USERNAME") password = os.environ.get("NEO4J_PASSWORD") @@ -188,7 +244,6 @@ def test_cypher_save_load() -> None: username=username, password=password, ) - chain = GraphCypherQAChain.from_llm( OpenAI(temperature=0), graph=graph, return_direct=True )