diff --git a/libs/community/langchain_community/graphs/neo4j_graph.py b/libs/community/langchain_community/graphs/neo4j_graph.py index af416b29bc..92efd27f08 100644 --- a/libs/community/langchain_community/graphs/neo4j_graph.py +++ b/libs/community/langchain_community/graphs/neo4j_graph.py @@ -31,8 +31,50 @@ RETURN {start: label, type: property, end: toString(other_node)} AS output """ +def value_sanitize(d: Dict[str, Any]) -> Dict[str, Any]: + """ + Sanitizes the input dictionary by removing embedding-like values, + lists with more than 128 elements, that are mostly irrelevant for + generating answers in a LLM context. These properties, if left in + results, can occupy significant context space and detract from + the LLM's performance by introducing unnecessary noise and cost. + """ + LIST_LIMIT = 128 + # Create a new dictionary to avoid changing size during iteration + new_dict = {} + for key, value in d.items(): + if isinstance(value, dict): + # Recurse to handle nested dictionaries + new_dict[key] = value_sanitize(value) + elif isinstance(value, list): + # check if it has less than LIST_LIMIT values + if len(value) < LIST_LIMIT: + # if value is a list, check if it contains dictionaries to clean + cleaned_list = [] + for item in value: + if isinstance(item, dict): + cleaned_list.append(value_sanitize(item)) + else: + cleaned_list.append(item) + new_dict[key] = cleaned_list + else: + new_dict[key] = value + return new_dict + + class Neo4jGraph(GraphStore): - """Neo4j wrapper for graph operations. + """Provides a connection to a Neo4j database for various graph operations. + Parameters: + url (Optional[str]): The URL of the Neo4j database server. + username (Optional[str]): The username for database authentication. + password (Optional[str]): The password for database authentication. + database (str): The name of the database to connect to. Default is 'neo4j'. + timeout (Optional[float]): The timeout for transactions in seconds. + Useful for terminating long-running queries. + By default, there is no timeout set. + sanitize (bool): A flag to indicate whether to remove lists with + more than 128 elements from results. Useful for removing + embedding-like properties from database responses. Default is False. *Security note*: Make sure that the database connection uses credentials that are narrowly-scoped to only include necessary permissions. @@ -52,6 +94,8 @@ class Neo4jGraph(GraphStore): username: Optional[str] = None, password: Optional[str] = None, database: str = "neo4j", + timeout: Optional[float] = None, + sanitize: bool = False, ) -> None: """Create a new Neo4j graph wrapper instance.""" try: @@ -69,6 +113,8 @@ class Neo4jGraph(GraphStore): self._driver = neo4j.GraphDatabase.driver(url, auth=(username, password)) self._database = database + self.timeout = timeout + self.sanitize = sanitize self.schema: str = "" self.structured_schema: Dict[str, Any] = {} # Verify connection @@ -106,12 +152,16 @@ class Neo4jGraph(GraphStore): def query(self, query: str, params: dict = {}) -> List[Dict[str, Any]]: """Query Neo4j database.""" + from neo4j import Query from neo4j.exceptions import CypherSyntaxError with self._driver.session(database=self._database) as session: try: - data = session.run(query, params) - return [r.data() for r in data] + data = session.run(Query(text=query, timeout=self.timeout), params) + json_data = [r.data() for r in data] + if self.sanitize: + json_data = value_sanitize(json_data) + return json_data except CypherSyntaxError as e: raise ValueError(f"Generated Cypher Statement is not valid\n{e}") diff --git a/libs/community/tests/integration_tests/graphs/test_neo4j.py b/libs/community/tests/integration_tests/graphs/test_neo4j.py index fd9e3bb36f..948fe4c719 100644 --- a/libs/community/tests/integration_tests/graphs/test_neo4j.py +++ b/libs/community/tests/integration_tests/graphs/test_neo4j.py @@ -69,3 +69,22 @@ def test_cypher_return_correct_schema() -> None: sorted(relationships, key=lambda x: x["output"]["end"]) == expected_relationships ) + + +def test_neo4j_timeout() -> None: + """Test that neo4j uses the timeout correctly.""" + url = os.environ.get("NEO4J_URI") + username = os.environ.get("NEO4J_USERNAME") + password = os.environ.get("NEO4J_PASSWORD") + assert url is not None + assert username is not None + assert password is not None + + graph = Neo4jGraph(url=url, username=username, password=password, timeout=0.1) + try: + graph.query("UNWIND range(0,100000,1) AS i MERGE (:Foo {id:i})") + except Exception as e: + assert ( + e.code + == "Neo.ClientError.Transaction.TransactionTimedOutClientConfiguration" + ) diff --git a/libs/community/tests/unit_tests/graphs/test_neo4j_graph.py b/libs/community/tests/unit_tests/graphs/test_neo4j_graph.py new file mode 100644 index 0000000000..b352529ba6 --- /dev/null +++ b/libs/community/tests/unit_tests/graphs/test_neo4j_graph.py @@ -0,0 +1,32 @@ +from langchain_community.graphs.neo4j_graph import value_sanitize + + +def test_value_sanitize_with_small_list(): + small_list = list(range(15)) # list size > LIST_LIMIT + input_dict = {"key1": "value1", "small_list": small_list} + expected_output = {"key1": "value1", "small_list": small_list} + assert value_sanitize(input_dict) == expected_output + + +def test_value_sanitize_with_oversized_list(): + oversized_list = list(range(150)) # list size > LIST_LIMIT + input_dict = {"key1": "value1", "oversized_list": oversized_list} + expected_output = { + "key1": "value1" + # oversized_list should not be included + } + assert value_sanitize(input_dict) == expected_output + + +def test_value_sanitize_with_nested_oversized_list(): + oversized_list = list(range(150)) # list size > LIST_LIMIT + input_dict = {"key1": "value1", "oversized_list": {"key": oversized_list}} + expected_output = {"key1": "value1", "oversized_list": {}} + assert value_sanitize(input_dict) == expected_output + + +def test_value_sanitize_with_dict_in_list(): + oversized_list = list(range(150)) # list size > LIST_LIMIT + input_dict = {"key1": "value1", "oversized_list": [1, 2, {"key": oversized_list}]} + expected_output = {"key1": "value1", "oversized_list": [1, 2, {}]} + assert value_sanitize(input_dict) == expected_output