langchain/libs/community/tests/integration_tests/graphs/test_age_graph.py

338 lines
11 KiB
Python
Raw Normal View History

import os
import re
import unittest
from typing import Any, Dict
from langchain_core.documents import Document
from langchain_community.graphs.age_graph import AGEGraph
from langchain_community.graphs.graph_document import GraphDocument, Node, Relationship
test_data = [
GraphDocument(
nodes=[Node(id="foo", type="foo"), Node(id="bar", type="bar")],
relationships=[
Relationship(
source=Node(id="foo", type="foo"),
target=Node(id="bar", type="bar"),
type="REL",
)
],
source=Document(page_content="source document"),
)
]
class TestAGEGraph(unittest.TestCase):
def test_node_properties(self) -> None:
conf = {
"database": os.getenv("AGE_PGSQL_DB"),
"user": os.getenv("AGE_PGSQL_USER"),
"password": os.getenv("AGE_PGSQL_PASSWORD"),
"host": os.getenv("AGE_PGSQL_HOST", "localhost"),
"port": int(os.getenv("AGE_PGSQL_PORT", 5432)),
}
self.assertIsNotNone(conf["database"])
self.assertIsNotNone(conf["user"])
self.assertIsNotNone(conf["password"])
graph_name = os.getenv("AGE_GRAPH_NAME", "age_test")
graph = AGEGraph(graph_name, conf)
graph.query("MATCH (n) DETACH DELETE n")
# Create two nodes and a relationship
graph.query(
"""
CREATE (la:LabelA {property_a: 'a'})
CREATE (lb:LabelB)
CREATE (lc:LabelC)
MERGE (la)-[:REL_TYPE]-> (lb)
MERGE (la)-[:REL_TYPE {rel_prop: 'abc'}]-> (lc)
"""
)
# Refresh schema information
# graph.refresh_schema()
n_labels, e_labels = graph._get_labels()
node_properties = graph._get_node_properties(n_labels)
expected_node_properties = [
{
"properties": [{"property": "property_a", "type": "STRING"}],
"labels": "LabelA",
},
{
"properties": [],
"labels": "LabelB",
},
{
"properties": [],
"labels": "LabelC",
},
]
self.assertEqual(
sorted(node_properties, key=lambda x: x["labels"]), expected_node_properties
)
def test_edge_properties(self) -> None:
conf = {
"database": os.getenv("AGE_PGSQL_DB"),
"user": os.getenv("AGE_PGSQL_USER"),
"password": os.getenv("AGE_PGSQL_PASSWORD"),
"host": os.getenv("AGE_PGSQL_HOST", "localhost"),
"port": int(os.getenv("AGE_PGSQL_PORT", 5432)),
}
self.assertIsNotNone(conf["database"])
self.assertIsNotNone(conf["user"])
self.assertIsNotNone(conf["password"])
graph_name = os.getenv("AGE_GRAPH_NAME", "age_test")
graph = AGEGraph(graph_name, conf)
graph.query("MATCH (n) DETACH DELETE n")
# Create two nodes and a relationship
graph.query(
"""
CREATE (la:LabelA {property_a: 'a'})
CREATE (lb:LabelB)
CREATE (lc:LabelC)
MERGE (la)-[:REL_TYPE]-> (lb)
MERGE (la)-[:REL_TYPE {rel_prop: 'abc'}]-> (lc)
"""
)
# Refresh schema information
# graph.refresh_schema()
n_labels, e_labels = graph._get_labels()
relationships_properties = graph._get_edge_properties(e_labels)
expected_relationships_properties = [
{
"type": "REL_TYPE",
"properties": [{"property": "rel_prop", "type": "STRING"}],
}
]
self.assertEqual(relationships_properties, expected_relationships_properties)
def test_relationships(self) -> None:
conf = {
"database": os.getenv("AGE_PGSQL_DB"),
"user": os.getenv("AGE_PGSQL_USER"),
"password": os.getenv("AGE_PGSQL_PASSWORD"),
"host": os.getenv("AGE_PGSQL_HOST", "localhost"),
"port": int(os.getenv("AGE_PGSQL_PORT", 5432)),
}
self.assertIsNotNone(conf["database"])
self.assertIsNotNone(conf["user"])
self.assertIsNotNone(conf["password"])
graph_name = os.getenv("AGE_GRAPH_NAME", "age_test")
graph = AGEGraph(graph_name, conf)
graph.query("MATCH (n) DETACH DELETE n")
# Create two nodes and a relationship
graph.query(
"""
CREATE (la:LabelA {property_a: 'a'})
CREATE (lb:LabelB)
CREATE (lc:LabelC)
MERGE (la)-[:REL_TYPE]-> (lb)
MERGE (la)-[:REL_TYPE {rel_prop: 'abc'}]-> (lc)
"""
)
# Refresh schema information
# graph.refresh_schema()
n_labels, e_labels = graph._get_labels()
relationships = graph._get_triples(e_labels)
expected_relationships = [
{"start": "LabelA", "type": "REL_TYPE", "end": "LabelB"},
{"start": "LabelA", "type": "REL_TYPE", "end": "LabelC"},
]
self.assertEqual(
sorted(relationships, key=lambda x: x["end"]), expected_relationships
)
def test_add_documents(self) -> None:
conf = {
"database": os.getenv("AGE_PGSQL_DB"),
"user": os.getenv("AGE_PGSQL_USER"),
"password": os.getenv("AGE_PGSQL_PASSWORD"),
"host": os.getenv("AGE_PGSQL_HOST", "localhost"),
"port": int(os.getenv("AGE_PGSQL_PORT", 5432)),
}
self.assertIsNotNone(conf["database"])
self.assertIsNotNone(conf["user"])
self.assertIsNotNone(conf["password"])
graph_name = os.getenv("AGE_GRAPH_NAME", "age_test")
graph = AGEGraph(graph_name, conf)
# Delete all nodes in the graph
graph.query("MATCH (n) DETACH DELETE n")
# Create two nodes and a relationship
graph.add_graph_documents(test_data)
output = graph.query(
"MATCH (n) RETURN labels(n) AS label, count(*) AS count ORDER BY labels(n)"
)
self.assertEqual(
output, [{"label": ["bar"], "count": 1}, {"label": ["foo"], "count": 1}]
)
def test_add_documents_source(self) -> None:
conf = {
"database": os.getenv("AGE_PGSQL_DB"),
"user": os.getenv("AGE_PGSQL_USER"),
"password": os.getenv("AGE_PGSQL_PASSWORD"),
"host": os.getenv("AGE_PGSQL_HOST", "localhost"),
"port": int(os.getenv("AGE_PGSQL_PORT", 5432)),
}
self.assertIsNotNone(conf["database"])
self.assertIsNotNone(conf["user"])
self.assertIsNotNone(conf["password"])
graph_name = os.getenv("AGE_GRAPH_NAME", "age_test")
graph = AGEGraph(graph_name, conf)
# Delete all nodes in the graph
graph.query("MATCH (n) DETACH DELETE n")
# Create two nodes and a relationship
graph.add_graph_documents(test_data, include_source=True)
output = graph.query(
"MATCH (n) RETURN labels(n) AS label, count(*) AS count ORDER BY labels(n)"
)
expected = [
{"label": ["bar"], "count": 1},
{"label": ["Document"], "count": 1},
{"label": ["foo"], "count": 1},
]
self.assertEqual(output, expected)
def test_get_schema(self) -> None:
conf = {
"database": os.getenv("AGE_PGSQL_DB"),
"user": os.getenv("AGE_PGSQL_USER"),
"password": os.getenv("AGE_PGSQL_PASSWORD"),
"host": os.getenv("AGE_PGSQL_HOST", "localhost"),
"port": int(os.getenv("AGE_PGSQL_PORT", 5432)),
}
self.assertIsNotNone(conf["database"])
self.assertIsNotNone(conf["user"])
self.assertIsNotNone(conf["password"])
graph_name = os.getenv("AGE_GRAPH_NAME", "age_test")
graph = AGEGraph(graph_name, conf)
graph.query("MATCH (n) DETACH DELETE n")
graph.refresh_schema()
expected = """
Node properties are the following:
[]
Relationship properties are the following:
[]
The relationships are the following:
[]
"""
# check that works on empty schema
self.assertEqual(
re.sub(r"\s", "", graph.get_schema), re.sub(r"\s", "", expected)
)
expected_structured: Dict[str, Any] = {
"node_props": {},
"rel_props": {},
"relationships": [],
"metadata": {},
}
self.assertEqual(graph.get_structured_schema, expected_structured)
# Create two nodes and a relationship
graph.query(
"""
MERGE (a:a {id: 1})-[b:b {id: 2}]-> (c:c {id: 3})
"""
)
# check that schema doesn't update without refresh
self.assertEqual(
re.sub(r"\s", "", graph.get_schema), re.sub(r"\s", "", expected)
)
self.assertEqual(graph.get_structured_schema, expected_structured)
# two possible orderings of node props
expected_possibilities = [
"""
Node properties are the following:
[
{'properties': [{'property': 'id', 'type': 'INTEGER'}], 'labels': 'a'},
{'properties': [{'property': 'id', 'type': 'INTEGER'}], 'labels': 'c'}
]
Relationship properties are the following:
[
{'properties': [{'property': 'id', 'type': 'INTEGER'}], 'type': 'b'}
]
The relationships are the following:
[
'(:`a`)-[:`b`]->(:`c`)'
]
""",
"""
Node properties are the following:
[
{'properties': [{'property': 'id', 'type': 'INTEGER'}], 'labels': 'c'},
{'properties': [{'property': 'id', 'type': 'INTEGER'}], 'labels': 'a'}
]
Relationship properties are the following:
[
{'properties': [{'property': 'id', 'type': 'INTEGER'}], 'type': 'b'}
]
The relationships are the following:
[
'(:`a`)-[:`b`]->(:`c`)'
]
""",
]
expected_structured2 = {
"node_props": {
"a": [{"property": "id", "type": "INTEGER"}],
"c": [{"property": "id", "type": "INTEGER"}],
},
"rel_props": {"b": [{"property": "id", "type": "INTEGER"}]},
"relationships": [{"start": "a", "type": "b", "end": "c"}],
"metadata": {},
}
graph.refresh_schema()
# check that schema is refreshed
self.assertIn(
re.sub(r"\s", "", graph.get_schema),
[re.sub(r"\s", "", x) for x in expected_possibilities],
)
self.assertEqual(graph.get_structured_schema, expected_structured2)