From b3bccabc66a8bf2521c16fe1f4d4739c71e2d468 Mon Sep 17 00:00:00 2001 From: Tomaz Bratanic Date: Mon, 19 Jun 2023 02:00:27 +0200 Subject: [PATCH] Add option to save/load graph cypher QA (#6219) Similar as https://github.com/hwchase17/langchain/pull/5818 Added the functionality to save/load Graph Cypher QA Chain due to a user reporting the following error > raise NotImplementedError("Saving not supported for this chain type.")\nNotImplementedError: Saving not supported for this chain type.\n' --- langchain/chains/graph_qa/cypher.py | 4 +++ langchain/chains/loading.py | 26 +++++++++++++++++ .../chains/test_graph_database.py | 29 +++++++++++++++++++ 3 files changed, 59 insertions(+) diff --git a/langchain/chains/graph_qa/cypher.py b/langchain/chains/graph_qa/cypher.py index 39d124c758..456bbad97a 100644 --- a/langchain/chains/graph_qa/cypher.py +++ b/langchain/chains/graph_qa/cypher.py @@ -59,6 +59,10 @@ class GraphCypherQAChain(Chain): _output_keys = [self.output_key] return _output_keys + @property + def _chain_type(self) -> str: + return "graph_cypher_chain" + @classmethod def from_llm( cls, diff --git a/langchain/chains/loading.py b/langchain/chains/loading.py index cc62866705..f115b90543 100644 --- a/langchain/chains/loading.py +++ b/langchain/chains/loading.py @@ -11,6 +11,7 @@ from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChai from langchain.chains.combine_documents.map_rerank import MapRerankDocumentsChain from langchain.chains.combine_documents.refine import RefineDocumentsChain from langchain.chains.combine_documents.stuff import StuffDocumentsChain +from langchain.chains.graph_qa.cypher import GraphCypherQAChain from langchain.chains.hyde.base import HypotheticalDocumentEmbedder from langchain.chains.llm import LLMChain from langchain.chains.llm_bash.base import LLMBashChain @@ -416,6 +417,30 @@ def _load_vector_db_qa(config: dict, **kwargs: Any) -> VectorDBQA: ) +def _load_graph_cypher_chain(config: dict, **kwargs: Any) -> GraphCypherQAChain: + if "graph" in kwargs: + graph = kwargs.pop("graph") + else: + raise ValueError("`graph` must be present.") + if "cypher_generation_chain" in config: + cypher_generation_chain_config = config.pop("cypher_generation_chain") + cypher_generation_chain = load_chain_from_config(cypher_generation_chain_config) + else: + raise ValueError("`cypher_generation_chain` must be present.") + if "qa_chain" in config: + qa_chain_config = config.pop("qa_chain") + qa_chain = load_chain_from_config(qa_chain_config) + else: + raise ValueError("`qa_chain` must be present.") + + return GraphCypherQAChain( + graph=graph, + cypher_generation_chain=cypher_generation_chain, + qa_chain=qa_chain, + **config, + ) + + def _load_api_chain(config: dict, **kwargs: Any) -> APIChain: if "api_request_chain" in config: api_request_chain_config = config.pop("api_request_chain") @@ -482,6 +507,7 @@ type_to_loader_dict = { "vector_db_qa_with_sources_chain": _load_vector_db_qa_with_sources_chain, "vector_db_qa": _load_vector_db_qa, "retrieval_qa": _load_retrieval_qa, + "graph_cypher_chain": _load_graph_cypher_chain, } diff --git a/tests/integration_tests/chains/test_graph_database.py b/tests/integration_tests/chains/test_graph_database.py index 9b515f9009..2c6133be1a 100644 --- a/tests/integration_tests/chains/test_graph_database.py +++ b/tests/integration_tests/chains/test_graph_database.py @@ -2,6 +2,7 @@ import os from langchain.chains.graph_qa.cypher import GraphCypherQAChain +from langchain.chains.loading import load_chain from langchain.graphs import Neo4jGraph from langchain.llms.openai import OpenAI @@ -168,3 +169,31 @@ def test_cypher_return_direct() -> None: output = chain.run("Who played in Pulp Fiction?") expected_output = [{"a.name": "Bruce Willis"}] assert output == expected_output + + +def test_cypher_save_load() -> None: + """Test saving and loading.""" + + FILE_PATH = "cypher.yaml" + + url = os.environ.get("NEO4J_URL") + username = os.environ.get("NEO4J_USERNAME") + password = os.environ.get("NEO4J_PASSWORD") + assert url is not None + assert username is not None + assert password is not None + + graph = Neo4jGraph( + url=url, + username=username, + password=password, + ) + + chain = GraphCypherQAChain.from_llm( + OpenAI(temperature=0), graph=graph, return_direct=True + ) + + chain.save(file_path=FILE_PATH) + qa_loaded = load_chain(FILE_PATH, graph=graph) + + assert qa_loaded == chain