mirror of
https://github.com/hwchase17/langchain
synced 2024-11-06 03:20:49 +00:00
f92006de3c
0.2rc migrations - [x] Move memory - [x] Move remaining retrievers - [x] graph_qa chains - [x] some dependency from evaluation code potentially on math utils - [x] Move openapi chain from `langchain.chains.api.openapi` to `langchain_community.chains.openapi` - [x] Migrate `langchain.chains.ernie_functions` to `langchain_community.chains.ernie_functions` - [x] migrate `langchain/chains/llm_requests.py` to `langchain_community.chains.llm_requests` - [x] Moving `langchain_community.cross_enoders.base:BaseCrossEncoder` -> `langchain_community.retrievers.document_compressors.cross_encoder:BaseCrossEncoder` (namespace not ideal, but it needs to be moved to `langchain` to avoid circular deps) - [x] unit tests langchain -- add pytest.mark.community to some unit tests that will stay in langchain - [x] unit tests community -- move unit tests that depend on community to community - [x] mv integration tests that depend on community to community - [x] mypy checks Other todo - [x] Make deprecation warnings not noisy (need to use warn deprecated and check that things are implemented properly) - [x] Update deprecation messages with timeline for code removal (likely we actually won't be removing things until 0.4 release) -- will give people more time to transition their code. - [ ] Add information to deprecation warning to show users how to migrate their code base using langchain-cli - [ ] Remove any unnecessary requirements in langchain (e.g., is SQLALchemy required?) --------- Co-authored-by: Erick Friis <erick@langchain.dev>
261 lines
9.4 KiB
Python
261 lines
9.4 KiB
Python
import re
|
|
from collections import namedtuple
|
|
from typing import Any, Dict, List, Optional, Tuple
|
|
|
|
Schema = namedtuple("Schema", ["left_node", "relation", "right_node"])
|
|
|
|
|
|
class CypherQueryCorrector:
|
|
"""
|
|
Used to correct relationship direction in generated Cypher statements.
|
|
This code is copied from the winner's submission to the Cypher competition:
|
|
https://github.com/sakusaku-rich/cypher-direction-competition
|
|
"""
|
|
|
|
property_pattern = re.compile(r"\{.+?\}")
|
|
node_pattern = re.compile(r"\(.+?\)")
|
|
path_pattern = re.compile(
|
|
r"(\([^\,\(\)]*?(\{.+\})?[^\,\(\)]*?\))(<?-)(\[.*?\])?(->?)(\([^\,\(\)]*?(\{.+\})?[^\,\(\)]*?\))"
|
|
)
|
|
node_relation_node_pattern = re.compile(
|
|
r"(\()+(?P<left_node>[^()]*?)\)(?P<relation>.*?)\((?P<right_node>[^()]*?)(\))+"
|
|
)
|
|
relation_type_pattern = re.compile(r":(?P<relation_type>.+?)?(\{.+\})?]")
|
|
|
|
def __init__(self, schemas: List[Schema]):
|
|
"""
|
|
Args:
|
|
schemas: list of schemas
|
|
"""
|
|
self.schemas = schemas
|
|
|
|
def clean_node(self, node: str) -> str:
|
|
"""
|
|
Args:
|
|
node: node in string format
|
|
|
|
"""
|
|
node = re.sub(self.property_pattern, "", node)
|
|
node = node.replace("(", "")
|
|
node = node.replace(")", "")
|
|
node = node.strip()
|
|
return node
|
|
|
|
def detect_node_variables(self, query: str) -> Dict[str, List[str]]:
|
|
"""
|
|
Args:
|
|
query: cypher query
|
|
"""
|
|
nodes = re.findall(self.node_pattern, query)
|
|
nodes = [self.clean_node(node) for node in nodes]
|
|
res: Dict[str, Any] = {}
|
|
for node in nodes:
|
|
parts = node.split(":")
|
|
if parts == "":
|
|
continue
|
|
variable = parts[0]
|
|
if variable not in res:
|
|
res[variable] = []
|
|
res[variable] += parts[1:]
|
|
return res
|
|
|
|
def extract_paths(self, query: str) -> "List[str]":
|
|
"""
|
|
Args:
|
|
query: cypher query
|
|
"""
|
|
paths = []
|
|
idx = 0
|
|
while matched := self.path_pattern.findall(query[idx:]):
|
|
matched = matched[0]
|
|
matched = [
|
|
m for i, m in enumerate(matched) if i not in [1, len(matched) - 1]
|
|
]
|
|
path = "".join(matched)
|
|
idx = query.find(path) + len(path) - len(matched[-1])
|
|
paths.append(path)
|
|
return paths
|
|
|
|
def judge_direction(self, relation: str) -> str:
|
|
"""
|
|
Args:
|
|
relation: relation in string format
|
|
"""
|
|
direction = "BIDIRECTIONAL"
|
|
if relation[0] == "<":
|
|
direction = "INCOMING"
|
|
if relation[-1] == ">":
|
|
direction = "OUTGOING"
|
|
return direction
|
|
|
|
def extract_node_variable(self, part: str) -> Optional[str]:
|
|
"""
|
|
Args:
|
|
part: node in string format
|
|
"""
|
|
part = part.lstrip("(").rstrip(")")
|
|
idx = part.find(":")
|
|
if idx != -1:
|
|
part = part[:idx]
|
|
return None if part == "" else part
|
|
|
|
def detect_labels(
|
|
self, str_node: str, node_variable_dict: Dict[str, Any]
|
|
) -> List[str]:
|
|
"""
|
|
Args:
|
|
str_node: node in string format
|
|
node_variable_dict: dictionary of node variables
|
|
"""
|
|
splitted_node = str_node.split(":")
|
|
variable = splitted_node[0]
|
|
labels = []
|
|
if variable in node_variable_dict:
|
|
labels = node_variable_dict[variable]
|
|
elif variable == "" and len(splitted_node) > 1:
|
|
labels = splitted_node[1:]
|
|
return labels
|
|
|
|
def verify_schema(
|
|
self,
|
|
from_node_labels: List[str],
|
|
relation_types: List[str],
|
|
to_node_labels: List[str],
|
|
) -> bool:
|
|
"""
|
|
Args:
|
|
from_node_labels: labels of the from node
|
|
relation_type: type of the relation
|
|
to_node_labels: labels of the to node
|
|
"""
|
|
valid_schemas = self.schemas
|
|
if from_node_labels != []:
|
|
from_node_labels = [label.strip("`") for label in from_node_labels]
|
|
valid_schemas = [
|
|
schema for schema in valid_schemas if schema[0] in from_node_labels
|
|
]
|
|
if to_node_labels != []:
|
|
to_node_labels = [label.strip("`") for label in to_node_labels]
|
|
valid_schemas = [
|
|
schema for schema in valid_schemas if schema[2] in to_node_labels
|
|
]
|
|
if relation_types != []:
|
|
relation_types = [type.strip("`") for type in relation_types]
|
|
valid_schemas = [
|
|
schema for schema in valid_schemas if schema[1] in relation_types
|
|
]
|
|
return valid_schemas != []
|
|
|
|
def detect_relation_types(self, str_relation: str) -> Tuple[str, List[str]]:
|
|
"""
|
|
Args:
|
|
str_relation: relation in string format
|
|
"""
|
|
relation_direction = self.judge_direction(str_relation)
|
|
relation_type = self.relation_type_pattern.search(str_relation)
|
|
if relation_type is None or relation_type.group("relation_type") is None:
|
|
return relation_direction, []
|
|
relation_types = [
|
|
t.strip().strip("!")
|
|
for t in relation_type.group("relation_type").split("|")
|
|
]
|
|
return relation_direction, relation_types
|
|
|
|
def correct_query(self, query: str) -> str:
|
|
"""
|
|
Args:
|
|
query: cypher query
|
|
"""
|
|
node_variable_dict = self.detect_node_variables(query)
|
|
paths = self.extract_paths(query)
|
|
for path in paths:
|
|
original_path = path
|
|
start_idx = 0
|
|
while start_idx < len(path):
|
|
match_res = re.match(self.node_relation_node_pattern, path[start_idx:])
|
|
if match_res is None:
|
|
break
|
|
start_idx += match_res.start()
|
|
match_dict = match_res.groupdict()
|
|
left_node_labels = self.detect_labels(
|
|
match_dict["left_node"], node_variable_dict
|
|
)
|
|
right_node_labels = self.detect_labels(
|
|
match_dict["right_node"], node_variable_dict
|
|
)
|
|
end_idx = (
|
|
start_idx
|
|
+ 4
|
|
+ len(match_dict["left_node"])
|
|
+ len(match_dict["relation"])
|
|
+ len(match_dict["right_node"])
|
|
)
|
|
original_partial_path = original_path[start_idx : end_idx + 1]
|
|
relation_direction, relation_types = self.detect_relation_types(
|
|
match_dict["relation"]
|
|
)
|
|
|
|
if relation_types != [] and "".join(relation_types).find("*") != -1:
|
|
start_idx += (
|
|
len(match_dict["left_node"]) + len(match_dict["relation"]) + 2
|
|
)
|
|
continue
|
|
|
|
if relation_direction == "OUTGOING":
|
|
is_legal = self.verify_schema(
|
|
left_node_labels, relation_types, right_node_labels
|
|
)
|
|
if not is_legal:
|
|
is_legal = self.verify_schema(
|
|
right_node_labels, relation_types, left_node_labels
|
|
)
|
|
if is_legal:
|
|
corrected_relation = "<" + match_dict["relation"][:-1]
|
|
corrected_partial_path = original_partial_path.replace(
|
|
match_dict["relation"], corrected_relation
|
|
)
|
|
query = query.replace(
|
|
original_partial_path, corrected_partial_path
|
|
)
|
|
else:
|
|
return ""
|
|
elif relation_direction == "INCOMING":
|
|
is_legal = self.verify_schema(
|
|
right_node_labels, relation_types, left_node_labels
|
|
)
|
|
if not is_legal:
|
|
is_legal = self.verify_schema(
|
|
left_node_labels, relation_types, right_node_labels
|
|
)
|
|
if is_legal:
|
|
corrected_relation = match_dict["relation"][1:] + ">"
|
|
corrected_partial_path = original_partial_path.replace(
|
|
match_dict["relation"], corrected_relation
|
|
)
|
|
query = query.replace(
|
|
original_partial_path, corrected_partial_path
|
|
)
|
|
else:
|
|
return ""
|
|
else:
|
|
is_legal = self.verify_schema(
|
|
left_node_labels, relation_types, right_node_labels
|
|
)
|
|
is_legal |= self.verify_schema(
|
|
right_node_labels, relation_types, left_node_labels
|
|
)
|
|
if not is_legal:
|
|
return ""
|
|
|
|
start_idx += (
|
|
len(match_dict["left_node"]) + len(match_dict["relation"]) + 2
|
|
)
|
|
return query
|
|
|
|
def __call__(self, query: str) -> str:
|
|
"""Correct the query to make it valid. If
|
|
Args:
|
|
query: cypher query
|
|
"""
|
|
return self.correct_query(query)
|