mirror of
https://github.com/hwchase17/langchain
synced 2024-11-06 03:20:49 +00:00
Add option to save/load graph cypher QA (#6219)
Similar as https://github.com/hwchase17/langchain/pull/5818 Added the functionality to save/load Graph Cypher QA Chain due to a user reporting the following error > raise NotImplementedError("Saving not supported for this chain type.")\nNotImplementedError: Saving not supported for this chain type.\n'
This commit is contained in:
parent
495128ba95
commit
b3bccabc66
@ -59,6 +59,10 @@ class GraphCypherQAChain(Chain):
|
||||
_output_keys = [self.output_key]
|
||||
return _output_keys
|
||||
|
||||
@property
|
||||
def _chain_type(self) -> str:
|
||||
return "graph_cypher_chain"
|
||||
|
||||
@classmethod
|
||||
def from_llm(
|
||||
cls,
|
||||
|
@ -11,6 +11,7 @@ from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChai
|
||||
from langchain.chains.combine_documents.map_rerank import MapRerankDocumentsChain
|
||||
from langchain.chains.combine_documents.refine import RefineDocumentsChain
|
||||
from langchain.chains.combine_documents.stuff import StuffDocumentsChain
|
||||
from langchain.chains.graph_qa.cypher import GraphCypherQAChain
|
||||
from langchain.chains.hyde.base import HypotheticalDocumentEmbedder
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.chains.llm_bash.base import LLMBashChain
|
||||
@ -416,6 +417,30 @@ def _load_vector_db_qa(config: dict, **kwargs: Any) -> VectorDBQA:
|
||||
)
|
||||
|
||||
|
||||
def _load_graph_cypher_chain(config: dict, **kwargs: Any) -> GraphCypherQAChain:
|
||||
if "graph" in kwargs:
|
||||
graph = kwargs.pop("graph")
|
||||
else:
|
||||
raise ValueError("`graph` must be present.")
|
||||
if "cypher_generation_chain" in config:
|
||||
cypher_generation_chain_config = config.pop("cypher_generation_chain")
|
||||
cypher_generation_chain = load_chain_from_config(cypher_generation_chain_config)
|
||||
else:
|
||||
raise ValueError("`cypher_generation_chain` must be present.")
|
||||
if "qa_chain" in config:
|
||||
qa_chain_config = config.pop("qa_chain")
|
||||
qa_chain = load_chain_from_config(qa_chain_config)
|
||||
else:
|
||||
raise ValueError("`qa_chain` must be present.")
|
||||
|
||||
return GraphCypherQAChain(
|
||||
graph=graph,
|
||||
cypher_generation_chain=cypher_generation_chain,
|
||||
qa_chain=qa_chain,
|
||||
**config,
|
||||
)
|
||||
|
||||
|
||||
def _load_api_chain(config: dict, **kwargs: Any) -> APIChain:
|
||||
if "api_request_chain" in config:
|
||||
api_request_chain_config = config.pop("api_request_chain")
|
||||
@ -482,6 +507,7 @@ type_to_loader_dict = {
|
||||
"vector_db_qa_with_sources_chain": _load_vector_db_qa_with_sources_chain,
|
||||
"vector_db_qa": _load_vector_db_qa,
|
||||
"retrieval_qa": _load_retrieval_qa,
|
||||
"graph_cypher_chain": _load_graph_cypher_chain,
|
||||
}
|
||||
|
||||
|
||||
|
@ -2,6 +2,7 @@
|
||||
import os
|
||||
|
||||
from langchain.chains.graph_qa.cypher import GraphCypherQAChain
|
||||
from langchain.chains.loading import load_chain
|
||||
from langchain.graphs import Neo4jGraph
|
||||
from langchain.llms.openai import OpenAI
|
||||
|
||||
@ -168,3 +169,31 @@ def test_cypher_return_direct() -> None:
|
||||
output = chain.run("Who played in Pulp Fiction?")
|
||||
expected_output = [{"a.name": "Bruce Willis"}]
|
||||
assert output == expected_output
|
||||
|
||||
|
||||
def test_cypher_save_load() -> None:
|
||||
"""Test saving and loading."""
|
||||
|
||||
FILE_PATH = "cypher.yaml"
|
||||
|
||||
url = os.environ.get("NEO4J_URL")
|
||||
username = os.environ.get("NEO4J_USERNAME")
|
||||
password = os.environ.get("NEO4J_PASSWORD")
|
||||
assert url is not None
|
||||
assert username is not None
|
||||
assert password is not None
|
||||
|
||||
graph = Neo4jGraph(
|
||||
url=url,
|
||||
username=username,
|
||||
password=password,
|
||||
)
|
||||
|
||||
chain = GraphCypherQAChain.from_llm(
|
||||
OpenAI(temperature=0), graph=graph, return_direct=True
|
||||
)
|
||||
|
||||
chain.save(file_path=FILE_PATH)
|
||||
qa_loaded = load_chain(FILE_PATH, graph=graph)
|
||||
|
||||
assert qa_loaded == chain
|
||||
|
Loading…
Reference in New Issue
Block a user