Filtering graph schema for Cypher generation (#10577)

Sometimes you don't want the LLM to be aware of the whole graph schema,
and want it to ignore parts of the graph when it is constructing Cypher
statements.
pull/11036/head
Tomaz Bratanic 10 months ago committed by GitHub
parent 89ef440c14
commit 0625ab7a9e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -135,7 +135,7 @@
}
],
"source": [
"print(graph.get_schema)"
"print(graph.schema)"
]
},
{
@ -510,13 +510,54 @@
"chain.run(\"Who played in Top Gun?\")"
]
},
{
"cell_type": "markdown",
"id": "eefea16b-508f-4552-8942-9d5063ed7d37",
"metadata": {},
"source": [
"# Ignore specified node and relationship types\n",
"You can use `include_types` or `exclude_types` to ignore parts of the graph schema when generating Cypher statements."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "48ff7cf8-18a3-43d7-8cb1-c1b91744608d",
"execution_count": 18,
"id": "a20fa21e-fb85-41c4-aac0-53fb25e34604",
"metadata": {},
"outputs": [],
"source": []
"source": [
"chain = GraphCypherQAChain.from_llm(\n",
" graph=graph,\n",
" cypher_llm=ChatOpenAI(temperature=0, model=\"gpt-3.5-turbo\"),\n",
" qa_llm=ChatOpenAI(temperature=0, model=\"gpt-3.5-turbo-16k\"),\n",
" verbose=True,\n",
" exclude_types=['Movie']\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "3ad7f6b8-543e-46e4-a3b2-40fa3e66e895",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Node properties are the following: \n",
" {'Actor': [{'property': 'name', 'type': 'STRING'}]}\n",
"Relationships properties are the following: \n",
" {}\n",
"Relationships are: \n",
"[]\n"
]
}
],
"source": [
"# Inspect graph schema\n",
"print(chain.graph_schema)"
]
}
],
"metadata": {

@ -187,7 +187,7 @@
"metadata": {},
"outputs": [],
"source": [
"print(graph.get_schema)"
"print(graph.schema)"
]
},
{
@ -687,7 +687,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.13"
"version": "3.8.8"
}
},
"nbformat": 4,

@ -34,12 +34,54 @@ def extract_cypher(text: str) -> str:
return matches[0] if matches else text
def construct_schema(
structured_schema: Dict[str, Any],
include_types: List[str],
exclude_types: List[str],
) -> str:
"""Filter the schema based on included or excluded types"""
def filter_func(x: str) -> bool:
return x in include_types if include_types else x not in exclude_types
filtered_schema = {
"node_props": {
k: v
for k, v in structured_schema.get("node_props", {}).items()
if filter_func(k)
},
"rel_props": {
k: v
for k, v in structured_schema.get("rel_props", {}).items()
if filter_func(k)
},
"relationships": [
r
for r in structured_schema.get("relationships", [])
if all(filter_func(r[t]) for t in ["start", "end", "type"])
],
}
return (
f"Node properties are the following: \n {filtered_schema['node_props']}\n"
f"Relationships properties are the following: \n {filtered_schema['rel_props']}"
"\nRelationships are: \n"
+ str(
[
f"(:{el['start']})-[:{el['type']}]->(:{el['end']})"
for el in filtered_schema["relationships"]
]
)
)
class GraphCypherQAChain(Chain):
"""Chain for question-answering against a graph by generating Cypher statements."""
graph: Neo4jGraph = Field(exclude=True)
cypher_generation_chain: LLMChain
qa_chain: LLMChain
graph_schema: str
input_key: str = "query" #: :meta private:
output_key: str = "result" #: :meta private:
top_k: int = 10
@ -79,6 +121,8 @@ class GraphCypherQAChain(Chain):
cypher_prompt: BasePromptTemplate = CYPHER_GENERATION_PROMPT,
cypher_llm: Optional[BaseLanguageModel] = None,
qa_llm: Optional[BaseLanguageModel] = None,
exclude_types: List[str] = [],
include_types: List[str] = [],
**kwargs: Any,
) -> GraphCypherQAChain:
"""Initialize from LLM."""
@ -96,7 +140,18 @@ class GraphCypherQAChain(Chain):
qa_chain = LLMChain(llm=qa_llm or llm, prompt=qa_prompt)
cypher_generation_chain = LLMChain(llm=cypher_llm or llm, prompt=cypher_prompt)
if exclude_types and include_types:
raise ValueError(
"Either `exclude_types` or `include_types` "
"can be provided, but not both"
)
graph_schema = construct_schema(
kwargs["graph"].structured_schema, include_types, exclude_types
)
return cls(
graph_schema=graph_schema,
qa_chain=qa_chain,
cypher_generation_chain=cypher_generation_chain,
**kwargs,
@ -115,7 +170,7 @@ class GraphCypherQAChain(Chain):
intermediate_steps: List = []
generated_cypher = self.cypher_generation_chain.run(
{"question": question, "schema": self.graph.get_schema}, callbacks=callbacks
{"question": question, "schema": self.graph_schema}, callbacks=callbacks
)
# Extract Cypher code if it is wrapped in backticks

@ -6,6 +6,12 @@ YIELD *
RETURN *
"""
RAW_SCHEMA_QUERY = """
CALL llm_util.schema("raw")
YIELD *
RETURN *
"""
class MemgraphGraph(Neo4jGraph):
"""Memgraph wrapper for graph operations."""
@ -24,3 +30,7 @@ class MemgraphGraph(Neo4jGraph):
db_schema = self.query(SCHEMA_QUERY)[0].get("schema")
assert db_schema is not None
self.schema = db_schema
db_structured_schema = self.query(RAW_SCHEMA_QUERY)[0].get("schema")
assert db_structured_schema is not None
self.structured_schema = db_structured_schema

@ -24,7 +24,7 @@ CALL apoc.meta.data()
YIELD label, other, elementType, type, property
WHERE type = "RELATIONSHIP" AND elementType = "node"
UNWIND other AS other_node
RETURN "(:" + label + ")-[:" + property + "]->(:" + toString(other_node) + ")" AS output
RETURN {start: label, type: property, end: toString(other_node)} AS output
"""
@ -45,7 +45,8 @@ class Neo4jGraph:
self._driver = neo4j.GraphDatabase.driver(url, auth=(username, password))
self._database = database
self.schema = ""
self.schema: str = ""
self.structured_schema: Dict[str, Any] = {}
# Verify connection
try:
self._driver.verify_connectivity()
@ -69,11 +70,6 @@ class Neo4jGraph:
"'apoc.meta.data()' is allowed in Neo4j configuration "
)
@property
def get_schema(self) -> str:
"""Returns the schema of the Neo4j database"""
return self.schema
def query(self, query: str, params: dict = {}) -> List[Dict[str, Any]]:
"""Query Neo4j database."""
from neo4j.exceptions import CypherSyntaxError
@ -89,17 +85,22 @@ class Neo4jGraph:
"""
Refreshes the Neo4j graph schema information.
"""
node_properties = self.query(node_properties_query)
relationships_properties = self.query(rel_properties_query)
relationships = self.query(rel_query)
node_properties = [el["output"] for el in self.query(node_properties_query)]
rel_properties = [el["output"] for el in self.query(rel_properties_query)]
relationships = [el["output"] for el in self.query(rel_query)]
self.structured_schema = {
"node_props": {el["labels"]: el["properties"] for el in node_properties},
"rel_props": {el["type"]: el["properties"] for el in rel_properties},
"relationships": relationships,
}
self.schema = f"""
Node properties are the following:
{[el['output'] for el in node_properties]}
{node_properties}
Relationship properties are the following:
{[el['output'] for el in relationships_properties]}
{rel_properties}
The relationships are the following:
{[el['output'] for el in relationships]}
{[f"(:{el['start']})-[:{el['type']}]->(:{el['end']})" for el in relationships]}
"""
def add_graph_documents(

@ -211,21 +211,32 @@ def test_cypher_return_correct_schema() -> None:
expected_node_properties = [
{
"properties": [{"property": "property_a", "type": "STRING"}],
"labels": "LabelA",
"output": {
"properties": [{"property": "property_a", "type": "STRING"}],
"labels": "LabelA",
}
}
]
expected_relationships_properties = [
{"type": "REL_TYPE", "properties": [{"property": "rel_prop", "type": "STRING"}]}
{
"output": {
"type": "REL_TYPE",
"properties": [{"property": "rel_prop", "type": "STRING"}],
}
}
]
expected_relationships = [
"(:LabelA)-[:REL_TYPE]->(:LabelB)",
"(:LabelA)-[:REL_TYPE]->(:LabelC)",
{"output": {"start": "LabelA", "type": "REL_TYPE", "end": "LabelB"}},
{"output": {"start": "LabelA", "type": "REL_TYPE", "end": "LabelC"}},
]
assert node_properties == expected_node_properties
assert relationships_properties == expected_relationships_properties
assert relationships == expected_relationships
# Order is not guaranteed with Neo4j returns
assert (
sorted(relationships, key=lambda x: x["output"]["end"])
== expected_relationships
)
def test_cypher_save_load() -> None:
@ -252,3 +263,122 @@ def test_cypher_save_load() -> None:
qa_loaded = load_chain(FILE_PATH, graph=graph)
assert qa_loaded == chain
def test_exclude_types() -> None:
"""Test exclude types from schema."""
url = os.environ.get("NEO4J_URL")
username = os.environ.get("NEO4J_USERNAME")
password = os.environ.get("NEO4J_PASSWORD")
assert url is not None
assert username is not None
assert password is not None
graph = Neo4jGraph(
url=url,
username=username,
password=password,
)
# Delete all nodes in the graph
graph.query("MATCH (n) DETACH DELETE n")
# Create two nodes and a relationship
graph.query(
"CREATE (a:Actor {name:'Bruce Willis'})"
"-[:ACTED_IN]->(:Movie {title: 'Pulp Fiction'})"
"<-[:DIRECTED]-(p:Person {name:'John'})"
)
# Refresh schema information
graph.refresh_schema()
chain = GraphCypherQAChain.from_llm(
OpenAI(temperature=0), graph=graph, exclude_types=["Person", "DIRECTED"]
)
expected_schema = (
"Node properties are the following: \n"
" {'Movie': [{'property': 'title', 'type': 'STRING'}], "
"'Actor': [{'property': 'name', 'type': 'STRING'}]}\n"
"Relationships properties are the following: \n"
" {}\nRelationships are: \n"
"['(:Actor)-[:ACTED_IN]->(:Movie)']"
)
assert chain.graph_schema == expected_schema
def test_include_types() -> None:
"""Test include types from schema."""
url = os.environ.get("NEO4J_URL")
username = os.environ.get("NEO4J_USERNAME")
password = os.environ.get("NEO4J_PASSWORD")
assert url is not None
assert username is not None
assert password is not None
graph = Neo4jGraph(
url=url,
username=username,
password=password,
)
# Delete all nodes in the graph
graph.query("MATCH (n) DETACH DELETE n")
# Create two nodes and a relationship
graph.query(
"CREATE (a:Actor {name:'Bruce Willis'})"
"-[:ACTED_IN]->(:Movie {title: 'Pulp Fiction'})"
"<-[:DIRECTED]-(p:Person {name:'John'})"
)
# Refresh schema information
graph.refresh_schema()
chain = GraphCypherQAChain.from_llm(
OpenAI(temperature=0), graph=graph, include_types=["Movie", "Actor", "ACTED_IN"]
)
expected_schema = (
"Node properties are the following: \n"
" {'Movie': [{'property': 'title', 'type': 'STRING'}], "
"'Actor': [{'property': 'name', 'type': 'STRING'}]}\n"
"Relationships properties are the following: \n"
" {}\nRelationships are: \n"
"['(:Actor)-[:ACTED_IN]->(:Movie)']"
)
assert chain.graph_schema == expected_schema
def test_include_types2() -> None:
"""Test include types from schema."""
url = os.environ.get("NEO4J_URL")
username = os.environ.get("NEO4J_USERNAME")
password = os.environ.get("NEO4J_PASSWORD")
assert url is not None
assert username is not None
assert password is not None
graph = Neo4jGraph(
url=url,
username=username,
password=password,
)
# Delete all nodes in the graph
graph.query("MATCH (n) DETACH DELETE n")
# Create two nodes and a relationship
graph.query(
"CREATE (a:Actor {name:'Bruce Willis'})"
"-[:ACTED_IN]->(:Movie {title: 'Pulp Fiction'})"
"<-[:DIRECTED]-(p:Person {name:'John'})"
)
# Refresh schema information
graph.refresh_schema()
chain = GraphCypherQAChain.from_llm(
OpenAI(temperature=0), graph=graph, include_types=["Movie", "ACTED_IN"]
)
expected_schema = (
"Node properties are the following: \n"
" {'Movie': [{'property': 'title', 'type': 'STRING'}]}\n"
"Relationships properties are the following: \n"
" {}\nRelationships are: \n"
"[]"
)
assert chain.graph_schema == expected_schema

@ -1,4 +1,4 @@
from langchain.chains.graph_qa.cypher import extract_cypher
from langchain.chains.graph_qa.cypher import construct_schema, extract_cypher
def test_no_backticks() -> None:
@ -13,3 +13,107 @@ def test_backticks() -> None:
query = "You can use the following query: ```MATCH (n) RETURN n```"
output = extract_cypher(query)
assert output == "MATCH (n) RETURN n"
def test_exclude_types() -> None:
structured_schema = {
"node_props": {
"Movie": [{"property": "title", "type": "STRING"}],
"Actor": [{"property": "name", "type": "STRING"}],
"Person": [{"property": "name", "type": "STRING"}],
},
"rel_props": {},
"relationships": [
{"start": "Actor", "end": "Movie", "type": "ACTED_IN"},
{"start": "Person", "end": "Movie", "type": "DIRECTED"},
],
}
exclude_types = ["Person", "DIRECTED"]
output = construct_schema(structured_schema, [], exclude_types)
expected_schema = (
"Node properties are the following: \n"
" {'Movie': [{'property': 'title', 'type': 'STRING'}], "
"'Actor': [{'property': 'name', 'type': 'STRING'}]}\n"
"Relationships properties are the following: \n"
" {}\nRelationships are: \n"
"['(:Actor)-[:ACTED_IN]->(:Movie)']"
)
assert output == expected_schema
def test_include_types() -> None:
structured_schema = {
"node_props": {
"Movie": [{"property": "title", "type": "STRING"}],
"Actor": [{"property": "name", "type": "STRING"}],
"Person": [{"property": "name", "type": "STRING"}],
},
"rel_props": {},
"relationships": [
{"start": "Actor", "end": "Movie", "type": "ACTED_IN"},
{"start": "Person", "end": "Movie", "type": "DIRECTED"},
],
}
include_types = ["Movie", "Actor", "ACTED_IN"]
output = construct_schema(structured_schema, include_types, [])
expected_schema = (
"Node properties are the following: \n"
" {'Movie': [{'property': 'title', 'type': 'STRING'}], "
"'Actor': [{'property': 'name', 'type': 'STRING'}]}\n"
"Relationships properties are the following: \n"
" {}\nRelationships are: \n"
"['(:Actor)-[:ACTED_IN]->(:Movie)']"
)
assert output == expected_schema
def test_include_types2() -> None:
structured_schema = {
"node_props": {
"Movie": [{"property": "title", "type": "STRING"}],
"Actor": [{"property": "name", "type": "STRING"}],
"Person": [{"property": "name", "type": "STRING"}],
},
"rel_props": {},
"relationships": [
{"start": "Actor", "end": "Movie", "type": "ACTED_IN"},
{"start": "Person", "end": "Movie", "type": "DIRECTED"},
],
}
include_types = ["Movie", "Actor"]
output = construct_schema(structured_schema, include_types, [])
expected_schema = (
"Node properties are the following: \n"
" {'Movie': [{'property': 'title', 'type': 'STRING'}], "
"'Actor': [{'property': 'name', 'type': 'STRING'}]}\n"
"Relationships properties are the following: \n"
" {}\nRelationships are: \n"
"[]"
)
assert output == expected_schema
def test_include_types3() -> None:
structured_schema = {
"node_props": {
"Movie": [{"property": "title", "type": "STRING"}],
"Actor": [{"property": "name", "type": "STRING"}],
"Person": [{"property": "name", "type": "STRING"}],
},
"rel_props": {},
"relationships": [
{"start": "Actor", "end": "Movie", "type": "ACTED_IN"},
{"start": "Person", "end": "Movie", "type": "DIRECTED"},
],
}
include_types = ["Movie", "Actor", "ACTED_IN"]
output = construct_schema(structured_schema, include_types, [])
expected_schema = (
"Node properties are the following: \n"
" {'Movie': [{'property': 'title', 'type': 'STRING'}], "
"'Actor': [{'property': 'name', 'type': 'STRING'}]}\n"
"Relationships properties are the following: \n"
" {}\nRelationships are: \n"
"['(:Actor)-[:ACTED_IN]->(:Movie)']"
)
assert output == expected_schema

Loading…
Cancel
Save