Add option to save/load graph cypher QA (#6219)

Similar as https://github.com/hwchase17/langchain/pull/5818

Added the functionality to save/load Graph Cypher QA Chain due to a user
reporting the following error

> raise NotImplementedError("Saving not supported for this chain
type.")\nNotImplementedError: Saving not supported for this chain
type.\n'
This commit is contained in:
Tomaz Bratanic 2023-06-19 02:00:27 +02:00 committed by GitHub
parent 495128ba95
commit b3bccabc66
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 59 additions and 0 deletions

View File

@ -59,6 +59,10 @@ class GraphCypherQAChain(Chain):
_output_keys = [self.output_key]
return _output_keys
@property
def _chain_type(self) -> str:
return "graph_cypher_chain"
@classmethod
def from_llm(
cls,

View File

@ -11,6 +11,7 @@ from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChai
from langchain.chains.combine_documents.map_rerank import MapRerankDocumentsChain
from langchain.chains.combine_documents.refine import RefineDocumentsChain
from langchain.chains.combine_documents.stuff import StuffDocumentsChain
from langchain.chains.graph_qa.cypher import GraphCypherQAChain
from langchain.chains.hyde.base import HypotheticalDocumentEmbedder
from langchain.chains.llm import LLMChain
from langchain.chains.llm_bash.base import LLMBashChain
@ -416,6 +417,30 @@ def _load_vector_db_qa(config: dict, **kwargs: Any) -> VectorDBQA:
)
def _load_graph_cypher_chain(config: dict, **kwargs: Any) -> GraphCypherQAChain:
if "graph" in kwargs:
graph = kwargs.pop("graph")
else:
raise ValueError("`graph` must be present.")
if "cypher_generation_chain" in config:
cypher_generation_chain_config = config.pop("cypher_generation_chain")
cypher_generation_chain = load_chain_from_config(cypher_generation_chain_config)
else:
raise ValueError("`cypher_generation_chain` must be present.")
if "qa_chain" in config:
qa_chain_config = config.pop("qa_chain")
qa_chain = load_chain_from_config(qa_chain_config)
else:
raise ValueError("`qa_chain` must be present.")
return GraphCypherQAChain(
graph=graph,
cypher_generation_chain=cypher_generation_chain,
qa_chain=qa_chain,
**config,
)
def _load_api_chain(config: dict, **kwargs: Any) -> APIChain:
if "api_request_chain" in config:
api_request_chain_config = config.pop("api_request_chain")
@ -482,6 +507,7 @@ type_to_loader_dict = {
"vector_db_qa_with_sources_chain": _load_vector_db_qa_with_sources_chain,
"vector_db_qa": _load_vector_db_qa,
"retrieval_qa": _load_retrieval_qa,
"graph_cypher_chain": _load_graph_cypher_chain,
}

View File

@ -2,6 +2,7 @@
import os
from langchain.chains.graph_qa.cypher import GraphCypherQAChain
from langchain.chains.loading import load_chain
from langchain.graphs import Neo4jGraph
from langchain.llms.openai import OpenAI
@ -168,3 +169,31 @@ def test_cypher_return_direct() -> None:
output = chain.run("Who played in Pulp Fiction?")
expected_output = [{"a.name": "Bruce Willis"}]
assert output == expected_output
def test_cypher_save_load() -> None:
"""Test saving and loading."""
FILE_PATH = "cypher.yaml"
url = os.environ.get("NEO4J_URL")
username = os.environ.get("NEO4J_USERNAME")
password = os.environ.get("NEO4J_PASSWORD")
assert url is not None
assert username is not None
assert password is not None
graph = Neo4jGraph(
url=url,
username=username,
password=password,
)
chain = GraphCypherQAChain.from_llm(
OpenAI(temperature=0), graph=graph, return_direct=True
)
chain.save(file_path=FILE_PATH)
qa_loaded = load_chain(FILE_PATH, graph=graph)
assert qa_loaded == chain