mirror of
https://github.com/hwchase17/langchain
synced 2024-11-10 01:10:59 +00:00
cb6e5e56c2
**Description:** implemented GraphStore class for Apache Age graph db **Dependencies:** depends on psycopg2 Unit and integration tests included. Formatting and linting have been run. --------- Co-authored-by: Bagatur <22008038+baskaryan@users.noreply.github.com> Co-authored-by: Bagatur <baskaryan@gmail.com>
338 lines
11 KiB
Python
338 lines
11 KiB
Python
import os
|
|
import re
|
|
import unittest
|
|
from typing import Any, Dict
|
|
|
|
from langchain_core.documents import Document
|
|
|
|
from langchain_community.graphs.age_graph import AGEGraph
|
|
from langchain_community.graphs.graph_document import GraphDocument, Node, Relationship
|
|
|
|
test_data = [
|
|
GraphDocument(
|
|
nodes=[Node(id="foo", type="foo"), Node(id="bar", type="bar")],
|
|
relationships=[
|
|
Relationship(
|
|
source=Node(id="foo", type="foo"),
|
|
target=Node(id="bar", type="bar"),
|
|
type="REL",
|
|
)
|
|
],
|
|
source=Document(page_content="source document"),
|
|
)
|
|
]
|
|
|
|
|
|
class TestAGEGraph(unittest.TestCase):
|
|
def test_node_properties(self) -> None:
|
|
conf = {
|
|
"database": os.getenv("AGE_PGSQL_DB"),
|
|
"user": os.getenv("AGE_PGSQL_USER"),
|
|
"password": os.getenv("AGE_PGSQL_PASSWORD"),
|
|
"host": os.getenv("AGE_PGSQL_HOST", "localhost"),
|
|
"port": int(os.getenv("AGE_PGSQL_PORT", 5432)),
|
|
}
|
|
|
|
self.assertIsNotNone(conf["database"])
|
|
self.assertIsNotNone(conf["user"])
|
|
self.assertIsNotNone(conf["password"])
|
|
|
|
graph_name = os.getenv("AGE_GRAPH_NAME", "age_test")
|
|
|
|
graph = AGEGraph(graph_name, conf)
|
|
|
|
graph.query("MATCH (n) DETACH DELETE n")
|
|
|
|
# Create two nodes and a relationship
|
|
graph.query(
|
|
"""
|
|
CREATE (la:LabelA {property_a: 'a'})
|
|
CREATE (lb:LabelB)
|
|
CREATE (lc:LabelC)
|
|
MERGE (la)-[:REL_TYPE]-> (lb)
|
|
MERGE (la)-[:REL_TYPE {rel_prop: 'abc'}]-> (lc)
|
|
"""
|
|
)
|
|
# Refresh schema information
|
|
# graph.refresh_schema()
|
|
|
|
n_labels, e_labels = graph._get_labels()
|
|
|
|
node_properties = graph._get_node_properties(n_labels)
|
|
|
|
expected_node_properties = [
|
|
{
|
|
"properties": [{"property": "property_a", "type": "STRING"}],
|
|
"labels": "LabelA",
|
|
},
|
|
{
|
|
"properties": [],
|
|
"labels": "LabelB",
|
|
},
|
|
{
|
|
"properties": [],
|
|
"labels": "LabelC",
|
|
},
|
|
]
|
|
|
|
self.assertEqual(
|
|
sorted(node_properties, key=lambda x: x["labels"]), expected_node_properties
|
|
)
|
|
|
|
def test_edge_properties(self) -> None:
|
|
conf = {
|
|
"database": os.getenv("AGE_PGSQL_DB"),
|
|
"user": os.getenv("AGE_PGSQL_USER"),
|
|
"password": os.getenv("AGE_PGSQL_PASSWORD"),
|
|
"host": os.getenv("AGE_PGSQL_HOST", "localhost"),
|
|
"port": int(os.getenv("AGE_PGSQL_PORT", 5432)),
|
|
}
|
|
|
|
self.assertIsNotNone(conf["database"])
|
|
self.assertIsNotNone(conf["user"])
|
|
self.assertIsNotNone(conf["password"])
|
|
|
|
graph_name = os.getenv("AGE_GRAPH_NAME", "age_test")
|
|
|
|
graph = AGEGraph(graph_name, conf)
|
|
|
|
graph.query("MATCH (n) DETACH DELETE n")
|
|
# Create two nodes and a relationship
|
|
graph.query(
|
|
"""
|
|
CREATE (la:LabelA {property_a: 'a'})
|
|
CREATE (lb:LabelB)
|
|
CREATE (lc:LabelC)
|
|
MERGE (la)-[:REL_TYPE]-> (lb)
|
|
MERGE (la)-[:REL_TYPE {rel_prop: 'abc'}]-> (lc)
|
|
"""
|
|
)
|
|
# Refresh schema information
|
|
# graph.refresh_schema()
|
|
|
|
n_labels, e_labels = graph._get_labels()
|
|
|
|
relationships_properties = graph._get_edge_properties(e_labels)
|
|
|
|
expected_relationships_properties = [
|
|
{
|
|
"type": "REL_TYPE",
|
|
"properties": [{"property": "rel_prop", "type": "STRING"}],
|
|
}
|
|
]
|
|
|
|
self.assertEqual(relationships_properties, expected_relationships_properties)
|
|
|
|
def test_relationships(self) -> None:
|
|
conf = {
|
|
"database": os.getenv("AGE_PGSQL_DB"),
|
|
"user": os.getenv("AGE_PGSQL_USER"),
|
|
"password": os.getenv("AGE_PGSQL_PASSWORD"),
|
|
"host": os.getenv("AGE_PGSQL_HOST", "localhost"),
|
|
"port": int(os.getenv("AGE_PGSQL_PORT", 5432)),
|
|
}
|
|
|
|
self.assertIsNotNone(conf["database"])
|
|
self.assertIsNotNone(conf["user"])
|
|
self.assertIsNotNone(conf["password"])
|
|
|
|
graph_name = os.getenv("AGE_GRAPH_NAME", "age_test")
|
|
|
|
graph = AGEGraph(graph_name, conf)
|
|
|
|
graph.query("MATCH (n) DETACH DELETE n")
|
|
# Create two nodes and a relationship
|
|
graph.query(
|
|
"""
|
|
CREATE (la:LabelA {property_a: 'a'})
|
|
CREATE (lb:LabelB)
|
|
CREATE (lc:LabelC)
|
|
MERGE (la)-[:REL_TYPE]-> (lb)
|
|
MERGE (la)-[:REL_TYPE {rel_prop: 'abc'}]-> (lc)
|
|
"""
|
|
)
|
|
# Refresh schema information
|
|
# graph.refresh_schema()
|
|
|
|
n_labels, e_labels = graph._get_labels()
|
|
|
|
relationships = graph._get_triples(e_labels)
|
|
|
|
expected_relationships = [
|
|
{"start": "LabelA", "type": "REL_TYPE", "end": "LabelB"},
|
|
{"start": "LabelA", "type": "REL_TYPE", "end": "LabelC"},
|
|
]
|
|
|
|
self.assertEqual(
|
|
sorted(relationships, key=lambda x: x["end"]), expected_relationships
|
|
)
|
|
|
|
def test_add_documents(self) -> None:
|
|
conf = {
|
|
"database": os.getenv("AGE_PGSQL_DB"),
|
|
"user": os.getenv("AGE_PGSQL_USER"),
|
|
"password": os.getenv("AGE_PGSQL_PASSWORD"),
|
|
"host": os.getenv("AGE_PGSQL_HOST", "localhost"),
|
|
"port": int(os.getenv("AGE_PGSQL_PORT", 5432)),
|
|
}
|
|
|
|
self.assertIsNotNone(conf["database"])
|
|
self.assertIsNotNone(conf["user"])
|
|
self.assertIsNotNone(conf["password"])
|
|
|
|
graph_name = os.getenv("AGE_GRAPH_NAME", "age_test")
|
|
|
|
graph = AGEGraph(graph_name, conf)
|
|
|
|
# Delete all nodes in the graph
|
|
graph.query("MATCH (n) DETACH DELETE n")
|
|
# Create two nodes and a relationship
|
|
graph.add_graph_documents(test_data)
|
|
output = graph.query(
|
|
"MATCH (n) RETURN labels(n) AS label, count(*) AS count ORDER BY labels(n)"
|
|
)
|
|
self.assertEqual(
|
|
output, [{"label": ["bar"], "count": 1}, {"label": ["foo"], "count": 1}]
|
|
)
|
|
|
|
def test_add_documents_source(self) -> None:
|
|
conf = {
|
|
"database": os.getenv("AGE_PGSQL_DB"),
|
|
"user": os.getenv("AGE_PGSQL_USER"),
|
|
"password": os.getenv("AGE_PGSQL_PASSWORD"),
|
|
"host": os.getenv("AGE_PGSQL_HOST", "localhost"),
|
|
"port": int(os.getenv("AGE_PGSQL_PORT", 5432)),
|
|
}
|
|
|
|
self.assertIsNotNone(conf["database"])
|
|
self.assertIsNotNone(conf["user"])
|
|
self.assertIsNotNone(conf["password"])
|
|
|
|
graph_name = os.getenv("AGE_GRAPH_NAME", "age_test")
|
|
|
|
graph = AGEGraph(graph_name, conf)
|
|
|
|
# Delete all nodes in the graph
|
|
graph.query("MATCH (n) DETACH DELETE n")
|
|
# Create two nodes and a relationship
|
|
graph.add_graph_documents(test_data, include_source=True)
|
|
output = graph.query(
|
|
"MATCH (n) RETURN labels(n) AS label, count(*) AS count ORDER BY labels(n)"
|
|
)
|
|
|
|
expected = [
|
|
{"label": ["bar"], "count": 1},
|
|
{"label": ["Document"], "count": 1},
|
|
{"label": ["foo"], "count": 1},
|
|
]
|
|
self.assertEqual(output, expected)
|
|
|
|
def test_get_schema(self) -> None:
|
|
conf = {
|
|
"database": os.getenv("AGE_PGSQL_DB"),
|
|
"user": os.getenv("AGE_PGSQL_USER"),
|
|
"password": os.getenv("AGE_PGSQL_PASSWORD"),
|
|
"host": os.getenv("AGE_PGSQL_HOST", "localhost"),
|
|
"port": int(os.getenv("AGE_PGSQL_PORT", 5432)),
|
|
}
|
|
|
|
self.assertIsNotNone(conf["database"])
|
|
self.assertIsNotNone(conf["user"])
|
|
self.assertIsNotNone(conf["password"])
|
|
|
|
graph_name = os.getenv("AGE_GRAPH_NAME", "age_test")
|
|
|
|
graph = AGEGraph(graph_name, conf)
|
|
|
|
graph.query("MATCH (n) DETACH DELETE n")
|
|
|
|
graph.refresh_schema()
|
|
|
|
expected = """
|
|
Node properties are the following:
|
|
[]
|
|
Relationship properties are the following:
|
|
[]
|
|
The relationships are the following:
|
|
[]
|
|
"""
|
|
# check that works on empty schema
|
|
self.assertEqual(
|
|
re.sub(r"\s", "", graph.get_schema), re.sub(r"\s", "", expected)
|
|
)
|
|
|
|
expected_structured: Dict[str, Any] = {
|
|
"node_props": {},
|
|
"rel_props": {},
|
|
"relationships": [],
|
|
"metadata": {},
|
|
}
|
|
|
|
self.assertEqual(graph.get_structured_schema, expected_structured)
|
|
|
|
# Create two nodes and a relationship
|
|
graph.query(
|
|
"""
|
|
MERGE (a:a {id: 1})-[b:b {id: 2}]-> (c:c {id: 3})
|
|
"""
|
|
)
|
|
|
|
# check that schema doesn't update without refresh
|
|
self.assertEqual(
|
|
re.sub(r"\s", "", graph.get_schema), re.sub(r"\s", "", expected)
|
|
)
|
|
self.assertEqual(graph.get_structured_schema, expected_structured)
|
|
|
|
# two possible orderings of node props
|
|
expected_possibilities = [
|
|
"""
|
|
Node properties are the following:
|
|
[
|
|
{'properties': [{'property': 'id', 'type': 'INTEGER'}], 'labels': 'a'},
|
|
{'properties': [{'property': 'id', 'type': 'INTEGER'}], 'labels': 'c'}
|
|
]
|
|
Relationship properties are the following:
|
|
[
|
|
{'properties': [{'property': 'id', 'type': 'INTEGER'}], 'type': 'b'}
|
|
]
|
|
The relationships are the following:
|
|
[
|
|
'(:`a`)-[:`b`]->(:`c`)'
|
|
]
|
|
""",
|
|
"""
|
|
Node properties are the following:
|
|
[
|
|
{'properties': [{'property': 'id', 'type': 'INTEGER'}], 'labels': 'c'},
|
|
{'properties': [{'property': 'id', 'type': 'INTEGER'}], 'labels': 'a'}
|
|
]
|
|
Relationship properties are the following:
|
|
[
|
|
{'properties': [{'property': 'id', 'type': 'INTEGER'}], 'type': 'b'}
|
|
]
|
|
The relationships are the following:
|
|
[
|
|
'(:`a`)-[:`b`]->(:`c`)'
|
|
]
|
|
""",
|
|
]
|
|
|
|
expected_structured2 = {
|
|
"node_props": {
|
|
"a": [{"property": "id", "type": "INTEGER"}],
|
|
"c": [{"property": "id", "type": "INTEGER"}],
|
|
},
|
|
"rel_props": {"b": [{"property": "id", "type": "INTEGER"}]},
|
|
"relationships": [{"start": "a", "type": "b", "end": "c"}],
|
|
"metadata": {},
|
|
}
|
|
|
|
graph.refresh_schema()
|
|
|
|
# check that schema is refreshed
|
|
self.assertIn(
|
|
re.sub(r"\s", "", graph.get_schema),
|
|
[re.sub(r"\s", "", x) for x in expected_possibilities],
|
|
)
|
|
self.assertEqual(graph.get_structured_schema, expected_structured2)
|