mirror of https://github.com/hwchase17/langchain
community[minor], langchain[minor], docs: Gremlin Graph Store and QA Chain (#17683)
- **Description:** New feature: Gremlin graph-store and QA chain (including docs). Compatible with Azure CosmosDB. - **Dependencies:** no changespull/18386/head
parent
a5ccf5d33c
commit
6c1989d292
@ -0,0 +1,239 @@
|
|||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "c94240f5",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"# Gremlin (with CosmosDB) QA chain\n",
|
||||||
|
"\n",
|
||||||
|
"This notebook shows how to use LLMs to provide a natural language interface to a graph database you can query with the Gremlin query language."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "dbc0ee68",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"You will need to have a Azure CosmosDB Graph database instance. One option is to create a [free CosmosDB Graph database instance in Azure](https://learn.microsoft.com/en-us/azure/cosmos-db/free-tier). \n",
|
||||||
|
"\n",
|
||||||
|
"When you create your Cosmos DB account and Graph, use /type as partition key."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "62812aad",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"import nest_asyncio\n",
|
||||||
|
"from langchain.chains.graph_qa import GremlinQAChain\n",
|
||||||
|
"from langchain.schema import Document\n",
|
||||||
|
"from langchain_community.graphs import GremlinGraph\n",
|
||||||
|
"from langchain_community.graphs.graph_document import GraphDocument, Node, Relationship\n",
|
||||||
|
"from langchain_openai import AzureChatOpenAI"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "0928915d",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"cosmosdb_name = \"mycosmosdb\"\n",
|
||||||
|
"cosmosdb_db_id = \"graphtesting\"\n",
|
||||||
|
"cosmosdb_db_graph_id = \"mygraph\"\n",
|
||||||
|
"cosmosdb_access_Key = \"longstring==\"\n",
|
||||||
|
"\n",
|
||||||
|
"graph = GremlinGraph(\n",
|
||||||
|
" url=f\"=wss://{cosmosdb_name}.gremlin.cosmos.azure.com:443/\",\n",
|
||||||
|
" username=f\"/dbs/{cosmosdb_db_id}/colls/{cosmosdb_db_graph_id}\",\n",
|
||||||
|
" password=cosmosdb_access_Key,\n",
|
||||||
|
")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "995ea9b9",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Seeding the database\n",
|
||||||
|
"\n",
|
||||||
|
"Assuming your database is empty, you can populate it using the GraphDocuments\n",
|
||||||
|
"\n",
|
||||||
|
"For Gremlin, always add property called 'label' for each Node.\n",
|
||||||
|
"If no label is set, Node.type is used as a label.\n",
|
||||||
|
"For cosmos using natural id's make sense, as they are visible in the graph explorer."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "fedd26b9",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"source_doc = Document(\n",
|
||||||
|
" page_content=\"Matrix is a movie where Keanu Reeves, Laurence Fishburne and Carrie-Anne Moss acted.\"\n",
|
||||||
|
")\n",
|
||||||
|
"movie = Node(id=\"The Matrix\", properties={\"label\": \"movie\", \"title\": \"The Matrix\"})\n",
|
||||||
|
"actor1 = Node(id=\"Keanu Reeves\", properties={\"label\": \"actor\", \"name\": \"Keanu Reeves\"})\n",
|
||||||
|
"actor2 = Node(\n",
|
||||||
|
" id=\"Laurence Fishburne\", properties={\"label\": \"actor\", \"name\": \"Laurence Fishburne\"}\n",
|
||||||
|
")\n",
|
||||||
|
"actor3 = Node(\n",
|
||||||
|
" id=\"Carrie-Anne Moss\", properties={\"label\": \"actor\", \"name\": \"Carrie-Anne Moss\"}\n",
|
||||||
|
")\n",
|
||||||
|
"rel1 = Relationship(\n",
|
||||||
|
" id=5, type=\"ActedIn\", source=actor1, target=movie, properties={\"label\": \"ActedIn\"}\n",
|
||||||
|
")\n",
|
||||||
|
"rel2 = Relationship(\n",
|
||||||
|
" id=6, type=\"ActedIn\", source=actor2, target=movie, properties={\"label\": \"ActedIn\"}\n",
|
||||||
|
")\n",
|
||||||
|
"rel3 = Relationship(\n",
|
||||||
|
" id=7, type=\"ActedIn\", source=actor3, target=movie, properties={\"label\": \"ActedIn\"}\n",
|
||||||
|
")\n",
|
||||||
|
"rel4 = Relationship(\n",
|
||||||
|
" id=8,\n",
|
||||||
|
" type=\"Starring\",\n",
|
||||||
|
" source=movie,\n",
|
||||||
|
" target=actor1,\n",
|
||||||
|
" properties={\"label\": \"Strarring\"},\n",
|
||||||
|
")\n",
|
||||||
|
"rel5 = Relationship(\n",
|
||||||
|
" id=9,\n",
|
||||||
|
" type=\"Starring\",\n",
|
||||||
|
" source=movie,\n",
|
||||||
|
" target=actor2,\n",
|
||||||
|
" properties={\"label\": \"Strarring\"},\n",
|
||||||
|
")\n",
|
||||||
|
"rel6 = Relationship(\n",
|
||||||
|
" id=10,\n",
|
||||||
|
" type=\"Straring\",\n",
|
||||||
|
" source=movie,\n",
|
||||||
|
" target=actor3,\n",
|
||||||
|
" properties={\"label\": \"Strarring\"},\n",
|
||||||
|
")\n",
|
||||||
|
"graph_doc = GraphDocument(\n",
|
||||||
|
" nodes=[movie, actor1, actor2, actor3],\n",
|
||||||
|
" relationships=[rel1, rel2, rel3, rel4, rel5, rel6],\n",
|
||||||
|
" source=source_doc,\n",
|
||||||
|
")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "d18f77a3",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# The underlying python-gremlin has a problem when running in notebook\n",
|
||||||
|
"# The following line is a workaround to fix the problem\n",
|
||||||
|
"nest_asyncio.apply()\n",
|
||||||
|
"\n",
|
||||||
|
"# Add the document to the CosmosDB graph.\n",
|
||||||
|
"graph.add_graph_documents([graph_doc])"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "58c1a8ea",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Refresh graph schema information\n",
|
||||||
|
"If the schema of database changes (after updates), you can refresh the schema information.\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "4e3de44f",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"graph.refresh_schema()"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "1fe76ccd",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"print(graph.schema)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "68a3c677",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Querying the graph\n",
|
||||||
|
"\n",
|
||||||
|
"We can now use the gremlin QA chain to ask question of the graph"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "7476ce98",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"chain = GremlinQAChain.from_llm(\n",
|
||||||
|
" AzureChatOpenAI(\n",
|
||||||
|
" temperature=0,\n",
|
||||||
|
" azure_deployment=\"gpt-4-turbo\",\n",
|
||||||
|
" ),\n",
|
||||||
|
" graph=graph,\n",
|
||||||
|
" verbose=True,\n",
|
||||||
|
")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "ef8ee27b",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"chain.invoke(\"Who played in The Matrix?\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "47c64027-cf42-493a-9c76-2d10ba753728",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"chain.run(\"How many people played in The Matrix?\")"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "Python 3 (ipykernel)",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.9.13"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 5
|
||||||
|
}
|
@ -0,0 +1,207 @@
|
|||||||
|
import hashlib
|
||||||
|
import sys
|
||||||
|
from typing import Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
|
from langchain_core.utils import get_from_env
|
||||||
|
|
||||||
|
from langchain_community.graphs.graph_document import GraphDocument, Node, Relationship
|
||||||
|
from langchain_community.graphs.graph_store import GraphStore
|
||||||
|
|
||||||
|
|
||||||
|
class GremlinGraph(GraphStore):
|
||||||
|
"""Gremlin wrapper for graph operations.
|
||||||
|
Parameters:
|
||||||
|
url (Optional[str]): The URL of the Gremlin database server or env GREMLIN_URI
|
||||||
|
username (Optional[str]): The collection-identifier like '/dbs/database/colls/graph'
|
||||||
|
or env GREMLIN_USERNAME if none provided
|
||||||
|
password (Optional[str]): The connection-key for database authentication
|
||||||
|
or env GREMLIN_PASSWORD if none provided
|
||||||
|
traversal_source (str): The traversal source to use for queries. Defaults to 'g'.
|
||||||
|
message_serializer (Optional[Any]): The message serializer to use for requests.
|
||||||
|
Defaults to serializer.GraphSONSerializersV2d0()
|
||||||
|
*Security note*: Make sure that the database connection uses credentials
|
||||||
|
that are narrowly-scoped to only include necessary permissions.
|
||||||
|
Failure to do so may result in data corruption or loss, since the calling
|
||||||
|
code may attempt commands that would result in deletion, mutation
|
||||||
|
of data if appropriately prompted or reading sensitive data if such
|
||||||
|
data is present in the database.
|
||||||
|
The best way to guard against such negative outcomes is to (as appropriate)
|
||||||
|
limit the permissions granted to the credentials used with this tool.
|
||||||
|
|
||||||
|
See https://python.langchain.com/docs/security for more information.
|
||||||
|
|
||||||
|
*Implementation details*:
|
||||||
|
The Gremlin queries are designed to work with Azure CosmosDB limitations
|
||||||
|
"""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def get_structured_schema(self) -> Dict[str, Any]:
|
||||||
|
return self.structured_schema
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
url: Optional[str] = None,
|
||||||
|
username: Optional[str] = None,
|
||||||
|
password: Optional[str] = None,
|
||||||
|
traversal_source: str = "g",
|
||||||
|
message_serializer: Optional[Any] = None,
|
||||||
|
) -> None:
|
||||||
|
"""Create a new Gremlin graph wrapper instance."""
|
||||||
|
try:
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
from gremlin_python.driver import client, serializer
|
||||||
|
|
||||||
|
if sys.platform == "win32":
|
||||||
|
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
|
||||||
|
except ImportError:
|
||||||
|
raise ValueError(
|
||||||
|
"Please install gremlin-python first: " "`pip3 install gremlinpython"
|
||||||
|
)
|
||||||
|
|
||||||
|
self.client = client.Client(
|
||||||
|
url=get_from_env("url", "GREMLIN_URI", url),
|
||||||
|
traversal_source=traversal_source,
|
||||||
|
username=get_from_env("username", "GREMLIN_USERNAME", username),
|
||||||
|
password=get_from_env("password", "GREMLIN_PASSWORD", password),
|
||||||
|
message_serializer=message_serializer
|
||||||
|
if message_serializer
|
||||||
|
else serializer.GraphSONSerializersV2d0(),
|
||||||
|
)
|
||||||
|
self.schema: str = ""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def get_schema(self) -> str:
|
||||||
|
"""Returns the schema of the Gremlin database"""
|
||||||
|
if len(self.schema) == 0:
|
||||||
|
self.refresh_schema()
|
||||||
|
return self.schema
|
||||||
|
|
||||||
|
def refresh_schema(self) -> None:
|
||||||
|
"""
|
||||||
|
Refreshes the Gremlin graph schema information.
|
||||||
|
"""
|
||||||
|
vertex_schema = self.client.submit("g.V().label().dedup()").all().result()
|
||||||
|
edge_schema = self.client.submit("g.E().label().dedup()").all().result()
|
||||||
|
vertex_properties = (
|
||||||
|
self.client.submit(
|
||||||
|
"g.V().group().by(label).by(properties().label().dedup().fold())"
|
||||||
|
)
|
||||||
|
.all()
|
||||||
|
.result()[0]
|
||||||
|
)
|
||||||
|
|
||||||
|
self.structured_schema = {
|
||||||
|
"vertex_labels": vertex_schema,
|
||||||
|
"edge_labels": edge_schema,
|
||||||
|
"vertice_props": vertex_properties,
|
||||||
|
}
|
||||||
|
|
||||||
|
self.schema = "\n".join(
|
||||||
|
[
|
||||||
|
"Vertex labels are the following:",
|
||||||
|
",".join(vertex_schema),
|
||||||
|
"Edge labes are the following:",
|
||||||
|
",".join(edge_schema),
|
||||||
|
f"Vertices have following properties:\n{vertex_properties}",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
def query(self, query: str, params: dict = {}) -> List[Dict[str, Any]]:
|
||||||
|
q = self.client.submit(query)
|
||||||
|
return q.all().result()
|
||||||
|
|
||||||
|
def add_graph_documents(
|
||||||
|
self, graph_documents: List[GraphDocument], include_source: bool = False
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Take GraphDocument as input as uses it to construct a graph.
|
||||||
|
"""
|
||||||
|
node_cache: Dict[Union[str, int], Node] = {}
|
||||||
|
for document in graph_documents:
|
||||||
|
if include_source:
|
||||||
|
# Create document vertex
|
||||||
|
doc_props = {
|
||||||
|
"page_content": document.source.page_content,
|
||||||
|
"metadata": document.source.metadata,
|
||||||
|
}
|
||||||
|
doc_id = hashlib.md5(document.source.page_content.encode()).hexdigest()
|
||||||
|
doc_node = self.add_node(
|
||||||
|
Node(id=doc_id, type="Document", properties=doc_props), node_cache
|
||||||
|
)
|
||||||
|
|
||||||
|
# Import nodes to vertices
|
||||||
|
for n in document.nodes:
|
||||||
|
node = self.add_node(n)
|
||||||
|
if include_source:
|
||||||
|
# Add Edge to document for each node
|
||||||
|
self.add_edge(
|
||||||
|
Relationship(
|
||||||
|
type="contains information about",
|
||||||
|
source=doc_node,
|
||||||
|
target=node,
|
||||||
|
properties={},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.add_edge(
|
||||||
|
Relationship(
|
||||||
|
type="is extracted from",
|
||||||
|
source=node,
|
||||||
|
target=doc_node,
|
||||||
|
properties={},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Edges
|
||||||
|
for el in document.relationships:
|
||||||
|
# Find or create the source vertex
|
||||||
|
self.add_node(el.source, node_cache)
|
||||||
|
# Find or create the target vertex
|
||||||
|
self.add_node(el.target, node_cache)
|
||||||
|
# Find or create the edge
|
||||||
|
self.add_edge(el)
|
||||||
|
|
||||||
|
def build_vertex_query(self, node: Node) -> str:
|
||||||
|
base_query = (
|
||||||
|
f"g.V().has('id','{node.id}').fold()"
|
||||||
|
+ f".coalesce(unfold(),addV('{node.type}')"
|
||||||
|
+ f".property('id','{node.id}')"
|
||||||
|
+ f".property('type','{node.type}')"
|
||||||
|
)
|
||||||
|
for key, value in node.properties.items():
|
||||||
|
base_query += f".property('{key}', '{value}')"
|
||||||
|
|
||||||
|
return base_query + ")"
|
||||||
|
|
||||||
|
def build_edge_query(self, relationship: Relationship) -> str:
|
||||||
|
source_query = f".has('id','{relationship.source.id}')"
|
||||||
|
target_query = f".has('id','{relationship.target.id}')"
|
||||||
|
|
||||||
|
base_query = f""""g.V(){source_query}.as('a')
|
||||||
|
.V(){target_query}.as('b')
|
||||||
|
.choose(
|
||||||
|
__.inE('{relationship.type}').where(outV().as('a')),
|
||||||
|
__.identity(),
|
||||||
|
__.addE('{relationship.type}').from('a').to('b')
|
||||||
|
)
|
||||||
|
""".replace("\n", "").replace("\t", "")
|
||||||
|
for key, value in relationship.properties.items():
|
||||||
|
base_query += f".property('{key}', '{value}')"
|
||||||
|
|
||||||
|
return base_query
|
||||||
|
|
||||||
|
def add_node(self, node: Node, node_cache: dict = {}) -> Node:
|
||||||
|
# if properties does not have label, add type as label
|
||||||
|
if "label" not in node.properties:
|
||||||
|
node.properties["label"] = node.type
|
||||||
|
if node.id in node_cache:
|
||||||
|
return node_cache[node.id]
|
||||||
|
else:
|
||||||
|
query = self.build_vertex_query(node)
|
||||||
|
_ = self.client.submit(query).all().result()[0]
|
||||||
|
node_cache[node.id] = node
|
||||||
|
return node
|
||||||
|
|
||||||
|
def add_edge(self, relationship: Relationship) -> Any:
|
||||||
|
query = self.build_edge_query(relationship)
|
||||||
|
return self.client.submit(query).all().result()
|
@ -0,0 +1,221 @@
|
|||||||
|
"""Question answering over a graph."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
from langchain_community.graphs import GremlinGraph
|
||||||
|
from langchain_core.callbacks.manager import CallbackManager, CallbackManagerForChainRun
|
||||||
|
from langchain_core.language_models import BaseLanguageModel
|
||||||
|
from langchain_core.prompts import BasePromptTemplate
|
||||||
|
from langchain_core.prompts.prompt import PromptTemplate
|
||||||
|
from langchain_core.pydantic_v1 import Field
|
||||||
|
|
||||||
|
from langchain.chains.base import Chain
|
||||||
|
from langchain.chains.graph_qa.prompts import (
|
||||||
|
CYPHER_QA_PROMPT,
|
||||||
|
GRAPHDB_SPARQL_FIX_TEMPLATE,
|
||||||
|
GREMLIN_GENERATION_PROMPT,
|
||||||
|
)
|
||||||
|
from langchain.chains.llm import LLMChain
|
||||||
|
|
||||||
|
INTERMEDIATE_STEPS_KEY = "intermediate_steps"
|
||||||
|
|
||||||
|
|
||||||
|
def extract_gremlin(text: str) -> str:
|
||||||
|
"""Extract Gremlin code from a text.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: Text to extract Gremlin code from.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Gremlin code extracted from the text.
|
||||||
|
"""
|
||||||
|
text = text.replace("`", "")
|
||||||
|
if text.startswith("gremlin"):
|
||||||
|
text = text[len("gremlin") :]
|
||||||
|
return text.replace("\n", "")
|
||||||
|
|
||||||
|
|
||||||
|
class GremlinQAChain(Chain):
|
||||||
|
"""Chain for question-answering against a graph by generating gremlin statements.
|
||||||
|
|
||||||
|
*Security note*: Make sure that the database connection uses credentials
|
||||||
|
that are narrowly-scoped to only include necessary permissions.
|
||||||
|
Failure to do so may result in data corruption or loss, since the calling
|
||||||
|
code may attempt commands that would result in deletion, mutation
|
||||||
|
of data if appropriately prompted or reading sensitive data if such
|
||||||
|
data is present in the database.
|
||||||
|
The best way to guard against such negative outcomes is to (as appropriate)
|
||||||
|
limit the permissions granted to the credentials used with this tool.
|
||||||
|
|
||||||
|
See https://python.langchain.com/docs/security for more information.
|
||||||
|
"""
|
||||||
|
|
||||||
|
graph: GremlinGraph = Field(exclude=True)
|
||||||
|
gremlin_generation_chain: LLMChain
|
||||||
|
qa_chain: LLMChain
|
||||||
|
gremlin_fix_chain: LLMChain
|
||||||
|
max_fix_retries: int = 3
|
||||||
|
input_key: str = "query" #: :meta private:
|
||||||
|
output_key: str = "result" #: :meta private:
|
||||||
|
top_k: int = 100
|
||||||
|
return_direct: bool = False
|
||||||
|
return_intermediate_steps: bool = False
|
||||||
|
|
||||||
|
@property
|
||||||
|
def input_keys(self) -> List[str]:
|
||||||
|
"""Input keys.
|
||||||
|
|
||||||
|
:meta private:
|
||||||
|
"""
|
||||||
|
return [self.input_key]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def output_keys(self) -> List[str]:
|
||||||
|
"""Output keys.
|
||||||
|
|
||||||
|
:meta private:
|
||||||
|
"""
|
||||||
|
_output_keys = [self.output_key]
|
||||||
|
return _output_keys
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_llm(
|
||||||
|
cls,
|
||||||
|
llm: BaseLanguageModel,
|
||||||
|
*,
|
||||||
|
gremlin_fix_prompt: BasePromptTemplate = PromptTemplate(
|
||||||
|
input_variables=["error_message", "generated_sparql", "schema"],
|
||||||
|
template=GRAPHDB_SPARQL_FIX_TEMPLATE.replace("SPARQL", "Gremlin").replace(
|
||||||
|
"in Turtle format", ""
|
||||||
|
),
|
||||||
|
),
|
||||||
|
qa_prompt: BasePromptTemplate = CYPHER_QA_PROMPT,
|
||||||
|
gremlin_prompt: BasePromptTemplate = GREMLIN_GENERATION_PROMPT,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> GremlinQAChain:
|
||||||
|
"""Initialize from LLM."""
|
||||||
|
qa_chain = LLMChain(llm=llm, prompt=qa_prompt)
|
||||||
|
gremlin_generation_chain = LLMChain(llm=llm, prompt=gremlin_prompt)
|
||||||
|
gremlinl_fix_chain = LLMChain(llm=llm, prompt=gremlin_fix_prompt)
|
||||||
|
return cls(
|
||||||
|
qa_chain=qa_chain,
|
||||||
|
gremlin_generation_chain=gremlin_generation_chain,
|
||||||
|
gremlin_fix_chain=gremlinl_fix_chain,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _call(
|
||||||
|
self,
|
||||||
|
inputs: Dict[str, Any],
|
||||||
|
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||||
|
) -> Dict[str, str]:
|
||||||
|
"""Generate gremlin statement, use it to look up in db and answer question."""
|
||||||
|
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
||||||
|
callbacks = _run_manager.get_child()
|
||||||
|
question = inputs[self.input_key]
|
||||||
|
|
||||||
|
intermediate_steps: List = []
|
||||||
|
|
||||||
|
chain_response = self.gremlin_generation_chain.invoke(
|
||||||
|
{"question": question, "schema": self.graph.get_schema}, callbacks=callbacks
|
||||||
|
)
|
||||||
|
|
||||||
|
generated_gremlin = extract_gremlin(
|
||||||
|
chain_response[self.gremlin_generation_chain.output_key]
|
||||||
|
)
|
||||||
|
|
||||||
|
_run_manager.on_text("Generated gremlin:", end="\n", verbose=self.verbose)
|
||||||
|
_run_manager.on_text(
|
||||||
|
generated_gremlin, color="green", end="\n", verbose=self.verbose
|
||||||
|
)
|
||||||
|
|
||||||
|
intermediate_steps.append({"query": generated_gremlin})
|
||||||
|
|
||||||
|
if generated_gremlin:
|
||||||
|
context = self.execute_with_retry(
|
||||||
|
_run_manager, callbacks, generated_gremlin
|
||||||
|
)[: self.top_k]
|
||||||
|
else:
|
||||||
|
context = []
|
||||||
|
|
||||||
|
if self.return_direct:
|
||||||
|
final_result = context
|
||||||
|
else:
|
||||||
|
_run_manager.on_text("Full Context:", end="\n", verbose=self.verbose)
|
||||||
|
_run_manager.on_text(
|
||||||
|
str(context), color="green", end="\n", verbose=self.verbose
|
||||||
|
)
|
||||||
|
|
||||||
|
intermediate_steps.append({"context": context})
|
||||||
|
|
||||||
|
result = self.qa_chain.invoke(
|
||||||
|
{"question": question, "context": context},
|
||||||
|
callbacks=callbacks,
|
||||||
|
)
|
||||||
|
final_result = result[self.qa_chain.output_key]
|
||||||
|
|
||||||
|
chain_result: Dict[str, Any] = {self.output_key: final_result}
|
||||||
|
if self.return_intermediate_steps:
|
||||||
|
chain_result[INTERMEDIATE_STEPS_KEY] = intermediate_steps
|
||||||
|
|
||||||
|
return chain_result
|
||||||
|
|
||||||
|
def execute_query(self, query: str) -> List[Any]:
|
||||||
|
try:
|
||||||
|
return self.graph.query(query)
|
||||||
|
except Exception as e:
|
||||||
|
if hasattr(e, "status_message"):
|
||||||
|
raise ValueError(e.status_message)
|
||||||
|
else:
|
||||||
|
raise ValueError(str(e))
|
||||||
|
|
||||||
|
def execute_with_retry(
|
||||||
|
self,
|
||||||
|
_run_manager: CallbackManagerForChainRun,
|
||||||
|
callbacks: CallbackManager,
|
||||||
|
generated_gremlin: str,
|
||||||
|
) -> List[Any]:
|
||||||
|
try:
|
||||||
|
return self.execute_query(generated_gremlin)
|
||||||
|
except Exception as e:
|
||||||
|
retries = 0
|
||||||
|
error_message = str(e)
|
||||||
|
self.log_invalid_query(_run_manager, generated_gremlin, error_message)
|
||||||
|
|
||||||
|
while retries < self.max_fix_retries:
|
||||||
|
try:
|
||||||
|
fix_chain_result = self.gremlin_fix_chain.invoke(
|
||||||
|
{
|
||||||
|
"error_message": error_message,
|
||||||
|
# we are borrowing template from sparql
|
||||||
|
"generated_sparql": generated_gremlin,
|
||||||
|
"schema": self.schema,
|
||||||
|
},
|
||||||
|
callbacks=callbacks,
|
||||||
|
)
|
||||||
|
fixed_gremlin = fix_chain_result[self.gremlin_fix_chain.output_key]
|
||||||
|
return self.execute_query(fixed_gremlin)
|
||||||
|
except Exception as e:
|
||||||
|
retries += 1
|
||||||
|
parse_exception = str(e)
|
||||||
|
self.log_invalid_query(_run_manager, fixed_gremlin, parse_exception)
|
||||||
|
|
||||||
|
raise ValueError("The generated Gremlin query is invalid.")
|
||||||
|
|
||||||
|
def log_invalid_query(
|
||||||
|
self,
|
||||||
|
_run_manager: CallbackManagerForChainRun,
|
||||||
|
generated_query: str,
|
||||||
|
error_message: str,
|
||||||
|
) -> None:
|
||||||
|
_run_manager.on_text("Invalid Gremlin query: ", end="\n", verbose=self.verbose)
|
||||||
|
_run_manager.on_text(
|
||||||
|
generated_query, color="red", end="\n", verbose=self.verbose
|
||||||
|
)
|
||||||
|
_run_manager.on_text(
|
||||||
|
"Gremlin Query Parse Error: ", end="\n", verbose=self.verbose
|
||||||
|
)
|
||||||
|
_run_manager.on_text(
|
||||||
|
error_message, color="red", end="\n\n", verbose=self.verbose
|
||||||
|
)
|
Loading…
Reference in New Issue