From 9373b9c004c402d96ac43b924994f2f83f91a254 Mon Sep 17 00:00:00 2001 From: Guy Korland Date: Tue, 17 Oct 2023 23:54:05 +0300 Subject: [PATCH] Add Graph interface (#11012) Replace this entire comment with: - **Description:** Add a Graph interface - **Tag maintainer:** @baskaryan @hwchase17 - **Twitter handle:** @g_korland --- .../langchain/chains/graph_qa/cypher.py | 6 +-- .../langchain/graphs/falkordb_graph.py | 8 ++-- .../langchain/langchain/graphs/graph_store.py | 37 +++++++++++++++++++ .../langchain/langchain/graphs/neo4j_graph.py | 13 ++++++- 4 files changed, 56 insertions(+), 8 deletions(-) create mode 100644 libs/langchain/langchain/graphs/graph_store.py diff --git a/libs/langchain/langchain/chains/graph_qa/cypher.py b/libs/langchain/langchain/chains/graph_qa/cypher.py index dd2a4cbb8f..1f59e52272 100644 --- a/libs/langchain/langchain/chains/graph_qa/cypher.py +++ b/libs/langchain/langchain/chains/graph_qa/cypher.py @@ -9,7 +9,7 @@ from langchain.chains.base import Chain from langchain.chains.graph_qa.cypher_utils import CypherQueryCorrector, Schema from langchain.chains.graph_qa.prompts import CYPHER_GENERATION_PROMPT, CYPHER_QA_PROMPT from langchain.chains.llm import LLMChain -from langchain.graphs.neo4j_graph import Neo4jGraph +from langchain.graphs.graph_store import GraphStore from langchain.pydantic_v1 import Field from langchain.schema import BasePromptTemplate from langchain.schema.language_model import BaseLanguageModel @@ -79,7 +79,7 @@ def construct_schema( class GraphCypherQAChain(Chain): """Chain for question-answering against a graph by generating Cypher statements.""" - graph: Neo4jGraph = Field(exclude=True) + graph: GraphStore = Field(exclude=True) cypher_generation_chain: LLMChain qa_chain: LLMChain graph_schema: str @@ -151,7 +151,7 @@ class GraphCypherQAChain(Chain): ) graph_schema = construct_schema( - kwargs["graph"].structured_schema, include_types, exclude_types + kwargs["graph"].get_structured_schema, include_types, exclude_types ) cypher_query_corrector = None diff --git a/libs/langchain/langchain/graphs/falkordb_graph.py b/libs/langchain/langchain/graphs/falkordb_graph.py index db2c0ce231..39aed25385 100644 --- a/libs/langchain/langchain/graphs/falkordb_graph.py +++ b/libs/langchain/langchain/graphs/falkordb_graph.py @@ -1,7 +1,7 @@ from typing import Any, Dict, List from langchain.graphs.graph_document import GraphDocument -from langchain.graphs.neo4j_graph import Neo4jGraph +from langchain.graphs.graph_store import GraphStore node_properties_query = """ MATCH (n) @@ -32,7 +32,7 @@ RETURN DISTINCT {start: src_label, type: rel_type, end: dst_label} AS output """ -class FalkorDBGraph(Neo4jGraph): +class FalkorDBGraph(GraphStore): """FalkorDB wrapper for graph operations. *Security note*: Make sure that the database connection uses credentials @@ -58,8 +58,8 @@ class FalkorDBGraph(Neo4jGraph): "Please install it with `pip install redis`." ) - driver = redis.Redis(host=host, port=port) - self._graph = Graph(driver, database) + self._driver = redis.Redis(host=host, port=port) + self._graph = Graph(self._driver, database) self.schema: str = "" self.structured_schema: Dict[str, Any] = {} diff --git a/libs/langchain/langchain/graphs/graph_store.py b/libs/langchain/langchain/graphs/graph_store.py new file mode 100644 index 0000000000..02dd3252c5 --- /dev/null +++ b/libs/langchain/langchain/graphs/graph_store.py @@ -0,0 +1,37 @@ +from abc import abstractmethod +from typing import Any, Dict, List + +from langchain.graphs.graph_document import GraphDocument + + +class GraphStore: + """An abstract class wrapper for graph operations.""" + + @property + @abstractmethod + def get_schema(self) -> str: + """Returns the schema of the Graph database""" + pass + + @property + @abstractmethod + def get_structured_schema(self) -> Dict[str, Any]: + """Returns the schema of the Graph database""" + pass + + @abstractmethod + def query(self, query: str, params: dict = {}) -> List[Dict[str, Any]]: + """Query the graph.""" + pass + + @abstractmethod + def refresh_schema(self) -> None: + """Refreshes the graph schema information.""" + pass + + @abstractmethod + def add_graph_documents( + self, graph_documents: List[GraphDocument], include_source: bool = False + ) -> None: + """Take GraphDocument as input as uses it to construct a graph.""" + pass diff --git a/libs/langchain/langchain/graphs/neo4j_graph.py b/libs/langchain/langchain/graphs/neo4j_graph.py index ee124f6323..df4230c6c4 100644 --- a/libs/langchain/langchain/graphs/neo4j_graph.py +++ b/libs/langchain/langchain/graphs/neo4j_graph.py @@ -1,6 +1,7 @@ from typing import Any, Dict, List from langchain.graphs.graph_document import GraphDocument +from langchain.graphs.graph_store import GraphStore node_properties_query = """ CALL apoc.meta.data() @@ -28,7 +29,7 @@ RETURN {start: label, type: property, end: toString(other_node)} AS output """ -class Neo4jGraph: +class Neo4jGraph(GraphStore): """Neo4j wrapper for graph operations. *Security note*: Make sure that the database connection uses credentials @@ -80,6 +81,16 @@ class Neo4jGraph: "'apoc.meta.data()' is allowed in Neo4j configuration " ) + @property + def get_schema(self) -> str: + """Returns the schema of the Graph""" + return self.schema + + @property + def get_structured_schema(self) -> Dict[str, Any]: + """Returns the structured schema of the Graph""" + return self.structured_schema + def query(self, query: str, params: dict = {}) -> List[Dict[str, Any]]: """Query Neo4j database.""" from neo4j.exceptions import CypherSyntaxError