@ -11,6 +11,7 @@ from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChai
from langchain . chains . combine_documents . map_rerank import MapRerankDocumentsChain
from langchain . chains . combine_documents . refine import RefineDocumentsChain
from langchain . chains . combine_documents . stuff import StuffDocumentsChain
from langchain . chains . graph_qa . cypher import GraphCypherQAChain
from langchain . chains . hyde . base import HypotheticalDocumentEmbedder
from langchain . chains . llm import LLMChain
from langchain . chains . llm_bash . base import LLMBashChain
@ -416,6 +417,30 @@ def _load_vector_db_qa(config: dict, **kwargs: Any) -> VectorDBQA:
)
def _load_graph_cypher_chain ( config : dict , * * kwargs : Any ) - > GraphCypherQAChain :
if " graph " in kwargs :
graph = kwargs . pop ( " graph " )
else :
raise ValueError ( " `graph` must be present. " )
if " cypher_generation_chain " in config :
cypher_generation_chain_config = config . pop ( " cypher_generation_chain " )
cypher_generation_chain = load_chain_from_config ( cypher_generation_chain_config )
else :
raise ValueError ( " `cypher_generation_chain` must be present. " )
if " qa_chain " in config :
qa_chain_config = config . pop ( " qa_chain " )
qa_chain = load_chain_from_config ( qa_chain_config )
else :
raise ValueError ( " `qa_chain` must be present. " )
return GraphCypherQAChain (
graph = graph ,
cypher_generation_chain = cypher_generation_chain ,
qa_chain = qa_chain ,
* * config ,
)
def _load_api_chain ( config : dict , * * kwargs : Any ) - > APIChain :
if " api_request_chain " in config :
api_request_chain_config = config . pop ( " api_request_chain " )
@ -482,6 +507,7 @@ type_to_loader_dict = {
" vector_db_qa_with_sources_chain " : _load_vector_db_qa_with_sources_chain ,
" vector_db_qa " : _load_vector_db_qa ,
" retrieval_qa " : _load_retrieval_qa ,
" graph_cypher_chain " : _load_graph_cypher_chain ,
}