Add option to save/load graph cypher QA (#6219)

Similar as https://github.com/hwchase17/langchain/pull/5818

Added the functionality to save/load Graph Cypher QA Chain due to a user
reporting the following error

> raise NotImplementedError("Saving not supported for this chain
type.")\nNotImplementedError: Saving not supported for this chain
type.\n'
master
Tomaz Bratanic 11 months ago committed by GitHub
parent 495128ba95
commit b3bccabc66
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -59,6 +59,10 @@ class GraphCypherQAChain(Chain):
_output_keys = [self.output_key]
return _output_keys
@property
def _chain_type(self) -> str:
return "graph_cypher_chain"
@classmethod
def from_llm(
cls,

@ -11,6 +11,7 @@ from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChai
from langchain.chains.combine_documents.map_rerank import MapRerankDocumentsChain
from langchain.chains.combine_documents.refine import RefineDocumentsChain
from langchain.chains.combine_documents.stuff import StuffDocumentsChain
from langchain.chains.graph_qa.cypher import GraphCypherQAChain
from langchain.chains.hyde.base import HypotheticalDocumentEmbedder
from langchain.chains.llm import LLMChain
from langchain.chains.llm_bash.base import LLMBashChain
@ -416,6 +417,30 @@ def _load_vector_db_qa(config: dict, **kwargs: Any) -> VectorDBQA:
)
def _load_graph_cypher_chain(config: dict, **kwargs: Any) -> GraphCypherQAChain:
if "graph" in kwargs:
graph = kwargs.pop("graph")
else:
raise ValueError("`graph` must be present.")
if "cypher_generation_chain" in config:
cypher_generation_chain_config = config.pop("cypher_generation_chain")
cypher_generation_chain = load_chain_from_config(cypher_generation_chain_config)
else:
raise ValueError("`cypher_generation_chain` must be present.")
if "qa_chain" in config:
qa_chain_config = config.pop("qa_chain")
qa_chain = load_chain_from_config(qa_chain_config)
else:
raise ValueError("`qa_chain` must be present.")
return GraphCypherQAChain(
graph=graph,
cypher_generation_chain=cypher_generation_chain,
qa_chain=qa_chain,
**config,
)
def _load_api_chain(config: dict, **kwargs: Any) -> APIChain:
if "api_request_chain" in config:
api_request_chain_config = config.pop("api_request_chain")
@ -482,6 +507,7 @@ type_to_loader_dict = {
"vector_db_qa_with_sources_chain": _load_vector_db_qa_with_sources_chain,
"vector_db_qa": _load_vector_db_qa,
"retrieval_qa": _load_retrieval_qa,
"graph_cypher_chain": _load_graph_cypher_chain,
}

@ -2,6 +2,7 @@
import os
from langchain.chains.graph_qa.cypher import GraphCypherQAChain
from langchain.chains.loading import load_chain
from langchain.graphs import Neo4jGraph
from langchain.llms.openai import OpenAI
@ -168,3 +169,31 @@ def test_cypher_return_direct() -> None:
output = chain.run("Who played in Pulp Fiction?")
expected_output = [{"a.name": "Bruce Willis"}]
assert output == expected_output
def test_cypher_save_load() -> None:
"""Test saving and loading."""
FILE_PATH = "cypher.yaml"
url = os.environ.get("NEO4J_URL")
username = os.environ.get("NEO4J_USERNAME")
password = os.environ.get("NEO4J_PASSWORD")
assert url is not None
assert username is not None
assert password is not None
graph = Neo4jGraph(
url=url,
username=username,
password=password,
)
chain = GraphCypherQAChain.from_llm(
OpenAI(temperature=0), graph=graph, return_direct=True
)
chain.save(file_path=FILE_PATH)
qa_loaded = load_chain(FILE_PATH, graph=graph)
assert qa_loaded == chain

Loading…
Cancel
Save