community[minor], langchain[minor], docs: Gremlin Graph Store and QA Chain (#17683)

- **Description:** 
New feature: Gremlin graph-store and QA chain (including docs).
Compatible with Azure CosmosDB.
  - **Dependencies:** 
  no changes
pull/18386/head
Petteri Johansson 4 months ago committed by GitHub
parent a5ccf5d33c
commit 6c1989d292
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -0,0 +1,239 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "c94240f5",
"metadata": {},
"source": [
"# Gremlin (with CosmosDB) QA chain\n",
"\n",
"This notebook shows how to use LLMs to provide a natural language interface to a graph database you can query with the Gremlin query language."
]
},
{
"cell_type": "markdown",
"id": "dbc0ee68",
"metadata": {},
"source": [
"You will need to have a Azure CosmosDB Graph database instance. One option is to create a [free CosmosDB Graph database instance in Azure](https://learn.microsoft.com/en-us/azure/cosmos-db/free-tier). \n",
"\n",
"When you create your Cosmos DB account and Graph, use /type as partition key."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "62812aad",
"metadata": {},
"outputs": [],
"source": [
"import nest_asyncio\n",
"from langchain.chains.graph_qa import GremlinQAChain\n",
"from langchain.schema import Document\n",
"from langchain_community.graphs import GremlinGraph\n",
"from langchain_community.graphs.graph_document import GraphDocument, Node, Relationship\n",
"from langchain_openai import AzureChatOpenAI"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0928915d",
"metadata": {},
"outputs": [],
"source": [
"cosmosdb_name = \"mycosmosdb\"\n",
"cosmosdb_db_id = \"graphtesting\"\n",
"cosmosdb_db_graph_id = \"mygraph\"\n",
"cosmosdb_access_Key = \"longstring==\"\n",
"\n",
"graph = GremlinGraph(\n",
" url=f\"=wss://{cosmosdb_name}.gremlin.cosmos.azure.com:443/\",\n",
" username=f\"/dbs/{cosmosdb_db_id}/colls/{cosmosdb_db_graph_id}\",\n",
" password=cosmosdb_access_Key,\n",
")"
]
},
{
"cell_type": "markdown",
"id": "995ea9b9",
"metadata": {},
"source": [
"## Seeding the database\n",
"\n",
"Assuming your database is empty, you can populate it using the GraphDocuments\n",
"\n",
"For Gremlin, always add property called 'label' for each Node.\n",
"If no label is set, Node.type is used as a label.\n",
"For cosmos using natural id's make sense, as they are visible in the graph explorer."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "fedd26b9",
"metadata": {},
"outputs": [],
"source": [
"source_doc = Document(\n",
" page_content=\"Matrix is a movie where Keanu Reeves, Laurence Fishburne and Carrie-Anne Moss acted.\"\n",
")\n",
"movie = Node(id=\"The Matrix\", properties={\"label\": \"movie\", \"title\": \"The Matrix\"})\n",
"actor1 = Node(id=\"Keanu Reeves\", properties={\"label\": \"actor\", \"name\": \"Keanu Reeves\"})\n",
"actor2 = Node(\n",
" id=\"Laurence Fishburne\", properties={\"label\": \"actor\", \"name\": \"Laurence Fishburne\"}\n",
")\n",
"actor3 = Node(\n",
" id=\"Carrie-Anne Moss\", properties={\"label\": \"actor\", \"name\": \"Carrie-Anne Moss\"}\n",
")\n",
"rel1 = Relationship(\n",
" id=5, type=\"ActedIn\", source=actor1, target=movie, properties={\"label\": \"ActedIn\"}\n",
")\n",
"rel2 = Relationship(\n",
" id=6, type=\"ActedIn\", source=actor2, target=movie, properties={\"label\": \"ActedIn\"}\n",
")\n",
"rel3 = Relationship(\n",
" id=7, type=\"ActedIn\", source=actor3, target=movie, properties={\"label\": \"ActedIn\"}\n",
")\n",
"rel4 = Relationship(\n",
" id=8,\n",
" type=\"Starring\",\n",
" source=movie,\n",
" target=actor1,\n",
" properties={\"label\": \"Strarring\"},\n",
")\n",
"rel5 = Relationship(\n",
" id=9,\n",
" type=\"Starring\",\n",
" source=movie,\n",
" target=actor2,\n",
" properties={\"label\": \"Strarring\"},\n",
")\n",
"rel6 = Relationship(\n",
" id=10,\n",
" type=\"Straring\",\n",
" source=movie,\n",
" target=actor3,\n",
" properties={\"label\": \"Strarring\"},\n",
")\n",
"graph_doc = GraphDocument(\n",
" nodes=[movie, actor1, actor2, actor3],\n",
" relationships=[rel1, rel2, rel3, rel4, rel5, rel6],\n",
" source=source_doc,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d18f77a3",
"metadata": {},
"outputs": [],
"source": [
"# The underlying python-gremlin has a problem when running in notebook\n",
"# The following line is a workaround to fix the problem\n",
"nest_asyncio.apply()\n",
"\n",
"# Add the document to the CosmosDB graph.\n",
"graph.add_graph_documents([graph_doc])"
]
},
{
"cell_type": "markdown",
"id": "58c1a8ea",
"metadata": {},
"source": [
"## Refresh graph schema information\n",
"If the schema of database changes (after updates), you can refresh the schema information.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4e3de44f",
"metadata": {},
"outputs": [],
"source": [
"graph.refresh_schema()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1fe76ccd",
"metadata": {},
"outputs": [],
"source": [
"print(graph.schema)"
]
},
{
"cell_type": "markdown",
"id": "68a3c677",
"metadata": {},
"source": [
"## Querying the graph\n",
"\n",
"We can now use the gremlin QA chain to ask question of the graph"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7476ce98",
"metadata": {},
"outputs": [],
"source": [
"chain = GremlinQAChain.from_llm(\n",
" AzureChatOpenAI(\n",
" temperature=0,\n",
" azure_deployment=\"gpt-4-turbo\",\n",
" ),\n",
" graph=graph,\n",
" verbose=True,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ef8ee27b",
"metadata": {},
"outputs": [],
"source": [
"chain.invoke(\"Who played in The Matrix?\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "47c64027-cf42-493a-9c76-2d10ba753728",
"metadata": {},
"outputs": [],
"source": [
"chain.run(\"How many people played in The Matrix?\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.13"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

@ -2,6 +2,7 @@
from langchain_community.graphs.arangodb_graph import ArangoGraph
from langchain_community.graphs.falkordb_graph import FalkorDBGraph
from langchain_community.graphs.gremlin_graph import GremlinGraph
from langchain_community.graphs.hugegraph import HugeGraph
from langchain_community.graphs.kuzu_graph import KuzuGraph
from langchain_community.graphs.memgraph_graph import MemgraphGraph
@ -28,4 +29,5 @@ __all__ = [
"FalkorDBGraph",
"TigerGraph",
"OntotextGraphDBGraph",
"GremlinGraph",
]

@ -0,0 +1,207 @@
import hashlib
import sys
from typing import Any, Dict, List, Optional, Union
from langchain_core.utils import get_from_env
from langchain_community.graphs.graph_document import GraphDocument, Node, Relationship
from langchain_community.graphs.graph_store import GraphStore
class GremlinGraph(GraphStore):
"""Gremlin wrapper for graph operations.
Parameters:
url (Optional[str]): The URL of the Gremlin database server or env GREMLIN_URI
username (Optional[str]): The collection-identifier like '/dbs/database/colls/graph'
or env GREMLIN_USERNAME if none provided
password (Optional[str]): The connection-key for database authentication
or env GREMLIN_PASSWORD if none provided
traversal_source (str): The traversal source to use for queries. Defaults to 'g'.
message_serializer (Optional[Any]): The message serializer to use for requests.
Defaults to serializer.GraphSONSerializersV2d0()
*Security note*: Make sure that the database connection uses credentials
that are narrowly-scoped to only include necessary permissions.
Failure to do so may result in data corruption or loss, since the calling
code may attempt commands that would result in deletion, mutation
of data if appropriately prompted or reading sensitive data if such
data is present in the database.
The best way to guard against such negative outcomes is to (as appropriate)
limit the permissions granted to the credentials used with this tool.
See https://python.langchain.com/docs/security for more information.
*Implementation details*:
The Gremlin queries are designed to work with Azure CosmosDB limitations
"""
@property
def get_structured_schema(self) -> Dict[str, Any]:
return self.structured_schema
def __init__(
self,
url: Optional[str] = None,
username: Optional[str] = None,
password: Optional[str] = None,
traversal_source: str = "g",
message_serializer: Optional[Any] = None,
) -> None:
"""Create a new Gremlin graph wrapper instance."""
try:
import asyncio
from gremlin_python.driver import client, serializer
if sys.platform == "win32":
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
except ImportError:
raise ValueError(
"Please install gremlin-python first: " "`pip3 install gremlinpython"
)
self.client = client.Client(
url=get_from_env("url", "GREMLIN_URI", url),
traversal_source=traversal_source,
username=get_from_env("username", "GREMLIN_USERNAME", username),
password=get_from_env("password", "GREMLIN_PASSWORD", password),
message_serializer=message_serializer
if message_serializer
else serializer.GraphSONSerializersV2d0(),
)
self.schema: str = ""
@property
def get_schema(self) -> str:
"""Returns the schema of the Gremlin database"""
if len(self.schema) == 0:
self.refresh_schema()
return self.schema
def refresh_schema(self) -> None:
"""
Refreshes the Gremlin graph schema information.
"""
vertex_schema = self.client.submit("g.V().label().dedup()").all().result()
edge_schema = self.client.submit("g.E().label().dedup()").all().result()
vertex_properties = (
self.client.submit(
"g.V().group().by(label).by(properties().label().dedup().fold())"
)
.all()
.result()[0]
)
self.structured_schema = {
"vertex_labels": vertex_schema,
"edge_labels": edge_schema,
"vertice_props": vertex_properties,
}
self.schema = "\n".join(
[
"Vertex labels are the following:",
",".join(vertex_schema),
"Edge labes are the following:",
",".join(edge_schema),
f"Vertices have following properties:\n{vertex_properties}",
]
)
def query(self, query: str, params: dict = {}) -> List[Dict[str, Any]]:
q = self.client.submit(query)
return q.all().result()
def add_graph_documents(
self, graph_documents: List[GraphDocument], include_source: bool = False
) -> None:
"""
Take GraphDocument as input as uses it to construct a graph.
"""
node_cache: Dict[Union[str, int], Node] = {}
for document in graph_documents:
if include_source:
# Create document vertex
doc_props = {
"page_content": document.source.page_content,
"metadata": document.source.metadata,
}
doc_id = hashlib.md5(document.source.page_content.encode()).hexdigest()
doc_node = self.add_node(
Node(id=doc_id, type="Document", properties=doc_props), node_cache
)
# Import nodes to vertices
for n in document.nodes:
node = self.add_node(n)
if include_source:
# Add Edge to document for each node
self.add_edge(
Relationship(
type="contains information about",
source=doc_node,
target=node,
properties={},
)
)
self.add_edge(
Relationship(
type="is extracted from",
source=node,
target=doc_node,
properties={},
)
)
# Edges
for el in document.relationships:
# Find or create the source vertex
self.add_node(el.source, node_cache)
# Find or create the target vertex
self.add_node(el.target, node_cache)
# Find or create the edge
self.add_edge(el)
def build_vertex_query(self, node: Node) -> str:
base_query = (
f"g.V().has('id','{node.id}').fold()"
+ f".coalesce(unfold(),addV('{node.type}')"
+ f".property('id','{node.id}')"
+ f".property('type','{node.type}')"
)
for key, value in node.properties.items():
base_query += f".property('{key}', '{value}')"
return base_query + ")"
def build_edge_query(self, relationship: Relationship) -> str:
source_query = f".has('id','{relationship.source.id}')"
target_query = f".has('id','{relationship.target.id}')"
base_query = f""""g.V(){source_query}.as('a')
.V(){target_query}.as('b')
.choose(
__.inE('{relationship.type}').where(outV().as('a')),
__.identity(),
__.addE('{relationship.type}').from('a').to('b')
)
""".replace("\n", "").replace("\t", "")
for key, value in relationship.properties.items():
base_query += f".property('{key}', '{value}')"
return base_query
def add_node(self, node: Node, node_cache: dict = {}) -> Node:
# if properties does not have label, add type as label
if "label" not in node.properties:
node.properties["label"] = node.type
if node.id in node_cache:
return node_cache[node.id]
else:
query = self.build_vertex_query(node)
_ = self.client.submit(query).all().result()[0]
node_cache[node.id] = node
return node
def add_edge(self, relationship: Relationship) -> Any:
query = self.build_edge_query(relationship)
return self.client.submit(query).all().result()

@ -14,6 +14,7 @@ EXPECTED_ALL = [
"FalkorDBGraph",
"TigerGraph",
"OntotextGraphDBGraph",
"GremlinGraph",
]

@ -0,0 +1,221 @@
"""Question answering over a graph."""
from __future__ import annotations
from typing import Any, Dict, List, Optional
from langchain_community.graphs import GremlinGraph
from langchain_core.callbacks.manager import CallbackManager, CallbackManagerForChainRun
from langchain_core.language_models import BaseLanguageModel
from langchain_core.prompts import BasePromptTemplate
from langchain_core.prompts.prompt import PromptTemplate
from langchain_core.pydantic_v1 import Field
from langchain.chains.base import Chain
from langchain.chains.graph_qa.prompts import (
CYPHER_QA_PROMPT,
GRAPHDB_SPARQL_FIX_TEMPLATE,
GREMLIN_GENERATION_PROMPT,
)
from langchain.chains.llm import LLMChain
INTERMEDIATE_STEPS_KEY = "intermediate_steps"
def extract_gremlin(text: str) -> str:
"""Extract Gremlin code from a text.
Args:
text: Text to extract Gremlin code from.
Returns:
Gremlin code extracted from the text.
"""
text = text.replace("`", "")
if text.startswith("gremlin"):
text = text[len("gremlin") :]
return text.replace("\n", "")
class GremlinQAChain(Chain):
"""Chain for question-answering against a graph by generating gremlin statements.
*Security note*: Make sure that the database connection uses credentials
that are narrowly-scoped to only include necessary permissions.
Failure to do so may result in data corruption or loss, since the calling
code may attempt commands that would result in deletion, mutation
of data if appropriately prompted or reading sensitive data if such
data is present in the database.
The best way to guard against such negative outcomes is to (as appropriate)
limit the permissions granted to the credentials used with this tool.
See https://python.langchain.com/docs/security for more information.
"""
graph: GremlinGraph = Field(exclude=True)
gremlin_generation_chain: LLMChain
qa_chain: LLMChain
gremlin_fix_chain: LLMChain
max_fix_retries: int = 3
input_key: str = "query" #: :meta private:
output_key: str = "result" #: :meta private:
top_k: int = 100
return_direct: bool = False
return_intermediate_steps: bool = False
@property
def input_keys(self) -> List[str]:
"""Input keys.
:meta private:
"""
return [self.input_key]
@property
def output_keys(self) -> List[str]:
"""Output keys.
:meta private:
"""
_output_keys = [self.output_key]
return _output_keys
@classmethod
def from_llm(
cls,
llm: BaseLanguageModel,
*,
gremlin_fix_prompt: BasePromptTemplate = PromptTemplate(
input_variables=["error_message", "generated_sparql", "schema"],
template=GRAPHDB_SPARQL_FIX_TEMPLATE.replace("SPARQL", "Gremlin").replace(
"in Turtle format", ""
),
),
qa_prompt: BasePromptTemplate = CYPHER_QA_PROMPT,
gremlin_prompt: BasePromptTemplate = GREMLIN_GENERATION_PROMPT,
**kwargs: Any,
) -> GremlinQAChain:
"""Initialize from LLM."""
qa_chain = LLMChain(llm=llm, prompt=qa_prompt)
gremlin_generation_chain = LLMChain(llm=llm, prompt=gremlin_prompt)
gremlinl_fix_chain = LLMChain(llm=llm, prompt=gremlin_fix_prompt)
return cls(
qa_chain=qa_chain,
gremlin_generation_chain=gremlin_generation_chain,
gremlin_fix_chain=gremlinl_fix_chain,
**kwargs,
)
def _call(
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, str]:
"""Generate gremlin statement, use it to look up in db and answer question."""
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
callbacks = _run_manager.get_child()
question = inputs[self.input_key]
intermediate_steps: List = []
chain_response = self.gremlin_generation_chain.invoke(
{"question": question, "schema": self.graph.get_schema}, callbacks=callbacks
)
generated_gremlin = extract_gremlin(
chain_response[self.gremlin_generation_chain.output_key]
)
_run_manager.on_text("Generated gremlin:", end="\n", verbose=self.verbose)
_run_manager.on_text(
generated_gremlin, color="green", end="\n", verbose=self.verbose
)
intermediate_steps.append({"query": generated_gremlin})
if generated_gremlin:
context = self.execute_with_retry(
_run_manager, callbacks, generated_gremlin
)[: self.top_k]
else:
context = []
if self.return_direct:
final_result = context
else:
_run_manager.on_text("Full Context:", end="\n", verbose=self.verbose)
_run_manager.on_text(
str(context), color="green", end="\n", verbose=self.verbose
)
intermediate_steps.append({"context": context})
result = self.qa_chain.invoke(
{"question": question, "context": context},
callbacks=callbacks,
)
final_result = result[self.qa_chain.output_key]
chain_result: Dict[str, Any] = {self.output_key: final_result}
if self.return_intermediate_steps:
chain_result[INTERMEDIATE_STEPS_KEY] = intermediate_steps
return chain_result
def execute_query(self, query: str) -> List[Any]:
try:
return self.graph.query(query)
except Exception as e:
if hasattr(e, "status_message"):
raise ValueError(e.status_message)
else:
raise ValueError(str(e))
def execute_with_retry(
self,
_run_manager: CallbackManagerForChainRun,
callbacks: CallbackManager,
generated_gremlin: str,
) -> List[Any]:
try:
return self.execute_query(generated_gremlin)
except Exception as e:
retries = 0
error_message = str(e)
self.log_invalid_query(_run_manager, generated_gremlin, error_message)
while retries < self.max_fix_retries:
try:
fix_chain_result = self.gremlin_fix_chain.invoke(
{
"error_message": error_message,
# we are borrowing template from sparql
"generated_sparql": generated_gremlin,
"schema": self.schema,
},
callbacks=callbacks,
)
fixed_gremlin = fix_chain_result[self.gremlin_fix_chain.output_key]
return self.execute_query(fixed_gremlin)
except Exception as e:
retries += 1
parse_exception = str(e)
self.log_invalid_query(_run_manager, fixed_gremlin, parse_exception)
raise ValueError("The generated Gremlin query is invalid.")
def log_invalid_query(
self,
_run_manager: CallbackManagerForChainRun,
generated_query: str,
error_message: str,
) -> None:
_run_manager.on_text("Invalid Gremlin query: ", end="\n", verbose=self.verbose)
_run_manager.on_text(
generated_query, color="red", end="\n", verbose=self.verbose
)
_run_manager.on_text(
"Gremlin Query Parse Error: ", end="\n", verbose=self.verbose
)
_run_manager.on_text(
error_message, color="red", end="\n\n", verbose=self.verbose
)
Loading…
Cancel
Save