mirror of
https://github.com/hwchase17/langchain
synced 2024-11-10 01:10:59 +00:00
261 lines
9.4 KiB
Python
261 lines
9.4 KiB
Python
|
import re
|
||
|
from collections import namedtuple
|
||
|
from typing import Any, Dict, List, Optional, Tuple
|
||
|
|
||
|
Schema = namedtuple("Schema", ["left_node", "relation", "right_node"])
|
||
|
|
||
|
|
||
|
class CypherQueryCorrector:
|
||
|
"""
|
||
|
Used to correct relationship direction in generated Cypher statements.
|
||
|
This code is copied from the winner's submission to the Cypher competition:
|
||
|
https://github.com/sakusaku-rich/cypher-direction-competition
|
||
|
"""
|
||
|
|
||
|
property_pattern = re.compile(r"\{.+?\}")
|
||
|
node_pattern = re.compile(r"\(.+?\)")
|
||
|
path_pattern = re.compile(
|
||
|
r"(\([^\,\(\)]*?(\{.+\})?[^\,\(\)]*?\))(<?-)(\[.*?\])?(->?)(\([^\,\(\)]*?(\{.+\})?[^\,\(\)]*?\))"
|
||
|
)
|
||
|
node_relation_node_pattern = re.compile(
|
||
|
r"(\()+(?P<left_node>[^()]*?)\)(?P<relation>.*?)\((?P<right_node>[^()]*?)(\))+"
|
||
|
)
|
||
|
relation_type_pattern = re.compile(r":(?P<relation_type>.+?)?(\{.+\})?]")
|
||
|
|
||
|
def __init__(self, schemas: List[Schema]):
|
||
|
"""
|
||
|
Args:
|
||
|
schemas: list of schemas
|
||
|
"""
|
||
|
self.schemas = schemas
|
||
|
|
||
|
def clean_node(self, node: str) -> str:
|
||
|
"""
|
||
|
Args:
|
||
|
node: node in string format
|
||
|
|
||
|
"""
|
||
|
node = re.sub(self.property_pattern, "", node)
|
||
|
node = node.replace("(", "")
|
||
|
node = node.replace(")", "")
|
||
|
node = node.strip()
|
||
|
return node
|
||
|
|
||
|
def detect_node_variables(self, query: str) -> Dict[str, List[str]]:
|
||
|
"""
|
||
|
Args:
|
||
|
query: cypher query
|
||
|
"""
|
||
|
nodes = re.findall(self.node_pattern, query)
|
||
|
nodes = [self.clean_node(node) for node in nodes]
|
||
|
res: Dict[str, Any] = {}
|
||
|
for node in nodes:
|
||
|
parts = node.split(":")
|
||
|
if parts == "":
|
||
|
continue
|
||
|
variable = parts[0]
|
||
|
if variable not in res:
|
||
|
res[variable] = []
|
||
|
res[variable] += parts[1:]
|
||
|
return res
|
||
|
|
||
|
def extract_paths(self, query: str) -> "List[str]":
|
||
|
"""
|
||
|
Args:
|
||
|
query: cypher query
|
||
|
"""
|
||
|
paths = []
|
||
|
idx = 0
|
||
|
while matched := self.path_pattern.findall(query[idx:]):
|
||
|
matched = matched[0]
|
||
|
matched = [
|
||
|
m for i, m in enumerate(matched) if i not in [1, len(matched) - 1]
|
||
|
]
|
||
|
path = "".join(matched)
|
||
|
idx = query.find(path) + len(path) - len(matched[-1])
|
||
|
paths.append(path)
|
||
|
return paths
|
||
|
|
||
|
def judge_direction(self, relation: str) -> str:
|
||
|
"""
|
||
|
Args:
|
||
|
relation: relation in string format
|
||
|
"""
|
||
|
direction = "BIDIRECTIONAL"
|
||
|
if relation[0] == "<":
|
||
|
direction = "INCOMING"
|
||
|
if relation[-1] == ">":
|
||
|
direction = "OUTGOING"
|
||
|
return direction
|
||
|
|
||
|
def extract_node_variable(self, part: str) -> Optional[str]:
|
||
|
"""
|
||
|
Args:
|
||
|
part: node in string format
|
||
|
"""
|
||
|
part = part.lstrip("(").rstrip(")")
|
||
|
idx = part.find(":")
|
||
|
if idx != -1:
|
||
|
part = part[:idx]
|
||
|
return None if part == "" else part
|
||
|
|
||
|
def detect_labels(
|
||
|
self, str_node: str, node_variable_dict: Dict[str, Any]
|
||
|
) -> List[str]:
|
||
|
"""
|
||
|
Args:
|
||
|
str_node: node in string format
|
||
|
node_variable_dict: dictionary of node variables
|
||
|
"""
|
||
|
splitted_node = str_node.split(":")
|
||
|
variable = splitted_node[0]
|
||
|
labels = []
|
||
|
if variable in node_variable_dict:
|
||
|
labels = node_variable_dict[variable]
|
||
|
elif variable == "" and len(splitted_node) > 1:
|
||
|
labels = splitted_node[1:]
|
||
|
return labels
|
||
|
|
||
|
def verify_schema(
|
||
|
self,
|
||
|
from_node_labels: List[str],
|
||
|
relation_types: List[str],
|
||
|
to_node_labels: List[str],
|
||
|
) -> bool:
|
||
|
"""
|
||
|
Args:
|
||
|
from_node_labels: labels of the from node
|
||
|
relation_type: type of the relation
|
||
|
to_node_labels: labels of the to node
|
||
|
"""
|
||
|
valid_schemas = self.schemas
|
||
|
if from_node_labels != []:
|
||
|
from_node_labels = [label.strip("`") for label in from_node_labels]
|
||
|
valid_schemas = [
|
||
|
schema for schema in valid_schemas if schema[0] in from_node_labels
|
||
|
]
|
||
|
if to_node_labels != []:
|
||
|
to_node_labels = [label.strip("`") for label in to_node_labels]
|
||
|
valid_schemas = [
|
||
|
schema for schema in valid_schemas if schema[2] in to_node_labels
|
||
|
]
|
||
|
if relation_types != []:
|
||
|
relation_types = [type.strip("`") for type in relation_types]
|
||
|
valid_schemas = [
|
||
|
schema for schema in valid_schemas if schema[1] in relation_types
|
||
|
]
|
||
|
return valid_schemas != []
|
||
|
|
||
|
def detect_relation_types(self, str_relation: str) -> Tuple[str, List[str]]:
|
||
|
"""
|
||
|
Args:
|
||
|
str_relation: relation in string format
|
||
|
"""
|
||
|
relation_direction = self.judge_direction(str_relation)
|
||
|
relation_type = self.relation_type_pattern.search(str_relation)
|
||
|
if relation_type is None or relation_type.group("relation_type") is None:
|
||
|
return relation_direction, []
|
||
|
relation_types = [
|
||
|
t.strip().strip("!")
|
||
|
for t in relation_type.group("relation_type").split("|")
|
||
|
]
|
||
|
return relation_direction, relation_types
|
||
|
|
||
|
def correct_query(self, query: str) -> str:
|
||
|
"""
|
||
|
Args:
|
||
|
query: cypher query
|
||
|
"""
|
||
|
node_variable_dict = self.detect_node_variables(query)
|
||
|
paths = self.extract_paths(query)
|
||
|
for path in paths:
|
||
|
original_path = path
|
||
|
start_idx = 0
|
||
|
while start_idx < len(path):
|
||
|
match_res = re.match(self.node_relation_node_pattern, path[start_idx:])
|
||
|
if match_res is None:
|
||
|
break
|
||
|
start_idx += match_res.start()
|
||
|
match_dict = match_res.groupdict()
|
||
|
left_node_labels = self.detect_labels(
|
||
|
match_dict["left_node"], node_variable_dict
|
||
|
)
|
||
|
right_node_labels = self.detect_labels(
|
||
|
match_dict["right_node"], node_variable_dict
|
||
|
)
|
||
|
end_idx = (
|
||
|
start_idx
|
||
|
+ 4
|
||
|
+ len(match_dict["left_node"])
|
||
|
+ len(match_dict["relation"])
|
||
|
+ len(match_dict["right_node"])
|
||
|
)
|
||
|
original_partial_path = original_path[start_idx : end_idx + 1]
|
||
|
relation_direction, relation_types = self.detect_relation_types(
|
||
|
match_dict["relation"]
|
||
|
)
|
||
|
|
||
|
if relation_types != [] and "".join(relation_types).find("*") != -1:
|
||
|
start_idx += (
|
||
|
len(match_dict["left_node"]) + len(match_dict["relation"]) + 2
|
||
|
)
|
||
|
continue
|
||
|
|
||
|
if relation_direction == "OUTGOING":
|
||
|
is_legal = self.verify_schema(
|
||
|
left_node_labels, relation_types, right_node_labels
|
||
|
)
|
||
|
if not is_legal:
|
||
|
is_legal = self.verify_schema(
|
||
|
right_node_labels, relation_types, left_node_labels
|
||
|
)
|
||
|
if is_legal:
|
||
|
corrected_relation = "<" + match_dict["relation"][:-1]
|
||
|
corrected_partial_path = original_partial_path.replace(
|
||
|
match_dict["relation"], corrected_relation
|
||
|
)
|
||
|
query = query.replace(
|
||
|
original_partial_path, corrected_partial_path
|
||
|
)
|
||
|
else:
|
||
|
return ""
|
||
|
elif relation_direction == "INCOMING":
|
||
|
is_legal = self.verify_schema(
|
||
|
right_node_labels, relation_types, left_node_labels
|
||
|
)
|
||
|
if not is_legal:
|
||
|
is_legal = self.verify_schema(
|
||
|
left_node_labels, relation_types, right_node_labels
|
||
|
)
|
||
|
if is_legal:
|
||
|
corrected_relation = match_dict["relation"][1:] + ">"
|
||
|
corrected_partial_path = original_partial_path.replace(
|
||
|
match_dict["relation"], corrected_relation
|
||
|
)
|
||
|
query = query.replace(
|
||
|
original_partial_path, corrected_partial_path
|
||
|
)
|
||
|
else:
|
||
|
return ""
|
||
|
else:
|
||
|
is_legal = self.verify_schema(
|
||
|
left_node_labels, relation_types, right_node_labels
|
||
|
)
|
||
|
is_legal |= self.verify_schema(
|
||
|
right_node_labels, relation_types, left_node_labels
|
||
|
)
|
||
|
if not is_legal:
|
||
|
return ""
|
||
|
|
||
|
start_idx += (
|
||
|
len(match_dict["left_node"]) + len(match_dict["relation"]) + 2
|
||
|
)
|
||
|
return query
|
||
|
|
||
|
def __call__(self, query: str) -> str:
|
||
|
"""Correct the query to make it valid. If
|
||
|
Args:
|
||
|
query: cypher query
|
||
|
"""
|
||
|
return self.correct_query(query)
|