From 6c1989d292fa6652f73d9cdf18e3025596c74576 Mon Sep 17 00:00:00 2001 From: Petteri Johansson <191493+piizei@users.noreply.github.com> Date: Fri, 1 Mar 2024 21:21:14 +0100 Subject: [PATCH] community[minor], langchain[minor], docs: Gremlin Graph Store and QA Chain (#17683) - **Description:** New feature: Gremlin graph-store and QA chain (including docs). Compatible with Azure CosmosDB. - **Dependencies:** no changes --- .../graph/graph_gremlin_cosmosdb_qa.ipynb | 239 ++++++++++++++++++ .../langchain_community/graphs/__init__.py | 2 + .../graphs/gremlin_graph.py | 207 +++++++++++++++ .../tests/unit_tests/graphs/test_imports.py | 1 + .../langchain/chains/graph_qa/gremlin.py | 221 ++++++++++++++++ 5 files changed, 670 insertions(+) create mode 100644 docs/docs/use_cases/graph/graph_gremlin_cosmosdb_qa.ipynb create mode 100644 libs/community/langchain_community/graphs/gremlin_graph.py create mode 100644 libs/langchain/langchain/chains/graph_qa/gremlin.py diff --git a/docs/docs/use_cases/graph/graph_gremlin_cosmosdb_qa.ipynb b/docs/docs/use_cases/graph/graph_gremlin_cosmosdb_qa.ipynb new file mode 100644 index 0000000000..3ae5d2f2d8 --- /dev/null +++ b/docs/docs/use_cases/graph/graph_gremlin_cosmosdb_qa.ipynb @@ -0,0 +1,239 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "c94240f5", + "metadata": {}, + "source": [ + "# Gremlin (with CosmosDB) QA chain\n", + "\n", + "This notebook shows how to use LLMs to provide a natural language interface to a graph database you can query with the Gremlin query language." + ] + }, + { + "cell_type": "markdown", + "id": "dbc0ee68", + "metadata": {}, + "source": [ + "You will need to have a Azure CosmosDB Graph database instance. One option is to create a [free CosmosDB Graph database instance in Azure](https://learn.microsoft.com/en-us/azure/cosmos-db/free-tier). \n", + "\n", + "When you create your Cosmos DB account and Graph, use /type as partition key." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "62812aad", + "metadata": {}, + "outputs": [], + "source": [ + "import nest_asyncio\n", + "from langchain.chains.graph_qa import GremlinQAChain\n", + "from langchain.schema import Document\n", + "from langchain_community.graphs import GremlinGraph\n", + "from langchain_community.graphs.graph_document import GraphDocument, Node, Relationship\n", + "from langchain_openai import AzureChatOpenAI" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0928915d", + "metadata": {}, + "outputs": [], + "source": [ + "cosmosdb_name = \"mycosmosdb\"\n", + "cosmosdb_db_id = \"graphtesting\"\n", + "cosmosdb_db_graph_id = \"mygraph\"\n", + "cosmosdb_access_Key = \"longstring==\"\n", + "\n", + "graph = GremlinGraph(\n", + " url=f\"=wss://{cosmosdb_name}.gremlin.cosmos.azure.com:443/\",\n", + " username=f\"/dbs/{cosmosdb_db_id}/colls/{cosmosdb_db_graph_id}\",\n", + " password=cosmosdb_access_Key,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "995ea9b9", + "metadata": {}, + "source": [ + "## Seeding the database\n", + "\n", + "Assuming your database is empty, you can populate it using the GraphDocuments\n", + "\n", + "For Gremlin, always add property called 'label' for each Node.\n", + "If no label is set, Node.type is used as a label.\n", + "For cosmos using natural id's make sense, as they are visible in the graph explorer." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "fedd26b9", + "metadata": {}, + "outputs": [], + "source": [ + "source_doc = Document(\n", + " page_content=\"Matrix is a movie where Keanu Reeves, Laurence Fishburne and Carrie-Anne Moss acted.\"\n", + ")\n", + "movie = Node(id=\"The Matrix\", properties={\"label\": \"movie\", \"title\": \"The Matrix\"})\n", + "actor1 = Node(id=\"Keanu Reeves\", properties={\"label\": \"actor\", \"name\": \"Keanu Reeves\"})\n", + "actor2 = Node(\n", + " id=\"Laurence Fishburne\", properties={\"label\": \"actor\", \"name\": \"Laurence Fishburne\"}\n", + ")\n", + "actor3 = Node(\n", + " id=\"Carrie-Anne Moss\", properties={\"label\": \"actor\", \"name\": \"Carrie-Anne Moss\"}\n", + ")\n", + "rel1 = Relationship(\n", + " id=5, type=\"ActedIn\", source=actor1, target=movie, properties={\"label\": \"ActedIn\"}\n", + ")\n", + "rel2 = Relationship(\n", + " id=6, type=\"ActedIn\", source=actor2, target=movie, properties={\"label\": \"ActedIn\"}\n", + ")\n", + "rel3 = Relationship(\n", + " id=7, type=\"ActedIn\", source=actor3, target=movie, properties={\"label\": \"ActedIn\"}\n", + ")\n", + "rel4 = Relationship(\n", + " id=8,\n", + " type=\"Starring\",\n", + " source=movie,\n", + " target=actor1,\n", + " properties={\"label\": \"Strarring\"},\n", + ")\n", + "rel5 = Relationship(\n", + " id=9,\n", + " type=\"Starring\",\n", + " source=movie,\n", + " target=actor2,\n", + " properties={\"label\": \"Strarring\"},\n", + ")\n", + "rel6 = Relationship(\n", + " id=10,\n", + " type=\"Straring\",\n", + " source=movie,\n", + " target=actor3,\n", + " properties={\"label\": \"Strarring\"},\n", + ")\n", + "graph_doc = GraphDocument(\n", + " nodes=[movie, actor1, actor2, actor3],\n", + " relationships=[rel1, rel2, rel3, rel4, rel5, rel6],\n", + " source=source_doc,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d18f77a3", + "metadata": {}, + "outputs": [], + "source": [ + "# The underlying python-gremlin has a problem when running in notebook\n", + "# The following line is a workaround to fix the problem\n", + "nest_asyncio.apply()\n", + "\n", + "# Add the document to the CosmosDB graph.\n", + "graph.add_graph_documents([graph_doc])" + ] + }, + { + "cell_type": "markdown", + "id": "58c1a8ea", + "metadata": {}, + "source": [ + "## Refresh graph schema information\n", + "If the schema of database changes (after updates), you can refresh the schema information.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4e3de44f", + "metadata": {}, + "outputs": [], + "source": [ + "graph.refresh_schema()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1fe76ccd", + "metadata": {}, + "outputs": [], + "source": [ + "print(graph.schema)" + ] + }, + { + "cell_type": "markdown", + "id": "68a3c677", + "metadata": {}, + "source": [ + "## Querying the graph\n", + "\n", + "We can now use the gremlin QA chain to ask question of the graph" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7476ce98", + "metadata": {}, + "outputs": [], + "source": [ + "chain = GremlinQAChain.from_llm(\n", + " AzureChatOpenAI(\n", + " temperature=0,\n", + " azure_deployment=\"gpt-4-turbo\",\n", + " ),\n", + " graph=graph,\n", + " verbose=True,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ef8ee27b", + "metadata": {}, + "outputs": [], + "source": [ + "chain.invoke(\"Who played in The Matrix?\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "47c64027-cf42-493a-9c76-2d10ba753728", + "metadata": {}, + "outputs": [], + "source": [ + "chain.run(\"How many people played in The Matrix?\")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.13" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/libs/community/langchain_community/graphs/__init__.py b/libs/community/langchain_community/graphs/__init__.py index c1fc640c4b..47aff97a85 100644 --- a/libs/community/langchain_community/graphs/__init__.py +++ b/libs/community/langchain_community/graphs/__init__.py @@ -2,6 +2,7 @@ from langchain_community.graphs.arangodb_graph import ArangoGraph from langchain_community.graphs.falkordb_graph import FalkorDBGraph +from langchain_community.graphs.gremlin_graph import GremlinGraph from langchain_community.graphs.hugegraph import HugeGraph from langchain_community.graphs.kuzu_graph import KuzuGraph from langchain_community.graphs.memgraph_graph import MemgraphGraph @@ -28,4 +29,5 @@ __all__ = [ "FalkorDBGraph", "TigerGraph", "OntotextGraphDBGraph", + "GremlinGraph", ] diff --git a/libs/community/langchain_community/graphs/gremlin_graph.py b/libs/community/langchain_community/graphs/gremlin_graph.py new file mode 100644 index 0000000000..2711c4f92a --- /dev/null +++ b/libs/community/langchain_community/graphs/gremlin_graph.py @@ -0,0 +1,207 @@ +import hashlib +import sys +from typing import Any, Dict, List, Optional, Union + +from langchain_core.utils import get_from_env + +from langchain_community.graphs.graph_document import GraphDocument, Node, Relationship +from langchain_community.graphs.graph_store import GraphStore + + +class GremlinGraph(GraphStore): + """Gremlin wrapper for graph operations. + Parameters: + url (Optional[str]): The URL of the Gremlin database server or env GREMLIN_URI + username (Optional[str]): The collection-identifier like '/dbs/database/colls/graph' + or env GREMLIN_USERNAME if none provided + password (Optional[str]): The connection-key for database authentication + or env GREMLIN_PASSWORD if none provided + traversal_source (str): The traversal source to use for queries. Defaults to 'g'. + message_serializer (Optional[Any]): The message serializer to use for requests. + Defaults to serializer.GraphSONSerializersV2d0() + *Security note*: Make sure that the database connection uses credentials + that are narrowly-scoped to only include necessary permissions. + Failure to do so may result in data corruption or loss, since the calling + code may attempt commands that would result in deletion, mutation + of data if appropriately prompted or reading sensitive data if such + data is present in the database. + The best way to guard against such negative outcomes is to (as appropriate) + limit the permissions granted to the credentials used with this tool. + + See https://python.langchain.com/docs/security for more information. + + *Implementation details*: + The Gremlin queries are designed to work with Azure CosmosDB limitations + """ + + @property + def get_structured_schema(self) -> Dict[str, Any]: + return self.structured_schema + + def __init__( + self, + url: Optional[str] = None, + username: Optional[str] = None, + password: Optional[str] = None, + traversal_source: str = "g", + message_serializer: Optional[Any] = None, + ) -> None: + """Create a new Gremlin graph wrapper instance.""" + try: + import asyncio + + from gremlin_python.driver import client, serializer + + if sys.platform == "win32": + asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) + except ImportError: + raise ValueError( + "Please install gremlin-python first: " "`pip3 install gremlinpython" + ) + + self.client = client.Client( + url=get_from_env("url", "GREMLIN_URI", url), + traversal_source=traversal_source, + username=get_from_env("username", "GREMLIN_USERNAME", username), + password=get_from_env("password", "GREMLIN_PASSWORD", password), + message_serializer=message_serializer + if message_serializer + else serializer.GraphSONSerializersV2d0(), + ) + self.schema: str = "" + + @property + def get_schema(self) -> str: + """Returns the schema of the Gremlin database""" + if len(self.schema) == 0: + self.refresh_schema() + return self.schema + + def refresh_schema(self) -> None: + """ + Refreshes the Gremlin graph schema information. + """ + vertex_schema = self.client.submit("g.V().label().dedup()").all().result() + edge_schema = self.client.submit("g.E().label().dedup()").all().result() + vertex_properties = ( + self.client.submit( + "g.V().group().by(label).by(properties().label().dedup().fold())" + ) + .all() + .result()[0] + ) + + self.structured_schema = { + "vertex_labels": vertex_schema, + "edge_labels": edge_schema, + "vertice_props": vertex_properties, + } + + self.schema = "\n".join( + [ + "Vertex labels are the following:", + ",".join(vertex_schema), + "Edge labes are the following:", + ",".join(edge_schema), + f"Vertices have following properties:\n{vertex_properties}", + ] + ) + + def query(self, query: str, params: dict = {}) -> List[Dict[str, Any]]: + q = self.client.submit(query) + return q.all().result() + + def add_graph_documents( + self, graph_documents: List[GraphDocument], include_source: bool = False + ) -> None: + """ + Take GraphDocument as input as uses it to construct a graph. + """ + node_cache: Dict[Union[str, int], Node] = {} + for document in graph_documents: + if include_source: + # Create document vertex + doc_props = { + "page_content": document.source.page_content, + "metadata": document.source.metadata, + } + doc_id = hashlib.md5(document.source.page_content.encode()).hexdigest() + doc_node = self.add_node( + Node(id=doc_id, type="Document", properties=doc_props), node_cache + ) + + # Import nodes to vertices + for n in document.nodes: + node = self.add_node(n) + if include_source: + # Add Edge to document for each node + self.add_edge( + Relationship( + type="contains information about", + source=doc_node, + target=node, + properties={}, + ) + ) + self.add_edge( + Relationship( + type="is extracted from", + source=node, + target=doc_node, + properties={}, + ) + ) + + # Edges + for el in document.relationships: + # Find or create the source vertex + self.add_node(el.source, node_cache) + # Find or create the target vertex + self.add_node(el.target, node_cache) + # Find or create the edge + self.add_edge(el) + + def build_vertex_query(self, node: Node) -> str: + base_query = ( + f"g.V().has('id','{node.id}').fold()" + + f".coalesce(unfold(),addV('{node.type}')" + + f".property('id','{node.id}')" + + f".property('type','{node.type}')" + ) + for key, value in node.properties.items(): + base_query += f".property('{key}', '{value}')" + + return base_query + ")" + + def build_edge_query(self, relationship: Relationship) -> str: + source_query = f".has('id','{relationship.source.id}')" + target_query = f".has('id','{relationship.target.id}')" + + base_query = f""""g.V(){source_query}.as('a') + .V(){target_query}.as('b') + .choose( + __.inE('{relationship.type}').where(outV().as('a')), + __.identity(), + __.addE('{relationship.type}').from('a').to('b') + ) + """.replace("\n", "").replace("\t", "") + for key, value in relationship.properties.items(): + base_query += f".property('{key}', '{value}')" + + return base_query + + def add_node(self, node: Node, node_cache: dict = {}) -> Node: + # if properties does not have label, add type as label + if "label" not in node.properties: + node.properties["label"] = node.type + if node.id in node_cache: + return node_cache[node.id] + else: + query = self.build_vertex_query(node) + _ = self.client.submit(query).all().result()[0] + node_cache[node.id] = node + return node + + def add_edge(self, relationship: Relationship) -> Any: + query = self.build_edge_query(relationship) + return self.client.submit(query).all().result() diff --git a/libs/community/tests/unit_tests/graphs/test_imports.py b/libs/community/tests/unit_tests/graphs/test_imports.py index 202ecefa24..272400085f 100644 --- a/libs/community/tests/unit_tests/graphs/test_imports.py +++ b/libs/community/tests/unit_tests/graphs/test_imports.py @@ -14,6 +14,7 @@ EXPECTED_ALL = [ "FalkorDBGraph", "TigerGraph", "OntotextGraphDBGraph", + "GremlinGraph", ] diff --git a/libs/langchain/langchain/chains/graph_qa/gremlin.py b/libs/langchain/langchain/chains/graph_qa/gremlin.py new file mode 100644 index 0000000000..732e06bde8 --- /dev/null +++ b/libs/langchain/langchain/chains/graph_qa/gremlin.py @@ -0,0 +1,221 @@ +"""Question answering over a graph.""" +from __future__ import annotations + +from typing import Any, Dict, List, Optional + +from langchain_community.graphs import GremlinGraph +from langchain_core.callbacks.manager import CallbackManager, CallbackManagerForChainRun +from langchain_core.language_models import BaseLanguageModel +from langchain_core.prompts import BasePromptTemplate +from langchain_core.prompts.prompt import PromptTemplate +from langchain_core.pydantic_v1 import Field + +from langchain.chains.base import Chain +from langchain.chains.graph_qa.prompts import ( + CYPHER_QA_PROMPT, + GRAPHDB_SPARQL_FIX_TEMPLATE, + GREMLIN_GENERATION_PROMPT, +) +from langchain.chains.llm import LLMChain + +INTERMEDIATE_STEPS_KEY = "intermediate_steps" + + +def extract_gremlin(text: str) -> str: + """Extract Gremlin code from a text. + + Args: + text: Text to extract Gremlin code from. + + Returns: + Gremlin code extracted from the text. + """ + text = text.replace("`", "") + if text.startswith("gremlin"): + text = text[len("gremlin") :] + return text.replace("\n", "") + + +class GremlinQAChain(Chain): + """Chain for question-answering against a graph by generating gremlin statements. + + *Security note*: Make sure that the database connection uses credentials + that are narrowly-scoped to only include necessary permissions. + Failure to do so may result in data corruption or loss, since the calling + code may attempt commands that would result in deletion, mutation + of data if appropriately prompted or reading sensitive data if such + data is present in the database. + The best way to guard against such negative outcomes is to (as appropriate) + limit the permissions granted to the credentials used with this tool. + + See https://python.langchain.com/docs/security for more information. + """ + + graph: GremlinGraph = Field(exclude=True) + gremlin_generation_chain: LLMChain + qa_chain: LLMChain + gremlin_fix_chain: LLMChain + max_fix_retries: int = 3 + input_key: str = "query" #: :meta private: + output_key: str = "result" #: :meta private: + top_k: int = 100 + return_direct: bool = False + return_intermediate_steps: bool = False + + @property + def input_keys(self) -> List[str]: + """Input keys. + + :meta private: + """ + return [self.input_key] + + @property + def output_keys(self) -> List[str]: + """Output keys. + + :meta private: + """ + _output_keys = [self.output_key] + return _output_keys + + @classmethod + def from_llm( + cls, + llm: BaseLanguageModel, + *, + gremlin_fix_prompt: BasePromptTemplate = PromptTemplate( + input_variables=["error_message", "generated_sparql", "schema"], + template=GRAPHDB_SPARQL_FIX_TEMPLATE.replace("SPARQL", "Gremlin").replace( + "in Turtle format", "" + ), + ), + qa_prompt: BasePromptTemplate = CYPHER_QA_PROMPT, + gremlin_prompt: BasePromptTemplate = GREMLIN_GENERATION_PROMPT, + **kwargs: Any, + ) -> GremlinQAChain: + """Initialize from LLM.""" + qa_chain = LLMChain(llm=llm, prompt=qa_prompt) + gremlin_generation_chain = LLMChain(llm=llm, prompt=gremlin_prompt) + gremlinl_fix_chain = LLMChain(llm=llm, prompt=gremlin_fix_prompt) + return cls( + qa_chain=qa_chain, + gremlin_generation_chain=gremlin_generation_chain, + gremlin_fix_chain=gremlinl_fix_chain, + **kwargs, + ) + + def _call( + self, + inputs: Dict[str, Any], + run_manager: Optional[CallbackManagerForChainRun] = None, + ) -> Dict[str, str]: + """Generate gremlin statement, use it to look up in db and answer question.""" + _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() + callbacks = _run_manager.get_child() + question = inputs[self.input_key] + + intermediate_steps: List = [] + + chain_response = self.gremlin_generation_chain.invoke( + {"question": question, "schema": self.graph.get_schema}, callbacks=callbacks + ) + + generated_gremlin = extract_gremlin( + chain_response[self.gremlin_generation_chain.output_key] + ) + + _run_manager.on_text("Generated gremlin:", end="\n", verbose=self.verbose) + _run_manager.on_text( + generated_gremlin, color="green", end="\n", verbose=self.verbose + ) + + intermediate_steps.append({"query": generated_gremlin}) + + if generated_gremlin: + context = self.execute_with_retry( + _run_manager, callbacks, generated_gremlin + )[: self.top_k] + else: + context = [] + + if self.return_direct: + final_result = context + else: + _run_manager.on_text("Full Context:", end="\n", verbose=self.verbose) + _run_manager.on_text( + str(context), color="green", end="\n", verbose=self.verbose + ) + + intermediate_steps.append({"context": context}) + + result = self.qa_chain.invoke( + {"question": question, "context": context}, + callbacks=callbacks, + ) + final_result = result[self.qa_chain.output_key] + + chain_result: Dict[str, Any] = {self.output_key: final_result} + if self.return_intermediate_steps: + chain_result[INTERMEDIATE_STEPS_KEY] = intermediate_steps + + return chain_result + + def execute_query(self, query: str) -> List[Any]: + try: + return self.graph.query(query) + except Exception as e: + if hasattr(e, "status_message"): + raise ValueError(e.status_message) + else: + raise ValueError(str(e)) + + def execute_with_retry( + self, + _run_manager: CallbackManagerForChainRun, + callbacks: CallbackManager, + generated_gremlin: str, + ) -> List[Any]: + try: + return self.execute_query(generated_gremlin) + except Exception as e: + retries = 0 + error_message = str(e) + self.log_invalid_query(_run_manager, generated_gremlin, error_message) + + while retries < self.max_fix_retries: + try: + fix_chain_result = self.gremlin_fix_chain.invoke( + { + "error_message": error_message, + # we are borrowing template from sparql + "generated_sparql": generated_gremlin, + "schema": self.schema, + }, + callbacks=callbacks, + ) + fixed_gremlin = fix_chain_result[self.gremlin_fix_chain.output_key] + return self.execute_query(fixed_gremlin) + except Exception as e: + retries += 1 + parse_exception = str(e) + self.log_invalid_query(_run_manager, fixed_gremlin, parse_exception) + + raise ValueError("The generated Gremlin query is invalid.") + + def log_invalid_query( + self, + _run_manager: CallbackManagerForChainRun, + generated_query: str, + error_message: str, + ) -> None: + _run_manager.on_text("Invalid Gremlin query: ", end="\n", verbose=self.verbose) + _run_manager.on_text( + generated_query, color="red", end="\n", verbose=self.verbose + ) + _run_manager.on_text( + "Gremlin Query Parse Error: ", end="\n", verbose=self.verbose + ) + _run_manager.on_text( + error_message, color="red", end="\n\n", verbose=self.verbose + )