mirror of https://github.com/hwchase17/langchain
community[minor], langchain[minor], docs: Gremlin Graph Store and QA Chain (#17683)
- **Description:** New feature: Gremlin graph-store and QA chain (including docs). Compatible with Azure CosmosDB. - **Dependencies:** no changespull/18386/head
parent
a5ccf5d33c
commit
6c1989d292
@ -0,0 +1,239 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "c94240f5",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Gremlin (with CosmosDB) QA chain\n",
|
||||
"\n",
|
||||
"This notebook shows how to use LLMs to provide a natural language interface to a graph database you can query with the Gremlin query language."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "dbc0ee68",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"You will need to have a Azure CosmosDB Graph database instance. One option is to create a [free CosmosDB Graph database instance in Azure](https://learn.microsoft.com/en-us/azure/cosmos-db/free-tier). \n",
|
||||
"\n",
|
||||
"When you create your Cosmos DB account and Graph, use /type as partition key."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "62812aad",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import nest_asyncio\n",
|
||||
"from langchain.chains.graph_qa import GremlinQAChain\n",
|
||||
"from langchain.schema import Document\n",
|
||||
"from langchain_community.graphs import GremlinGraph\n",
|
||||
"from langchain_community.graphs.graph_document import GraphDocument, Node, Relationship\n",
|
||||
"from langchain_openai import AzureChatOpenAI"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "0928915d",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"cosmosdb_name = \"mycosmosdb\"\n",
|
||||
"cosmosdb_db_id = \"graphtesting\"\n",
|
||||
"cosmosdb_db_graph_id = \"mygraph\"\n",
|
||||
"cosmosdb_access_Key = \"longstring==\"\n",
|
||||
"\n",
|
||||
"graph = GremlinGraph(\n",
|
||||
" url=f\"=wss://{cosmosdb_name}.gremlin.cosmos.azure.com:443/\",\n",
|
||||
" username=f\"/dbs/{cosmosdb_db_id}/colls/{cosmosdb_db_graph_id}\",\n",
|
||||
" password=cosmosdb_access_Key,\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "995ea9b9",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Seeding the database\n",
|
||||
"\n",
|
||||
"Assuming your database is empty, you can populate it using the GraphDocuments\n",
|
||||
"\n",
|
||||
"For Gremlin, always add property called 'label' for each Node.\n",
|
||||
"If no label is set, Node.type is used as a label.\n",
|
||||
"For cosmos using natural id's make sense, as they are visible in the graph explorer."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "fedd26b9",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"source_doc = Document(\n",
|
||||
" page_content=\"Matrix is a movie where Keanu Reeves, Laurence Fishburne and Carrie-Anne Moss acted.\"\n",
|
||||
")\n",
|
||||
"movie = Node(id=\"The Matrix\", properties={\"label\": \"movie\", \"title\": \"The Matrix\"})\n",
|
||||
"actor1 = Node(id=\"Keanu Reeves\", properties={\"label\": \"actor\", \"name\": \"Keanu Reeves\"})\n",
|
||||
"actor2 = Node(\n",
|
||||
" id=\"Laurence Fishburne\", properties={\"label\": \"actor\", \"name\": \"Laurence Fishburne\"}\n",
|
||||
")\n",
|
||||
"actor3 = Node(\n",
|
||||
" id=\"Carrie-Anne Moss\", properties={\"label\": \"actor\", \"name\": \"Carrie-Anne Moss\"}\n",
|
||||
")\n",
|
||||
"rel1 = Relationship(\n",
|
||||
" id=5, type=\"ActedIn\", source=actor1, target=movie, properties={\"label\": \"ActedIn\"}\n",
|
||||
")\n",
|
||||
"rel2 = Relationship(\n",
|
||||
" id=6, type=\"ActedIn\", source=actor2, target=movie, properties={\"label\": \"ActedIn\"}\n",
|
||||
")\n",
|
||||
"rel3 = Relationship(\n",
|
||||
" id=7, type=\"ActedIn\", source=actor3, target=movie, properties={\"label\": \"ActedIn\"}\n",
|
||||
")\n",
|
||||
"rel4 = Relationship(\n",
|
||||
" id=8,\n",
|
||||
" type=\"Starring\",\n",
|
||||
" source=movie,\n",
|
||||
" target=actor1,\n",
|
||||
" properties={\"label\": \"Strarring\"},\n",
|
||||
")\n",
|
||||
"rel5 = Relationship(\n",
|
||||
" id=9,\n",
|
||||
" type=\"Starring\",\n",
|
||||
" source=movie,\n",
|
||||
" target=actor2,\n",
|
||||
" properties={\"label\": \"Strarring\"},\n",
|
||||
")\n",
|
||||
"rel6 = Relationship(\n",
|
||||
" id=10,\n",
|
||||
" type=\"Straring\",\n",
|
||||
" source=movie,\n",
|
||||
" target=actor3,\n",
|
||||
" properties={\"label\": \"Strarring\"},\n",
|
||||
")\n",
|
||||
"graph_doc = GraphDocument(\n",
|
||||
" nodes=[movie, actor1, actor2, actor3],\n",
|
||||
" relationships=[rel1, rel2, rel3, rel4, rel5, rel6],\n",
|
||||
" source=source_doc,\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "d18f77a3",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# The underlying python-gremlin has a problem when running in notebook\n",
|
||||
"# The following line is a workaround to fix the problem\n",
|
||||
"nest_asyncio.apply()\n",
|
||||
"\n",
|
||||
"# Add the document to the CosmosDB graph.\n",
|
||||
"graph.add_graph_documents([graph_doc])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "58c1a8ea",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Refresh graph schema information\n",
|
||||
"If the schema of database changes (after updates), you can refresh the schema information.\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "4e3de44f",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"graph.refresh_schema()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "1fe76ccd",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"print(graph.schema)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "68a3c677",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Querying the graph\n",
|
||||
"\n",
|
||||
"We can now use the gremlin QA chain to ask question of the graph"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "7476ce98",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"chain = GremlinQAChain.from_llm(\n",
|
||||
" AzureChatOpenAI(\n",
|
||||
" temperature=0,\n",
|
||||
" azure_deployment=\"gpt-4-turbo\",\n",
|
||||
" ),\n",
|
||||
" graph=graph,\n",
|
||||
" verbose=True,\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "ef8ee27b",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"chain.invoke(\"Who played in The Matrix?\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "47c64027-cf42-493a-9c76-2d10ba753728",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"chain.run(\"How many people played in The Matrix?\")"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.9.13"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
@ -0,0 +1,207 @@
|
||||
import hashlib
|
||||
import sys
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from langchain_core.utils import get_from_env
|
||||
|
||||
from langchain_community.graphs.graph_document import GraphDocument, Node, Relationship
|
||||
from langchain_community.graphs.graph_store import GraphStore
|
||||
|
||||
|
||||
class GremlinGraph(GraphStore):
|
||||
"""Gremlin wrapper for graph operations.
|
||||
Parameters:
|
||||
url (Optional[str]): The URL of the Gremlin database server or env GREMLIN_URI
|
||||
username (Optional[str]): The collection-identifier like '/dbs/database/colls/graph'
|
||||
or env GREMLIN_USERNAME if none provided
|
||||
password (Optional[str]): The connection-key for database authentication
|
||||
or env GREMLIN_PASSWORD if none provided
|
||||
traversal_source (str): The traversal source to use for queries. Defaults to 'g'.
|
||||
message_serializer (Optional[Any]): The message serializer to use for requests.
|
||||
Defaults to serializer.GraphSONSerializersV2d0()
|
||||
*Security note*: Make sure that the database connection uses credentials
|
||||
that are narrowly-scoped to only include necessary permissions.
|
||||
Failure to do so may result in data corruption or loss, since the calling
|
||||
code may attempt commands that would result in deletion, mutation
|
||||
of data if appropriately prompted or reading sensitive data if such
|
||||
data is present in the database.
|
||||
The best way to guard against such negative outcomes is to (as appropriate)
|
||||
limit the permissions granted to the credentials used with this tool.
|
||||
|
||||
See https://python.langchain.com/docs/security for more information.
|
||||
|
||||
*Implementation details*:
|
||||
The Gremlin queries are designed to work with Azure CosmosDB limitations
|
||||
"""
|
||||
|
||||
@property
|
||||
def get_structured_schema(self) -> Dict[str, Any]:
|
||||
return self.structured_schema
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
url: Optional[str] = None,
|
||||
username: Optional[str] = None,
|
||||
password: Optional[str] = None,
|
||||
traversal_source: str = "g",
|
||||
message_serializer: Optional[Any] = None,
|
||||
) -> None:
|
||||
"""Create a new Gremlin graph wrapper instance."""
|
||||
try:
|
||||
import asyncio
|
||||
|
||||
from gremlin_python.driver import client, serializer
|
||||
|
||||
if sys.platform == "win32":
|
||||
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"Please install gremlin-python first: " "`pip3 install gremlinpython"
|
||||
)
|
||||
|
||||
self.client = client.Client(
|
||||
url=get_from_env("url", "GREMLIN_URI", url),
|
||||
traversal_source=traversal_source,
|
||||
username=get_from_env("username", "GREMLIN_USERNAME", username),
|
||||
password=get_from_env("password", "GREMLIN_PASSWORD", password),
|
||||
message_serializer=message_serializer
|
||||
if message_serializer
|
||||
else serializer.GraphSONSerializersV2d0(),
|
||||
)
|
||||
self.schema: str = ""
|
||||
|
||||
@property
|
||||
def get_schema(self) -> str:
|
||||
"""Returns the schema of the Gremlin database"""
|
||||
if len(self.schema) == 0:
|
||||
self.refresh_schema()
|
||||
return self.schema
|
||||
|
||||
def refresh_schema(self) -> None:
|
||||
"""
|
||||
Refreshes the Gremlin graph schema information.
|
||||
"""
|
||||
vertex_schema = self.client.submit("g.V().label().dedup()").all().result()
|
||||
edge_schema = self.client.submit("g.E().label().dedup()").all().result()
|
||||
vertex_properties = (
|
||||
self.client.submit(
|
||||
"g.V().group().by(label).by(properties().label().dedup().fold())"
|
||||
)
|
||||
.all()
|
||||
.result()[0]
|
||||
)
|
||||
|
||||
self.structured_schema = {
|
||||
"vertex_labels": vertex_schema,
|
||||
"edge_labels": edge_schema,
|
||||
"vertice_props": vertex_properties,
|
||||
}
|
||||
|
||||
self.schema = "\n".join(
|
||||
[
|
||||
"Vertex labels are the following:",
|
||||
",".join(vertex_schema),
|
||||
"Edge labes are the following:",
|
||||
",".join(edge_schema),
|
||||
f"Vertices have following properties:\n{vertex_properties}",
|
||||
]
|
||||
)
|
||||
|
||||
def query(self, query: str, params: dict = {}) -> List[Dict[str, Any]]:
|
||||
q = self.client.submit(query)
|
||||
return q.all().result()
|
||||
|
||||
def add_graph_documents(
|
||||
self, graph_documents: List[GraphDocument], include_source: bool = False
|
||||
) -> None:
|
||||
"""
|
||||
Take GraphDocument as input as uses it to construct a graph.
|
||||
"""
|
||||
node_cache: Dict[Union[str, int], Node] = {}
|
||||
for document in graph_documents:
|
||||
if include_source:
|
||||
# Create document vertex
|
||||
doc_props = {
|
||||
"page_content": document.source.page_content,
|
||||
"metadata": document.source.metadata,
|
||||
}
|
||||
doc_id = hashlib.md5(document.source.page_content.encode()).hexdigest()
|
||||
doc_node = self.add_node(
|
||||
Node(id=doc_id, type="Document", properties=doc_props), node_cache
|
||||
)
|
||||
|
||||
# Import nodes to vertices
|
||||
for n in document.nodes:
|
||||
node = self.add_node(n)
|
||||
if include_source:
|
||||
# Add Edge to document for each node
|
||||
self.add_edge(
|
||||
Relationship(
|
||||
type="contains information about",
|
||||
source=doc_node,
|
||||
target=node,
|
||||
properties={},
|
||||
)
|
||||
)
|
||||
self.add_edge(
|
||||
Relationship(
|
||||
type="is extracted from",
|
||||
source=node,
|
||||
target=doc_node,
|
||||
properties={},
|
||||
)
|
||||
)
|
||||
|
||||
# Edges
|
||||
for el in document.relationships:
|
||||
# Find or create the source vertex
|
||||
self.add_node(el.source, node_cache)
|
||||
# Find or create the target vertex
|
||||
self.add_node(el.target, node_cache)
|
||||
# Find or create the edge
|
||||
self.add_edge(el)
|
||||
|
||||
def build_vertex_query(self, node: Node) -> str:
|
||||
base_query = (
|
||||
f"g.V().has('id','{node.id}').fold()"
|
||||
+ f".coalesce(unfold(),addV('{node.type}')"
|
||||
+ f".property('id','{node.id}')"
|
||||
+ f".property('type','{node.type}')"
|
||||
)
|
||||
for key, value in node.properties.items():
|
||||
base_query += f".property('{key}', '{value}')"
|
||||
|
||||
return base_query + ")"
|
||||
|
||||
def build_edge_query(self, relationship: Relationship) -> str:
|
||||
source_query = f".has('id','{relationship.source.id}')"
|
||||
target_query = f".has('id','{relationship.target.id}')"
|
||||
|
||||
base_query = f""""g.V(){source_query}.as('a')
|
||||
.V(){target_query}.as('b')
|
||||
.choose(
|
||||
__.inE('{relationship.type}').where(outV().as('a')),
|
||||
__.identity(),
|
||||
__.addE('{relationship.type}').from('a').to('b')
|
||||
)
|
||||
""".replace("\n", "").replace("\t", "")
|
||||
for key, value in relationship.properties.items():
|
||||
base_query += f".property('{key}', '{value}')"
|
||||
|
||||
return base_query
|
||||
|
||||
def add_node(self, node: Node, node_cache: dict = {}) -> Node:
|
||||
# if properties does not have label, add type as label
|
||||
if "label" not in node.properties:
|
||||
node.properties["label"] = node.type
|
||||
if node.id in node_cache:
|
||||
return node_cache[node.id]
|
||||
else:
|
||||
query = self.build_vertex_query(node)
|
||||
_ = self.client.submit(query).all().result()[0]
|
||||
node_cache[node.id] = node
|
||||
return node
|
||||
|
||||
def add_edge(self, relationship: Relationship) -> Any:
|
||||
query = self.build_edge_query(relationship)
|
||||
return self.client.submit(query).all().result()
|
@ -0,0 +1,221 @@
|
||||
"""Question answering over a graph."""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from langchain_community.graphs import GremlinGraph
|
||||
from langchain_core.callbacks.manager import CallbackManager, CallbackManagerForChainRun
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
from langchain_core.prompts import BasePromptTemplate
|
||||
from langchain_core.prompts.prompt import PromptTemplate
|
||||
from langchain_core.pydantic_v1 import Field
|
||||
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.graph_qa.prompts import (
|
||||
CYPHER_QA_PROMPT,
|
||||
GRAPHDB_SPARQL_FIX_TEMPLATE,
|
||||
GREMLIN_GENERATION_PROMPT,
|
||||
)
|
||||
from langchain.chains.llm import LLMChain
|
||||
|
||||
INTERMEDIATE_STEPS_KEY = "intermediate_steps"
|
||||
|
||||
|
||||
def extract_gremlin(text: str) -> str:
|
||||
"""Extract Gremlin code from a text.
|
||||
|
||||
Args:
|
||||
text: Text to extract Gremlin code from.
|
||||
|
||||
Returns:
|
||||
Gremlin code extracted from the text.
|
||||
"""
|
||||
text = text.replace("`", "")
|
||||
if text.startswith("gremlin"):
|
||||
text = text[len("gremlin") :]
|
||||
return text.replace("\n", "")
|
||||
|
||||
|
||||
class GremlinQAChain(Chain):
|
||||
"""Chain for question-answering against a graph by generating gremlin statements.
|
||||
|
||||
*Security note*: Make sure that the database connection uses credentials
|
||||
that are narrowly-scoped to only include necessary permissions.
|
||||
Failure to do so may result in data corruption or loss, since the calling
|
||||
code may attempt commands that would result in deletion, mutation
|
||||
of data if appropriately prompted or reading sensitive data if such
|
||||
data is present in the database.
|
||||
The best way to guard against such negative outcomes is to (as appropriate)
|
||||
limit the permissions granted to the credentials used with this tool.
|
||||
|
||||
See https://python.langchain.com/docs/security for more information.
|
||||
"""
|
||||
|
||||
graph: GremlinGraph = Field(exclude=True)
|
||||
gremlin_generation_chain: LLMChain
|
||||
qa_chain: LLMChain
|
||||
gremlin_fix_chain: LLMChain
|
||||
max_fix_retries: int = 3
|
||||
input_key: str = "query" #: :meta private:
|
||||
output_key: str = "result" #: :meta private:
|
||||
top_k: int = 100
|
||||
return_direct: bool = False
|
||||
return_intermediate_steps: bool = False
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
"""Input keys.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
return [self.input_key]
|
||||
|
||||
@property
|
||||
def output_keys(self) -> List[str]:
|
||||
"""Output keys.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
_output_keys = [self.output_key]
|
||||
return _output_keys
|
||||
|
||||
@classmethod
|
||||
def from_llm(
|
||||
cls,
|
||||
llm: BaseLanguageModel,
|
||||
*,
|
||||
gremlin_fix_prompt: BasePromptTemplate = PromptTemplate(
|
||||
input_variables=["error_message", "generated_sparql", "schema"],
|
||||
template=GRAPHDB_SPARQL_FIX_TEMPLATE.replace("SPARQL", "Gremlin").replace(
|
||||
"in Turtle format", ""
|
||||
),
|
||||
),
|
||||
qa_prompt: BasePromptTemplate = CYPHER_QA_PROMPT,
|
||||
gremlin_prompt: BasePromptTemplate = GREMLIN_GENERATION_PROMPT,
|
||||
**kwargs: Any,
|
||||
) -> GremlinQAChain:
|
||||
"""Initialize from LLM."""
|
||||
qa_chain = LLMChain(llm=llm, prompt=qa_prompt)
|
||||
gremlin_generation_chain = LLMChain(llm=llm, prompt=gremlin_prompt)
|
||||
gremlinl_fix_chain = LLMChain(llm=llm, prompt=gremlin_fix_prompt)
|
||||
return cls(
|
||||
qa_chain=qa_chain,
|
||||
gremlin_generation_chain=gremlin_generation_chain,
|
||||
gremlin_fix_chain=gremlinl_fix_chain,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, str]:
|
||||
"""Generate gremlin statement, use it to look up in db and answer question."""
|
||||
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
||||
callbacks = _run_manager.get_child()
|
||||
question = inputs[self.input_key]
|
||||
|
||||
intermediate_steps: List = []
|
||||
|
||||
chain_response = self.gremlin_generation_chain.invoke(
|
||||
{"question": question, "schema": self.graph.get_schema}, callbacks=callbacks
|
||||
)
|
||||
|
||||
generated_gremlin = extract_gremlin(
|
||||
chain_response[self.gremlin_generation_chain.output_key]
|
||||
)
|
||||
|
||||
_run_manager.on_text("Generated gremlin:", end="\n", verbose=self.verbose)
|
||||
_run_manager.on_text(
|
||||
generated_gremlin, color="green", end="\n", verbose=self.verbose
|
||||
)
|
||||
|
||||
intermediate_steps.append({"query": generated_gremlin})
|
||||
|
||||
if generated_gremlin:
|
||||
context = self.execute_with_retry(
|
||||
_run_manager, callbacks, generated_gremlin
|
||||
)[: self.top_k]
|
||||
else:
|
||||
context = []
|
||||
|
||||
if self.return_direct:
|
||||
final_result = context
|
||||
else:
|
||||
_run_manager.on_text("Full Context:", end="\n", verbose=self.verbose)
|
||||
_run_manager.on_text(
|
||||
str(context), color="green", end="\n", verbose=self.verbose
|
||||
)
|
||||
|
||||
intermediate_steps.append({"context": context})
|
||||
|
||||
result = self.qa_chain.invoke(
|
||||
{"question": question, "context": context},
|
||||
callbacks=callbacks,
|
||||
)
|
||||
final_result = result[self.qa_chain.output_key]
|
||||
|
||||
chain_result: Dict[str, Any] = {self.output_key: final_result}
|
||||
if self.return_intermediate_steps:
|
||||
chain_result[INTERMEDIATE_STEPS_KEY] = intermediate_steps
|
||||
|
||||
return chain_result
|
||||
|
||||
def execute_query(self, query: str) -> List[Any]:
|
||||
try:
|
||||
return self.graph.query(query)
|
||||
except Exception as e:
|
||||
if hasattr(e, "status_message"):
|
||||
raise ValueError(e.status_message)
|
||||
else:
|
||||
raise ValueError(str(e))
|
||||
|
||||
def execute_with_retry(
|
||||
self,
|
||||
_run_manager: CallbackManagerForChainRun,
|
||||
callbacks: CallbackManager,
|
||||
generated_gremlin: str,
|
||||
) -> List[Any]:
|
||||
try:
|
||||
return self.execute_query(generated_gremlin)
|
||||
except Exception as e:
|
||||
retries = 0
|
||||
error_message = str(e)
|
||||
self.log_invalid_query(_run_manager, generated_gremlin, error_message)
|
||||
|
||||
while retries < self.max_fix_retries:
|
||||
try:
|
||||
fix_chain_result = self.gremlin_fix_chain.invoke(
|
||||
{
|
||||
"error_message": error_message,
|
||||
# we are borrowing template from sparql
|
||||
"generated_sparql": generated_gremlin,
|
||||
"schema": self.schema,
|
||||
},
|
||||
callbacks=callbacks,
|
||||
)
|
||||
fixed_gremlin = fix_chain_result[self.gremlin_fix_chain.output_key]
|
||||
return self.execute_query(fixed_gremlin)
|
||||
except Exception as e:
|
||||
retries += 1
|
||||
parse_exception = str(e)
|
||||
self.log_invalid_query(_run_manager, fixed_gremlin, parse_exception)
|
||||
|
||||
raise ValueError("The generated Gremlin query is invalid.")
|
||||
|
||||
def log_invalid_query(
|
||||
self,
|
||||
_run_manager: CallbackManagerForChainRun,
|
||||
generated_query: str,
|
||||
error_message: str,
|
||||
) -> None:
|
||||
_run_manager.on_text("Invalid Gremlin query: ", end="\n", verbose=self.verbose)
|
||||
_run_manager.on_text(
|
||||
generated_query, color="red", end="\n", verbose=self.verbose
|
||||
)
|
||||
_run_manager.on_text(
|
||||
"Gremlin Query Parse Error: ", end="\n", verbose=self.verbose
|
||||
)
|
||||
_run_manager.on_text(
|
||||
error_message, color="red", end="\n\n", verbose=self.verbose
|
||||
)
|
Loading…
Reference in New Issue