Expose configuration options in GraphCypherQAChain (#12159)

Allows for passing arguments into the LLM chains used by the
GraphCypherQAChain. This is to address a request by a user to include
memory in the Cypher creating chain. Will keep the prompt variables
as-is to be backward compatible. But, would be a good idea to deprecate
them and use the **kwargs variables. Added a test case.

In general, I think it would be good for any chain to automatically pass
in a readonlymemory(of its input) to its subchains whilist allowing for
an override. But, this would be a different change.
pull/12219/head
sudranga 9 months ago committed by GitHub
parent 11f13aed53
commit f09f82541b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -132,13 +132,15 @@ class GraphCypherQAChain(Chain):
cls,
llm: Optional[BaseLanguageModel] = None,
*,
qa_prompt: BasePromptTemplate = CYPHER_QA_PROMPT,
cypher_prompt: BasePromptTemplate = CYPHER_GENERATION_PROMPT,
qa_prompt: Optional[BasePromptTemplate] = None,
cypher_prompt: Optional[BasePromptTemplate] = None,
cypher_llm: Optional[BaseLanguageModel] = None,
qa_llm: Optional[BaseLanguageModel] = None,
exclude_types: List[str] = [],
include_types: List[str] = [],
validate_cypher: bool = False,
qa_llm_kwargs: Optional[Dict[str, Any]] = None,
cypher_llm_kwargs: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> GraphCypherQAChain:
"""Initialize from LLM."""
@ -152,9 +154,34 @@ class GraphCypherQAChain(Chain):
"You can specify up to two of 'cypher_llm', 'qa_llm'"
", and 'llm', but not all three simultaneously."
)
if cypher_prompt and cypher_llm_kwargs:
raise ValueError(
"Specifying cypher_prompt and cypher_llm_kwargs together is"
" not allowed. Please pass prompt via cypher_llm_kwargs."
)
if qa_prompt and qa_llm_kwargs:
raise ValueError(
"Specifying qa_prompt and qa_llm_kwargs together is"
" not allowed. Please pass prompt via qa_llm_kwargs."
)
use_qa_llm_kwargs = qa_llm_kwargs if qa_llm_kwargs is not None else {}
use_cypher_llm_kwargs = (
cypher_llm_kwargs if cypher_llm_kwargs is not None else {}
)
if "prompt" not in use_qa_llm_kwargs:
use_qa_llm_kwargs["prompt"] = (
qa_prompt if qa_prompt is not None else CYPHER_QA_PROMPT
)
if "prompt" not in use_cypher_llm_kwargs:
use_cypher_llm_kwargs["prompt"] = (
cypher_prompt if cypher_prompt is not None else CYPHER_GENERATION_PROMPT
)
qa_chain = LLMChain(llm=qa_llm or llm, prompt=qa_prompt)
cypher_generation_chain = LLMChain(llm=cypher_llm or llm, prompt=cypher_prompt)
qa_chain = LLMChain(llm=qa_llm or llm, **use_qa_llm_kwargs)
cypher_generation_chain = LLMChain(
llm=cypher_llm or llm, **use_cypher_llm_kwargs
)
if exclude_types and include_types:
raise ValueError(

@ -1,9 +1,186 @@
from typing import List
from typing import Any, Dict, List
import pandas as pd
from langchain.chains.graph_qa.cypher import construct_schema, extract_cypher
from langchain.chains.graph_qa.cypher import (
GraphCypherQAChain,
construct_schema,
extract_cypher,
)
from langchain.chains.graph_qa.cypher_utils import CypherQueryCorrector, Schema
from langchain.chains.graph_qa.prompts import CYPHER_GENERATION_PROMPT, CYPHER_QA_PROMPT
from langchain.graphs.graph_document import GraphDocument
from langchain.graphs.graph_store import GraphStore
from langchain.memory import ConversationBufferMemory, ReadOnlySharedMemory
from langchain.prompts import PromptTemplate
from tests.unit_tests.llms.fake_llm import FakeLLM
class FakeGraphStore(GraphStore):
@property
def get_schema(self) -> str:
"""Returns the schema of the Graph database"""
return ""
@property
def get_structured_schema(self) -> Dict[str, Any]:
"""Returns the schema of the Graph database"""
return {}
def query(self, query: str, params: dict = {}) -> List[Dict[str, Any]]:
"""Query the graph."""
return []
def refresh_schema(self) -> None:
"""Refreshes the graph schema information."""
pass
def add_graph_documents(
self, graph_documents: List[GraphDocument], include_source: bool = False
) -> None:
"""Take GraphDocument as input as uses it to construct a graph."""
pass
def test_graph_cypher_qa_chain_prompt_selection_1() -> None:
# Pass prompts directly. No kwargs is specified.
qa_prompt_template = "QA Prompt"
cypher_prompt_template = "Cypher Prompt"
qa_prompt = PromptTemplate(template=qa_prompt_template, input_variables=[])
cypher_prompt = PromptTemplate(template=cypher_prompt_template, input_variables=[])
chain = GraphCypherQAChain.from_llm(
llm=FakeLLM(),
graph=FakeGraphStore(),
verbose=True,
return_intermediate_steps=False,
qa_prompt=qa_prompt,
cypher_prompt=cypher_prompt,
)
assert chain.qa_chain.prompt == qa_prompt
assert chain.cypher_generation_chain.prompt == cypher_prompt
def test_graph_cypher_qa_chain_prompt_selection_2() -> None:
# Default case. Pass nothing
chain = GraphCypherQAChain.from_llm(
llm=FakeLLM(),
graph=FakeGraphStore(),
verbose=True,
return_intermediate_steps=False,
)
assert chain.qa_chain.prompt == CYPHER_QA_PROMPT
assert chain.cypher_generation_chain.prompt == CYPHER_GENERATION_PROMPT
def test_graph_cypher_qa_chain_prompt_selection_3() -> None:
# Pass non-prompt args only to sub-chains via kwargs
memory = ConversationBufferMemory(memory_key="chat_history")
readonlymemory = ReadOnlySharedMemory(memory=memory)
chain = GraphCypherQAChain.from_llm(
llm=FakeLLM(),
graph=FakeGraphStore(),
verbose=True,
return_intermediate_steps=False,
cypher_llm_kwargs={"memory": readonlymemory},
qa_llm_kwargs={"memory": readonlymemory},
)
assert chain.qa_chain.prompt == CYPHER_QA_PROMPT
assert chain.cypher_generation_chain.prompt == CYPHER_GENERATION_PROMPT
def test_graph_cypher_qa_chain_prompt_selection_4() -> None:
# Pass prompt, non-prompt args to subchains via kwargs
qa_prompt_template = "QA Prompt"
cypher_prompt_template = "Cypher Prompt"
memory = ConversationBufferMemory(memory_key="chat_history")
readonlymemory = ReadOnlySharedMemory(memory=memory)
qa_prompt = PromptTemplate(template=qa_prompt_template, input_variables=[])
cypher_prompt = PromptTemplate(template=cypher_prompt_template, input_variables=[])
chain = GraphCypherQAChain.from_llm(
llm=FakeLLM(),
graph=FakeGraphStore(),
verbose=True,
return_intermediate_steps=False,
cypher_llm_kwargs={"prompt": cypher_prompt, "memory": readonlymemory},
qa_llm_kwargs={"prompt": qa_prompt, "memory": readonlymemory},
)
assert chain.qa_chain.prompt == qa_prompt
assert chain.cypher_generation_chain.prompt == cypher_prompt
def test_graph_cypher_qa_chain_prompt_selection_5() -> None:
# Can't pass both prompt and kwargs at the same time
qa_prompt_template = "QA Prompt"
cypher_prompt_template = "Cypher Prompt"
memory = ConversationBufferMemory(memory_key="chat_history")
readonlymemory = ReadOnlySharedMemory(memory=memory)
qa_prompt = PromptTemplate(template=qa_prompt_template, input_variables=[])
cypher_prompt = PromptTemplate(template=cypher_prompt_template, input_variables=[])
try:
GraphCypherQAChain.from_llm(
llm=FakeLLM(),
graph=FakeGraphStore(),
verbose=True,
return_intermediate_steps=False,
qa_prompt=qa_prompt,
cypher_prompt=cypher_prompt,
cypher_llm_kwargs={"memory": readonlymemory},
qa_llm_kwargs={"memory": readonlymemory},
)
assert False
except ValueError:
assert True
def test_graph_cypher_qa_chain() -> None:
template = """You are a nice chatbot having a conversation with a human.
Schema:
{schema}
Previous conversation:
{chat_history}
New human question: {question}
Response:"""
prompt = PromptTemplate(
input_variables=["schema", "question", "chat_history"], template=template
)
memory = ConversationBufferMemory(memory_key="chat_history")
readonlymemory = ReadOnlySharedMemory(memory=memory)
prompt1 = (
"You are a nice chatbot having a conversation with a human.\n\n "
"Schema:\n Node properties are the following: \n {}\nRelationships "
"properties are the following: \n {}\nRelationships are: \n[]\n\n "
"Previous conversation:\n \n\n New human question: "
"Test question\n Response:"
)
prompt2 = (
"You are a nice chatbot having a conversation with a human.\n\n "
"Schema:\n Node properties are the following: \n {}\nRelationships "
"properties are the following: \n {}\nRelationships are: \n[]\n\n "
"Previous conversation:\n Human: Test question\nAI: foo\n\n "
"New human question: Test new question\n Response:"
)
llm = FakeLLM(queries={prompt1: "answer1", prompt2: "answer2"})
chain = GraphCypherQAChain.from_llm(
cypher_llm=llm,
qa_llm=FakeLLM(),
graph=FakeGraphStore(),
verbose=True,
return_intermediate_steps=False,
cypher_llm_kwargs={"prompt": prompt, "memory": readonlymemory},
memory=memory,
)
chain.run("Test question")
chain.run("Test new question")
# If we get here without a key error, that means memory
# was used properly to create prompts.
assert True
def test_no_backticks() -> None:

@ -39,9 +39,11 @@ class FakeLLM(LLM):
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
print(prompt)
if self.sequential_responses:
return self._get_next_response_in_sequence
print(repr(prompt))
print(self.queries)
if self.queries is not None:
return self.queries[prompt]
if stop is None:

Loading…
Cancel
Save