From f09f82541b92fdc3693ca80b3ca514775fcd8759 Mon Sep 17 00:00:00 2001 From: sudranga Date: Tue, 24 Oct 2023 09:52:55 -0700 Subject: [PATCH] Expose configuration options in GraphCypherQAChain (#12159) Allows for passing arguments into the LLM chains used by the GraphCypherQAChain. This is to address a request by a user to include memory in the Cypher creating chain. Will keep the prompt variables as-is to be backward compatible. But, would be a good idea to deprecate them and use the **kwargs variables. Added a test case. In general, I think it would be good for any chain to automatically pass in a readonlymemory(of its input) to its subchains whilist allowing for an override. But, this would be a different change. --- .../langchain/chains/graph_qa/cypher.py | 35 +++- .../tests/unit_tests/chains/test_graph_qa.py | 181 +++++++++++++++++- .../tests/unit_tests/llms/fake_llm.py | 4 +- 3 files changed, 213 insertions(+), 7 deletions(-) diff --git a/libs/langchain/langchain/chains/graph_qa/cypher.py b/libs/langchain/langchain/chains/graph_qa/cypher.py index abbc5a85c3..3531d517f9 100644 --- a/libs/langchain/langchain/chains/graph_qa/cypher.py +++ b/libs/langchain/langchain/chains/graph_qa/cypher.py @@ -132,13 +132,15 @@ class GraphCypherQAChain(Chain): cls, llm: Optional[BaseLanguageModel] = None, *, - qa_prompt: BasePromptTemplate = CYPHER_QA_PROMPT, - cypher_prompt: BasePromptTemplate = CYPHER_GENERATION_PROMPT, + qa_prompt: Optional[BasePromptTemplate] = None, + cypher_prompt: Optional[BasePromptTemplate] = None, cypher_llm: Optional[BaseLanguageModel] = None, qa_llm: Optional[BaseLanguageModel] = None, exclude_types: List[str] = [], include_types: List[str] = [], validate_cypher: bool = False, + qa_llm_kwargs: Optional[Dict[str, Any]] = None, + cypher_llm_kwargs: Optional[Dict[str, Any]] = None, **kwargs: Any, ) -> GraphCypherQAChain: """Initialize from LLM.""" @@ -152,9 +154,34 @@ class GraphCypherQAChain(Chain): "You can specify up to two of 'cypher_llm', 'qa_llm'" ", and 'llm', but not all three simultaneously." ) + if cypher_prompt and cypher_llm_kwargs: + raise ValueError( + "Specifying cypher_prompt and cypher_llm_kwargs together is" + " not allowed. Please pass prompt via cypher_llm_kwargs." + ) + if qa_prompt and qa_llm_kwargs: + raise ValueError( + "Specifying qa_prompt and qa_llm_kwargs together is" + " not allowed. Please pass prompt via qa_llm_kwargs." + ) + use_qa_llm_kwargs = qa_llm_kwargs if qa_llm_kwargs is not None else {} + use_cypher_llm_kwargs = ( + cypher_llm_kwargs if cypher_llm_kwargs is not None else {} + ) + if "prompt" not in use_qa_llm_kwargs: + use_qa_llm_kwargs["prompt"] = ( + qa_prompt if qa_prompt is not None else CYPHER_QA_PROMPT + ) + if "prompt" not in use_cypher_llm_kwargs: + use_cypher_llm_kwargs["prompt"] = ( + cypher_prompt if cypher_prompt is not None else CYPHER_GENERATION_PROMPT + ) - qa_chain = LLMChain(llm=qa_llm or llm, prompt=qa_prompt) - cypher_generation_chain = LLMChain(llm=cypher_llm or llm, prompt=cypher_prompt) + qa_chain = LLMChain(llm=qa_llm or llm, **use_qa_llm_kwargs) + + cypher_generation_chain = LLMChain( + llm=cypher_llm or llm, **use_cypher_llm_kwargs + ) if exclude_types and include_types: raise ValueError( diff --git a/libs/langchain/tests/unit_tests/chains/test_graph_qa.py b/libs/langchain/tests/unit_tests/chains/test_graph_qa.py index ed4fe9feb6..51c968328b 100644 --- a/libs/langchain/tests/unit_tests/chains/test_graph_qa.py +++ b/libs/langchain/tests/unit_tests/chains/test_graph_qa.py @@ -1,9 +1,186 @@ -from typing import List +from typing import Any, Dict, List import pandas as pd -from langchain.chains.graph_qa.cypher import construct_schema, extract_cypher +from langchain.chains.graph_qa.cypher import ( + GraphCypherQAChain, + construct_schema, + extract_cypher, +) from langchain.chains.graph_qa.cypher_utils import CypherQueryCorrector, Schema +from langchain.chains.graph_qa.prompts import CYPHER_GENERATION_PROMPT, CYPHER_QA_PROMPT +from langchain.graphs.graph_document import GraphDocument +from langchain.graphs.graph_store import GraphStore +from langchain.memory import ConversationBufferMemory, ReadOnlySharedMemory +from langchain.prompts import PromptTemplate +from tests.unit_tests.llms.fake_llm import FakeLLM + + +class FakeGraphStore(GraphStore): + @property + def get_schema(self) -> str: + """Returns the schema of the Graph database""" + return "" + + @property + def get_structured_schema(self) -> Dict[str, Any]: + """Returns the schema of the Graph database""" + return {} + + def query(self, query: str, params: dict = {}) -> List[Dict[str, Any]]: + """Query the graph.""" + return [] + + def refresh_schema(self) -> None: + """Refreshes the graph schema information.""" + pass + + def add_graph_documents( + self, graph_documents: List[GraphDocument], include_source: bool = False + ) -> None: + """Take GraphDocument as input as uses it to construct a graph.""" + pass + + +def test_graph_cypher_qa_chain_prompt_selection_1() -> None: + # Pass prompts directly. No kwargs is specified. + qa_prompt_template = "QA Prompt" + cypher_prompt_template = "Cypher Prompt" + qa_prompt = PromptTemplate(template=qa_prompt_template, input_variables=[]) + cypher_prompt = PromptTemplate(template=cypher_prompt_template, input_variables=[]) + chain = GraphCypherQAChain.from_llm( + llm=FakeLLM(), + graph=FakeGraphStore(), + verbose=True, + return_intermediate_steps=False, + qa_prompt=qa_prompt, + cypher_prompt=cypher_prompt, + ) + assert chain.qa_chain.prompt == qa_prompt + assert chain.cypher_generation_chain.prompt == cypher_prompt + + +def test_graph_cypher_qa_chain_prompt_selection_2() -> None: + # Default case. Pass nothing + chain = GraphCypherQAChain.from_llm( + llm=FakeLLM(), + graph=FakeGraphStore(), + verbose=True, + return_intermediate_steps=False, + ) + assert chain.qa_chain.prompt == CYPHER_QA_PROMPT + assert chain.cypher_generation_chain.prompt == CYPHER_GENERATION_PROMPT + + +def test_graph_cypher_qa_chain_prompt_selection_3() -> None: + # Pass non-prompt args only to sub-chains via kwargs + memory = ConversationBufferMemory(memory_key="chat_history") + readonlymemory = ReadOnlySharedMemory(memory=memory) + chain = GraphCypherQAChain.from_llm( + llm=FakeLLM(), + graph=FakeGraphStore(), + verbose=True, + return_intermediate_steps=False, + cypher_llm_kwargs={"memory": readonlymemory}, + qa_llm_kwargs={"memory": readonlymemory}, + ) + assert chain.qa_chain.prompt == CYPHER_QA_PROMPT + assert chain.cypher_generation_chain.prompt == CYPHER_GENERATION_PROMPT + + +def test_graph_cypher_qa_chain_prompt_selection_4() -> None: + # Pass prompt, non-prompt args to subchains via kwargs + qa_prompt_template = "QA Prompt" + cypher_prompt_template = "Cypher Prompt" + memory = ConversationBufferMemory(memory_key="chat_history") + readonlymemory = ReadOnlySharedMemory(memory=memory) + qa_prompt = PromptTemplate(template=qa_prompt_template, input_variables=[]) + cypher_prompt = PromptTemplate(template=cypher_prompt_template, input_variables=[]) + chain = GraphCypherQAChain.from_llm( + llm=FakeLLM(), + graph=FakeGraphStore(), + verbose=True, + return_intermediate_steps=False, + cypher_llm_kwargs={"prompt": cypher_prompt, "memory": readonlymemory}, + qa_llm_kwargs={"prompt": qa_prompt, "memory": readonlymemory}, + ) + assert chain.qa_chain.prompt == qa_prompt + assert chain.cypher_generation_chain.prompt == cypher_prompt + + +def test_graph_cypher_qa_chain_prompt_selection_5() -> None: + # Can't pass both prompt and kwargs at the same time + qa_prompt_template = "QA Prompt" + cypher_prompt_template = "Cypher Prompt" + memory = ConversationBufferMemory(memory_key="chat_history") + readonlymemory = ReadOnlySharedMemory(memory=memory) + qa_prompt = PromptTemplate(template=qa_prompt_template, input_variables=[]) + cypher_prompt = PromptTemplate(template=cypher_prompt_template, input_variables=[]) + try: + GraphCypherQAChain.from_llm( + llm=FakeLLM(), + graph=FakeGraphStore(), + verbose=True, + return_intermediate_steps=False, + qa_prompt=qa_prompt, + cypher_prompt=cypher_prompt, + cypher_llm_kwargs={"memory": readonlymemory}, + qa_llm_kwargs={"memory": readonlymemory}, + ) + assert False + except ValueError: + assert True + + +def test_graph_cypher_qa_chain() -> None: + template = """You are a nice chatbot having a conversation with a human. + + Schema: + {schema} + + Previous conversation: + {chat_history} + + New human question: {question} + Response:""" + + prompt = PromptTemplate( + input_variables=["schema", "question", "chat_history"], template=template + ) + + memory = ConversationBufferMemory(memory_key="chat_history") + readonlymemory = ReadOnlySharedMemory(memory=memory) + prompt1 = ( + "You are a nice chatbot having a conversation with a human.\n\n " + "Schema:\n Node properties are the following: \n {}\nRelationships " + "properties are the following: \n {}\nRelationships are: \n[]\n\n " + "Previous conversation:\n \n\n New human question: " + "Test question\n Response:" + ) + + prompt2 = ( + "You are a nice chatbot having a conversation with a human.\n\n " + "Schema:\n Node properties are the following: \n {}\nRelationships " + "properties are the following: \n {}\nRelationships are: \n[]\n\n " + "Previous conversation:\n Human: Test question\nAI: foo\n\n " + "New human question: Test new question\n Response:" + ) + + llm = FakeLLM(queries={prompt1: "answer1", prompt2: "answer2"}) + chain = GraphCypherQAChain.from_llm( + cypher_llm=llm, + qa_llm=FakeLLM(), + graph=FakeGraphStore(), + verbose=True, + return_intermediate_steps=False, + cypher_llm_kwargs={"prompt": prompt, "memory": readonlymemory}, + memory=memory, + ) + chain.run("Test question") + chain.run("Test new question") + # If we get here without a key error, that means memory + # was used properly to create prompts. + assert True def test_no_backticks() -> None: diff --git a/libs/langchain/tests/unit_tests/llms/fake_llm.py b/libs/langchain/tests/unit_tests/llms/fake_llm.py index c92df7bf1c..199c0f3d57 100644 --- a/libs/langchain/tests/unit_tests/llms/fake_llm.py +++ b/libs/langchain/tests/unit_tests/llms/fake_llm.py @@ -39,9 +39,11 @@ class FakeLLM(LLM): run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> str: + print(prompt) if self.sequential_responses: return self._get_next_response_in_sequence - + print(repr(prompt)) + print(self.queries) if self.queries is not None: return self.queries[prompt] if stop is None: