Add support for Falkordb (ex-RedisGraph) (#9821)

Replace this entire comment with:
  - Description: Add support for Falkordb (ex-RedisGraph)
  - Tag maintainer: @hwchase17
  - Twitter handle: @g_korland
This commit is contained in:
Guy Korland 2023-08-30 00:22:33 +03:00 committed by GitHub
parent fbd792ac7c
commit 7cbe872af8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 400 additions and 0 deletions

View File

@ -0,0 +1,154 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# FalkorDBQAChain"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This notebook shows how to use LLMs to provide a natural language interface to FalkorDB database.\n",
"\n",
"FalkorDB is a low latency property graph database management system. You can simply run its docker locally:\n",
"\n",
"```bash\n",
"docker run -p 6379:6379 -it --rm falkordb/falkordb:edge\n",
"```\n",
"\n",
"Once launched, you can simply start creating a database on the local machine and connect to it."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"from langchain.chat_models import ChatOpenAI\n",
"from langchain.graphs import FalkorDBGraph\n",
"from langchain.chains import FalkorDBQAChain"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"graph = FalkorDBGraph(database=\"movies\")"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[]"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"graph.query(\n",
" \"\"\"\n",
"MERGE (m:Movie {name:\"Top Gun\"})\n",
"WITH m\n",
"UNWIND [\"Tom Cruise\", \"Val Kilmer\", \"Anthony Edwards\", \"Meg Ryan\"] AS actor\n",
"MERGE (a:Actor {name:actor})\n",
"MERGE (a)-[:ACTED_IN]->(m)\n",
"\"\"\"\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"graph.refresh_schema()\n",
"import os\n",
"os.environ['OPENAI_API_KEY']='API_KEY_HERE'\n"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"chain = FalkorDBQAChain.from_llm(\n",
" ChatOpenAI(temperature=0), graph=graph, verbose=True\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
"\u001b[1m> Entering new FalkorDBQAChain chain...\u001b[0m\n",
"Generated Cypher:\n",
"\u001b[32;1m\u001b[1;3mMATCH (:Movie {title: 'Top Gun'})<-[:ACTED_IN]-(actor:Person)\n",
"RETURN actor.name AS output\u001b[0m\n",
"Full Context:\n",
"\u001b[32;1m\u001b[1;3m[]\u001b[0m\n",
"\n",
"\u001b[1m> Finished chain.\u001b[0m\n"
]
},
{
"data": {
"text/plain": [
"'The actor who played in Top Gun is Tom Cruise.'"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"chain.run(\"Who played in Top Gun?\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.4"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}

View File

@ -36,6 +36,7 @@ from langchain.chains.flare.base import FlareChain
from langchain.chains.graph_qa.arangodb import ArangoGraphQAChain
from langchain.chains.graph_qa.base import GraphQAChain
from langchain.chains.graph_qa.cypher import GraphCypherQAChain
from langchain.chains.graph_qa.falkordb import FalkorDBQAChain
from langchain.chains.graph_qa.hugegraph import HugeGraphQAChain
from langchain.chains.graph_qa.kuzu import KuzuQAChain
from langchain.chains.graph_qa.nebulagraph import NebulaGraphQAChain
@ -85,6 +86,7 @@ __all__ = [
"ConstitutionalChain",
"ConversationChain",
"ConversationalRetrievalChain",
"FalkorDBQAChain",
"FlareChain",
"GraphCypherQAChain",
"GraphQAChain",

View File

@ -0,0 +1,141 @@
"""Question answering over a graph."""
from __future__ import annotations
import re
from typing import Any, Dict, List, Optional
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.manager import CallbackManagerForChainRun
from langchain.chains.base import Chain
from langchain.chains.graph_qa.prompts import CYPHER_GENERATION_PROMPT, CYPHER_QA_PROMPT
from langchain.chains.llm import LLMChain
from langchain.graphs import FalkorDBGraph
from langchain.pydantic_v1 import Field
from langchain.schema import BasePromptTemplate
INTERMEDIATE_STEPS_KEY = "intermediate_steps"
def extract_cypher(text: str) -> str:
"""
Extract Cypher code from a text.
Args:
text: Text to extract Cypher code from.
Returns:
Cypher code extracted from the text.
"""
# The pattern to find Cypher code enclosed in triple backticks
pattern = r"```(.*?)```"
# Find all matches in the input text
matches = re.findall(pattern, text, re.DOTALL)
return matches[0] if matches else text
class FalkorDBQAChain(Chain):
"""Chain for question-answering against a graph by generating Cypher statements."""
graph: FalkorDBGraph = Field(exclude=True)
cypher_generation_chain: LLMChain
qa_chain: LLMChain
input_key: str = "query" #: :meta private:
output_key: str = "result" #: :meta private:
top_k: int = 10
"""Number of results to return from the query"""
return_intermediate_steps: bool = False
"""Whether or not to return the intermediate steps along with the final answer."""
return_direct: bool = False
"""Whether or not to return the result of querying the graph directly."""
@property
def input_keys(self) -> List[str]:
"""Return the input keys.
:meta private:
"""
return [self.input_key]
@property
def output_keys(self) -> List[str]:
"""Return the output keys.
:meta private:
"""
_output_keys = [self.output_key]
return _output_keys
@property
def _chain_type(self) -> str:
return "graph_cypher_chain"
@classmethod
def from_llm(
cls,
llm: BaseLanguageModel,
*,
qa_prompt: BasePromptTemplate = CYPHER_QA_PROMPT,
cypher_prompt: BasePromptTemplate = CYPHER_GENERATION_PROMPT,
**kwargs: Any,
) -> FalkorDBQAChain:
"""Initialize from LLM."""
qa_chain = LLMChain(llm=llm, prompt=qa_prompt)
cypher_generation_chain = LLMChain(llm=llm, prompt=cypher_prompt)
return cls(
qa_chain=qa_chain,
cypher_generation_chain=cypher_generation_chain,
**kwargs,
)
def _call(
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
"""Generate Cypher statement, use it to look up in db and answer question."""
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
callbacks = _run_manager.get_child()
question = inputs[self.input_key]
intermediate_steps: List = []
generated_cypher = self.cypher_generation_chain.run(
{"question": question, "schema": self.graph.schema}, callbacks=callbacks
)
# Extract Cypher code if it is wrapped in backticks
generated_cypher = extract_cypher(generated_cypher)
_run_manager.on_text("Generated Cypher:", end="\n", verbose=self.verbose)
_run_manager.on_text(
generated_cypher, color="green", end="\n", verbose=self.verbose
)
intermediate_steps.append({"query": generated_cypher})
# Retrieve and limit the number of results
context = self.graph.query(generated_cypher)[: self.top_k]
if self.return_direct:
final_result = context
else:
_run_manager.on_text("Full Context:", end="\n", verbose=self.verbose)
_run_manager.on_text(
str(context), color="green", end="\n", verbose=self.verbose
)
intermediate_steps.append({"context": context})
result = self.qa_chain(
{"question": question, "context": context},
callbacks=callbacks,
)
final_result = result[self.qa_chain.output_key]
chain_result: Dict[str, Any] = {self.output_key: final_result}
if self.return_intermediate_steps:
chain_result[INTERMEDIATE_STEPS_KEY] = intermediate_steps
return chain_result

View File

@ -1,6 +1,7 @@
"""**Graphs** provide a natural language interface to graph databases."""
from langchain.graphs.arangodb_graph import ArangoGraph
from langchain.graphs.falkordb_graph import FalkorDBGraph
from langchain.graphs.hugegraph import HugeGraph
from langchain.graphs.kuzu_graph import KuzuGraph
from langchain.graphs.memgraph_graph import MemgraphGraph
@ -20,4 +21,5 @@ __all__ = [
"HugeGraph",
"RdfGraph",
"ArangoGraph",
"FalkorDBGraph",
]

View File

@ -0,0 +1,67 @@
from typing import Any, Dict, List
node_properties_query = """
MATCH (n)
UNWIND labels(n) as l
UNWIND keys(n) as p
RETURN {label:l, properties: collect(distinct p)} AS output
"""
rel_properties_query = """
MATCH ()-[r]->()
UNWIND keys(r) as p
RETURN {type:type(r), properties: collect(distinct p)} AS output
"""
rel_query = """
MATCH (n)-[r]->(m)
WITH labels(n)[0] AS src, labels(m)[0] AS dst, type(r) AS type
RETURN DISTINCT "(:" + src + ")-[:" + type + "]->(:" + dst + ")" AS output
"""
class FalkorDBGraph:
"""FalkorDB wrapper for graph operations."""
def __init__(
self, database: str, host: str = "localhost", port: int = 6379
) -> None:
"""Create a new FalkorDB graph wrapper instance."""
try:
import redis
from redis.commands.graph import Graph
except ImportError:
raise ImportError(
"Could not import redis python package. "
"Please install it with `pip install redis`."
)
self._driver = redis.Redis(host=host, port=port)
self._graph = Graph(self._driver, database)
try:
self.refresh_schema()
except Exception as e:
raise ValueError(f"Could not refresh schema. Error: {e}")
@property
def get_schema(self) -> str:
"""Returns the schema of the FalkorDB database"""
return self.schema
def refresh_schema(self) -> None:
"""Refreshes the schema of the FalkorDB database"""
self.schema = (
f"Node properties: {node_properties_query}\n"
f"Relationships properties: {rel_properties_query}\n"
f"Relationships: {rel_query}\n"
)
def query(self, query: str, params: dict = {}) -> List[Dict[str, Any]]:
"""Query FalkorDB database."""
try:
data = self._graph.query(query, params)
return data.result_set
except Exception as e:
raise ValueError("Generated Cypher Statement is not valid\n" f"{e}")

View File

@ -0,0 +1,34 @@
import unittest
from typing import Any
from unittest.mock import MagicMock, patch
from langchain.graphs import FalkorDBGraph
class TestFalkorDB(unittest.TestCase):
def setUp(self) -> None:
self.host = "localhost"
self.graph = "test_falkordb"
self.port = 6379
@patch("redis.Redis")
def test_init(self, mock_client: Any) -> None:
mock_client.return_value = MagicMock()
FalkorDBGraph(database=self.graph, host=self.host, port=self.port)
@patch("redis.Redis")
def test_execute(self, mock_client: Any) -> None:
mock_client.return_value = MagicMock()
graph = FalkorDBGraph(database=self.graph, host=self.host, port=self.port)
query = "RETURN 1"
result = graph.query(query)
self.assertIsInstance(result, MagicMock)
@patch("redis.Redis")
def test_refresh_schema(self, mock_client: Any) -> None:
mock_client.return_value = MagicMock()
graph = FalkorDBGraph(database=self.graph, host=self.host, port=self.port)
graph.refresh_schema()
self.assertNotEqual(graph.get_schema, "")