From 01c2f27ffa87c4225d4f6ae2e89887e1424990dc Mon Sep 17 00:00:00 2001 From: Katarina Supe <61758502+katarinasupe@users.noreply.github.com> Date: Mon, 22 Jan 2024 20:33:28 +0100 Subject: [PATCH] community[patch]: Update Memgraph support (#16360) - **Description:** I removed two queries to the database and left just one whose results were formatted afterward into other type of schema (avoided two calls to DB) - **Issue:** / - **Dependencies:** / - **Twitter handle:** @supe_katarina --- .../graphs/memgraph_graph.py | 45 ++++++++++---- .../integration_tests/graphs/test_memgraph.py | 62 +++++++++++++++++++ 2 files changed, 96 insertions(+), 11 deletions(-) create mode 100644 libs/community/tests/integration_tests/graphs/test_memgraph.py diff --git a/libs/community/langchain_community/graphs/memgraph_graph.py b/libs/community/langchain_community/graphs/memgraph_graph.py index 2df4612a2c..34e9f7145b 100644 --- a/libs/community/langchain_community/graphs/memgraph_graph.py +++ b/libs/community/langchain_community/graphs/memgraph_graph.py @@ -1,12 +1,6 @@ from langchain_community.graphs.neo4j_graph import Neo4jGraph SCHEMA_QUERY = """ -CALL llm_util.schema("prompt_ready") -YIELD * -RETURN * -""" - -RAW_SCHEMA_QUERY = """ CALL llm_util.schema("raw") YIELD * RETURN * @@ -39,10 +33,39 @@ class MemgraphGraph(Neo4jGraph): Refreshes the Memgraph graph schema information. """ - db_schema = self.query(SCHEMA_QUERY)[0].get("schema") - assert db_schema is not None - self.schema = db_schema - - db_structured_schema = self.query(RAW_SCHEMA_QUERY)[0].get("schema") + db_structured_schema = self.query(SCHEMA_QUERY)[0].get("schema") assert db_structured_schema is not None self.structured_schema = db_structured_schema + + # Format node properties + formatted_node_props = [] + + for node_name, properties in db_structured_schema["node_props"].items(): + formatted_node_props.append( + f"Node name: '{node_name}', Node properties: {properties}" + ) + + # Format relationship properties + formatted_rel_props = [] + for rel_name, properties in db_structured_schema["rel_props"].items(): + formatted_rel_props.append( + f"Relationship name: '{rel_name}', " + f"Relationship properties: {properties}" + ) + + # Format relationships + formatted_rels = [ + f"(:{rel['start']})-[:{rel['type']}]->(:{rel['end']})" + for rel in db_structured_schema["relationships"] + ] + + self.schema = "\n".join( + [ + "Node properties are the following:", + *formatted_node_props, + "Relationship properties are the following:", + *formatted_rel_props, + "The relationships are the following:", + *formatted_rels, + ] + ) diff --git a/libs/community/tests/integration_tests/graphs/test_memgraph.py b/libs/community/tests/integration_tests/graphs/test_memgraph.py new file mode 100644 index 0000000000..663f974d3f --- /dev/null +++ b/libs/community/tests/integration_tests/graphs/test_memgraph.py @@ -0,0 +1,62 @@ +import os + +from langchain_community.graphs import MemgraphGraph + + +def test_cypher_return_correct_schema() -> None: + """Test that chain returns direct results.""" + url = os.environ.get("MEMGRAPH_URI", "bolt://localhost:7687") + username = os.environ.get("MEMGRAPH_USERNAME", "") + password = os.environ.get("MEMGRAPH_PASSWORD", "") + assert url is not None + assert username is not None + assert password is not None + + graph = MemgraphGraph( + url=url, + username=username, + password=password, + ) + # Delete all nodes in the graph + graph.query("MATCH (n) DETACH DELETE n") + # Create two nodes and a relationship + graph.query( + """ + CREATE (la:LabelA {property_a: 'a'}) + CREATE (lb:LabelB) + CREATE (lc:LabelC) + MERGE (la)-[:REL_TYPE]-> (lb) + MERGE (la)-[:REL_TYPE {rel_prop: 'abc'}]-> (lc) + """ + ) + # Refresh schema information + graph.refresh_schema() + relationships = graph.query( + "CALL llm_util.schema('raw') YIELD schema " + "WITH schema.relationships AS relationships " + "UNWIND relationships AS relationship " + "RETURN relationship['start'] AS start, " + "relationship['type'] AS type, " + "relationship['end'] AS end " + "ORDER BY start, type, end;" + ) + + node_props = graph.query( + "CALL llm_util.schema('raw') YIELD schema " + "WITH schema.node_props AS nodes " + "WITH nodes['LabelA'] AS properties " + "UNWIND properties AS property " + "RETURN property['property'] AS prop, " + "property['type'] AS type " + "ORDER BY prop ASC;" + ) + + expected_relationships = [ + {"start": "LabelA", "type": "REL_TYPE", "end": "LabelB"}, + {"start": "LabelA", "type": "REL_TYPE", "end": "LabelC"}, + ] + + expected_node_props = [{"prop": "property_a", "type": "str"}] + + assert relationships == expected_relationships + assert node_props == expected_node_props