From 76a193decc2ba61bc0f96ee30987a28d108b196b Mon Sep 17 00:00:00 2001 From: Tomaz Bratanic Date: Mon, 10 Jun 2024 13:52:17 -0700 Subject: [PATCH] community[patch]: Add function response to graph cypher qa chain (#22690) LLMs struggle with Graph RAG, because it's different from vector RAG in a way that you don't provide the whole context, only the answer and the LLM has to believe. However, that doesn't really work a lot of the time. However, if you wrap the context as function response the accuracy is much better. btw... `union[LLMChain, Runnable]` is linting fun, that's why so many ignores --- .../integrations/graphs/neo4j_cypher.ipynb | 150 +++++++++++++++--- .../chains/graph_qa/cypher.py | 92 +++++++++-- .../tests/unit_tests/chains/test_graph_qa.py | 8 +- 3 files changed, 211 insertions(+), 39 deletions(-) diff --git a/docs/docs/integrations/graphs/neo4j_cypher.ipynb b/docs/docs/integrations/graphs/neo4j_cypher.ipynb index 564c01ec28..3348f0b8d2 100644 --- a/docs/docs/integrations/graphs/neo4j_cypher.ipynb +++ b/docs/docs/integrations/graphs/neo4j_cypher.ipynb @@ -164,10 +164,10 @@ "text": [ "Node properties:\n", "- **Movie**\n", - " - `runtime: INTEGER` Min: 120, Max: 120\n", - " - `name: STRING` Available options: ['Top Gun']\n", + " - `runtime`: INTEGER Min: 120, Max: 120\n", + " - `name`: STRING Available options: ['Top Gun']\n", "- **Actor**\n", - " - `name: STRING` Available options: ['Tom Cruise', 'Val Kilmer', 'Anthony Edwards', 'Meg Ryan']\n", + " - `name`: STRING Available options: ['Tom Cruise', 'Val Kilmer', 'Anthony Edwards', 'Meg Ryan']\n", "Relationship properties:\n", "\n", "The relationships:\n", @@ -225,7 +225,7 @@ "WHERE m.name = 'Top Gun'\n", "RETURN a.name\u001b[0m\n", "Full Context:\n", - "\u001b[32;1m\u001b[1;3m[{'a.name': 'Anthony Edwards'}, {'a.name': 'Meg Ryan'}, {'a.name': 'Val Kilmer'}, {'a.name': 'Tom Cruise'}]\u001b[0m\n", + "\u001b[32;1m\u001b[1;3m[{'a.name': 'Tom Cruise'}, {'a.name': 'Val Kilmer'}, {'a.name': 'Anthony Edwards'}, {'a.name': 'Meg Ryan'}]\u001b[0m\n", "\n", "\u001b[1m> Finished chain.\u001b[0m\n" ] @@ -234,7 +234,7 @@ "data": { "text/plain": [ "{'query': 'Who played in Top Gun?',\n", - " 'result': 'Anthony Edwards, Meg Ryan, Val Kilmer, Tom Cruise played in Top Gun.'}" + " 'result': 'Tom Cruise, Val Kilmer, Anthony Edwards, and Meg Ryan played in Top Gun.'}" ] }, "execution_count": 8, @@ -286,7 +286,7 @@ "WHERE m.name = 'Top Gun'\n", "RETURN a.name\u001b[0m\n", "Full Context:\n", - "\u001b[32;1m\u001b[1;3m[{'a.name': 'Anthony Edwards'}, {'a.name': 'Meg Ryan'}]\u001b[0m\n", + "\u001b[32;1m\u001b[1;3m[{'a.name': 'Tom Cruise'}, {'a.name': 'Val Kilmer'}]\u001b[0m\n", "\n", "\u001b[1m> Finished chain.\u001b[0m\n" ] @@ -295,7 +295,7 @@ "data": { "text/plain": [ "{'query': 'Who played in Top Gun?',\n", - " 'result': 'Anthony Edwards, Meg Ryan played in Top Gun.'}" + " 'result': 'Tom Cruise, Val Kilmer played in Top Gun.'}" ] }, "execution_count": 10, @@ -346,11 +346,11 @@ "WHERE m.name = 'Top Gun'\n", "RETURN a.name\u001b[0m\n", "Full Context:\n", - "\u001b[32;1m\u001b[1;3m[{'a.name': 'Anthony Edwards'}, {'a.name': 'Meg Ryan'}, {'a.name': 'Val Kilmer'}, {'a.name': 'Tom Cruise'}]\u001b[0m\n", + "\u001b[32;1m\u001b[1;3m[{'a.name': 'Tom Cruise'}, {'a.name': 'Val Kilmer'}, {'a.name': 'Anthony Edwards'}, {'a.name': 'Meg Ryan'}]\u001b[0m\n", "\n", "\u001b[1m> Finished chain.\u001b[0m\n", - "Intermediate steps: [{'query': \"MATCH (a:Actor)-[:ACTED_IN]->(m:Movie)\\nWHERE m.name = 'Top Gun'\\nRETURN a.name\"}, {'context': [{'a.name': 'Anthony Edwards'}, {'a.name': 'Meg Ryan'}, {'a.name': 'Val Kilmer'}, {'a.name': 'Tom Cruise'}]}]\n", - "Final answer: Anthony Edwards, Meg Ryan, Val Kilmer, Tom Cruise played in Top Gun.\n" + "Intermediate steps: [{'query': \"MATCH (a:Actor)-[:ACTED_IN]->(m:Movie)\\nWHERE m.name = 'Top Gun'\\nRETURN a.name\"}, {'context': [{'a.name': 'Tom Cruise'}, {'a.name': 'Val Kilmer'}, {'a.name': 'Anthony Edwards'}, {'a.name': 'Meg Ryan'}]}]\n", + "Final answer: Tom Cruise, Val Kilmer, Anthony Edwards, and Meg Ryan played in Top Gun.\n" ] } ], @@ -406,10 +406,10 @@ "data": { "text/plain": [ "{'query': 'Who played in Top Gun?',\n", - " 'result': [{'a.name': 'Anthony Edwards'},\n", - " {'a.name': 'Meg Ryan'},\n", + " 'result': [{'a.name': 'Tom Cruise'},\n", " {'a.name': 'Val Kilmer'},\n", - " {'a.name': 'Tom Cruise'}]}" + " {'a.name': 'Anthony Edwards'},\n", + " {'a.name': 'Meg Ryan'}]}" ] }, "execution_count": 14, @@ -482,7 +482,7 @@ "\n", "\u001b[1m> Entering new GraphCypherQAChain chain...\u001b[0m\n", "Generated Cypher:\n", - "\u001b[32;1m\u001b[1;3mMATCH (:Movie {name:\"Top Gun\"})<-[:ACTED_IN]-()\n", + "\u001b[32;1m\u001b[1;3mMATCH (m:Movie {name:\"Top Gun\"})<-[:ACTED_IN]-()\n", "RETURN count(*) AS numberOfActors\u001b[0m\n", "Full Context:\n", "\u001b[32;1m\u001b[1;3m[{'numberOfActors': 4}]\u001b[0m\n", @@ -494,7 +494,7 @@ "data": { "text/plain": [ "{'query': 'How many people played in Top Gun?',\n", - " 'result': 'There were 4 actors who played in Top Gun.'}" + " 'result': 'There were 4 actors in Top Gun.'}" ] }, "execution_count": 16, @@ -548,7 +548,7 @@ "WHERE m.name = 'Top Gun'\n", "RETURN a.name\u001b[0m\n", "Full Context:\n", - "\u001b[32;1m\u001b[1;3m[{'a.name': 'Anthony Edwards'}, {'a.name': 'Meg Ryan'}, {'a.name': 'Val Kilmer'}, {'a.name': 'Tom Cruise'}]\u001b[0m\n", + "\u001b[32;1m\u001b[1;3m[{'a.name': 'Tom Cruise'}, {'a.name': 'Val Kilmer'}, {'a.name': 'Anthony Edwards'}, {'a.name': 'Meg Ryan'}]\u001b[0m\n", "\n", "\u001b[1m> Finished chain.\u001b[0m\n" ] @@ -557,7 +557,7 @@ "data": { "text/plain": [ "{'query': 'Who played in Top Gun?',\n", - " 'result': 'Anthony Edwards, Meg Ryan, Val Kilmer, and Tom Cruise played in Top Gun.'}" + " 'result': 'Tom Cruise, Val Kilmer, Anthony Edwards, and Meg Ryan played in Top Gun.'}" ] }, "execution_count": 18, @@ -661,7 +661,7 @@ "WHERE m.name = 'Top Gun'\n", "RETURN a.name\u001b[0m\n", "Full Context:\n", - "\u001b[32;1m\u001b[1;3m[{'a.name': 'Anthony Edwards'}, {'a.name': 'Meg Ryan'}, {'a.name': 'Val Kilmer'}, {'a.name': 'Tom Cruise'}]\u001b[0m\n", + "\u001b[32;1m\u001b[1;3m[{'a.name': 'Tom Cruise'}, {'a.name': 'Val Kilmer'}, {'a.name': 'Anthony Edwards'}, {'a.name': 'Meg Ryan'}]\u001b[0m\n", "\n", "\u001b[1m> Finished chain.\u001b[0m\n" ] @@ -670,7 +670,7 @@ "data": { "text/plain": [ "{'query': 'Who played in Top Gun?',\n", - " 'result': 'Anthony Edwards, Meg Ryan, Val Kilmer, Tom Cruise played in Top Gun.'}" + " 'result': 'Tom Cruise, Val Kilmer, Anthony Edwards, and Meg Ryan played in Top Gun.'}" ] }, "execution_count": 22, @@ -682,13 +682,117 @@ "chain.invoke({\"query\": \"Who played in Top Gun?\"})" ] }, + { + "cell_type": "markdown", + "id": "81093062-eb7f-4d96-b1fd-c36b8f1b9474", + "metadata": {}, + "source": [ + "## Provide context from database results as tool/function output\n", + "\n", + "You can use the `use_function_response` parameter to pass context from database results to an LLM as a tool/function output. This method improves the response accuracy and relevance of an answer as the LLM follows the provided context more closely.\n", + "_You will need to use an LLM with native function calling support to use this feature_." + ] + }, { "cell_type": "code", - "execution_count": null, - "id": "3fa3f3d5-f7e7-4ca9-8f07-ca22b897f192", + "execution_count": 23, + "id": "2be8f51c-e80a-4a60-ab1c-266450fc17cd", "metadata": {}, - "outputs": [], - "source": [] + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\n", + "\u001b[1m> Entering new GraphCypherQAChain chain...\u001b[0m\n", + "Generated Cypher:\n", + "\u001b[32;1m\u001b[1;3mMATCH (a:Actor)-[:ACTED_IN]->(m:Movie)\n", + "WHERE m.name = 'Top Gun'\n", + "RETURN a.name\u001b[0m\n", + "Full Context:\n", + "\u001b[32;1m\u001b[1;3m[{'a.name': 'Tom Cruise'}, {'a.name': 'Val Kilmer'}, {'a.name': 'Anthony Edwards'}, {'a.name': 'Meg Ryan'}]\u001b[0m\n", + "\n", + "\u001b[1m> Finished chain.\u001b[0m\n" + ] + }, + { + "data": { + "text/plain": [ + "{'query': 'Who played in Top Gun?',\n", + " 'result': 'The main actors in Top Gun are Tom Cruise, Val Kilmer, Anthony Edwards, and Meg Ryan.'}" + ] + }, + "execution_count": 23, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "chain = GraphCypherQAChain.from_llm(\n", + " llm=ChatOpenAI(temperature=0, model=\"gpt-3.5-turbo\"),\n", + " graph=graph,\n", + " verbose=True,\n", + " use_function_response=True,\n", + ")\n", + "chain.invoke({\"query\": \"Who played in Top Gun?\"})" + ] + }, + { + "cell_type": "markdown", + "id": "48a75785-5bc9-49a7-a41b-88bf3ef9d312", + "metadata": {}, + "source": [ + "You can provide custom system message when using the function response feature by providing `function_response_system` to instruct the model on how to generate answers.\n", + "\n", + "_Note that `qa_prompt` will have no effect when using `use_function_response`_" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "id": "ddf0a61e-f104-4dbb-abbf-e65f3f57dd9a", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\n", + "\u001b[1m> Entering new GraphCypherQAChain chain...\u001b[0m\n", + "Generated Cypher:\n", + "\u001b[32;1m\u001b[1;3mMATCH (a:Actor)-[:ACTED_IN]->(m:Movie)\n", + "WHERE m.name = 'Top Gun'\n", + "RETURN a.name\u001b[0m\n", + "Full Context:\n", + "\u001b[32;1m\u001b[1;3m[{'a.name': 'Tom Cruise'}, {'a.name': 'Val Kilmer'}, {'a.name': 'Anthony Edwards'}, {'a.name': 'Meg Ryan'}]\u001b[0m\n", + "\n", + "\u001b[1m> Finished chain.\u001b[0m\n" + ] + }, + { + "data": { + "text/plain": [ + "{'query': 'Who played in Top Gun?',\n", + " 'result': \"Arrr matey! In the film Top Gun, ye be seein' Tom Cruise, Val Kilmer, Anthony Edwards, and Meg Ryan sailin' the high seas of the sky! Aye, they be a fine crew of actors, they be!\"}" + ] + }, + "execution_count": 24, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "chain = GraphCypherQAChain.from_llm(\n", + " llm=ChatOpenAI(temperature=0, model=\"gpt-3.5-turbo\"),\n", + " graph=graph,\n", + " verbose=True,\n", + " use_function_response=True,\n", + " function_response_system=\"Respond as a pirate!\",\n", + ")\n", + "chain.invoke({\"query\": \"Who played in Top Gun?\"})" + ] } ], "metadata": { diff --git a/libs/community/langchain_community/chains/graph_qa/cypher.py b/libs/community/langchain_community/chains/graph_qa/cypher.py index e43f24037b..6df5108182 100644 --- a/libs/community/langchain_community/chains/graph_qa/cypher.py +++ b/libs/community/langchain_community/chains/graph_qa/cypher.py @@ -2,14 +2,27 @@ from __future__ import annotations import re -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Union from langchain.chains.base import Chain from langchain.chains.llm import LLMChain from langchain_core.callbacks import CallbackManagerForChainRun from langchain_core.language_models import BaseLanguageModel -from langchain_core.prompts import BasePromptTemplate +from langchain_core.messages import ( + AIMessage, + BaseMessage, + SystemMessage, + ToolMessage, +) +from langchain_core.output_parsers import StrOutputParser +from langchain_core.prompts import ( + BasePromptTemplate, + ChatPromptTemplate, + HumanMessagePromptTemplate, + MessagesPlaceholder, +) from langchain_core.pydantic_v1 import Field +from langchain_core.runnables import Runnable from langchain_community.chains.graph_qa.cypher_utils import ( CypherQueryCorrector, @@ -23,6 +36,12 @@ from langchain_community.graphs.graph_store import GraphStore INTERMEDIATE_STEPS_KEY = "intermediate_steps" +FUNCTION_RESPONSE_SYSTEM = """You are an assistant that helps to form nice and human +understandable answers based on the provided information from tools. +Do not add any other information that wasn't present in the tools, and use +very concise style in interpreting results! +""" + def extract_cypher(text: str) -> str: """Extract Cypher code from a text. @@ -104,6 +123,31 @@ def construct_schema( ) +def get_function_response( + question: str, context: List[Dict[str, Any]] +) -> List[BaseMessage]: + TOOL_ID = "call_H7fABDuzEau48T10Qn0Lsh0D" + messages = [ + AIMessage( + content="", + additional_kwargs={ + "tool_calls": [ + { + "id": TOOL_ID, + "function": { + "arguments": '{"question":"' + question + '"}', + "name": "GetInformation", + }, + "type": "function", + } + ] + }, + ), + ToolMessage(content=str(context), tool_call_id=TOOL_ID), + ] + return messages + + class GraphCypherQAChain(Chain): """Chain for question-answering against a graph by generating Cypher statements. @@ -121,7 +165,7 @@ class GraphCypherQAChain(Chain): graph: GraphStore = Field(exclude=True) cypher_generation_chain: LLMChain - qa_chain: LLMChain + qa_chain: Union[LLMChain, Runnable] graph_schema: str input_key: str = "query" #: :meta private: output_key: str = "result" #: :meta private: @@ -133,6 +177,8 @@ class GraphCypherQAChain(Chain): """Whether or not to return the result of querying the graph directly.""" cypher_query_corrector: Optional[CypherQueryCorrector] = None """Optional cypher validation tool""" + use_function_response: bool = False + """Whether to wrap the database context as tool/function response""" @property def input_keys(self) -> List[str]: @@ -163,12 +209,14 @@ class GraphCypherQAChain(Chain): qa_prompt: Optional[BasePromptTemplate] = None, cypher_prompt: Optional[BasePromptTemplate] = None, cypher_llm: Optional[BaseLanguageModel] = None, - qa_llm: Optional[BaseLanguageModel] = None, + qa_llm: Optional[Union[BaseLanguageModel, Any]] = None, exclude_types: List[str] = [], include_types: List[str] = [], validate_cypher: bool = False, qa_llm_kwargs: Optional[Dict[str, Any]] = None, cypher_llm_kwargs: Optional[Dict[str, Any]] = None, + use_function_response: bool = False, + function_response_system: str = FUNCTION_RESPONSE_SYSTEM, **kwargs: Any, ) -> GraphCypherQAChain: """Initialize from LLM.""" @@ -205,7 +253,22 @@ class GraphCypherQAChain(Chain): cypher_prompt if cypher_prompt is not None else CYPHER_GENERATION_PROMPT ) - qa_chain = LLMChain(llm=qa_llm or llm, **use_qa_llm_kwargs) # type: ignore[arg-type] + qa_llm = qa_llm or llm + if use_function_response: + try: + qa_llm.bind_tools({}) # type: ignore[union-attr] + response_prompt = ChatPromptTemplate.from_messages( + [ + SystemMessage(content=function_response_system), + HumanMessagePromptTemplate.from_template("{question}"), + MessagesPlaceholder(variable_name="function_response"), + ] + ) + qa_chain = response_prompt | qa_llm | StrOutputParser() # type: ignore + except (NotImplementedError, AttributeError): + raise ValueError("Provided LLM does not support native tools/functions") + else: + qa_chain = LLMChain(llm=qa_llm, **use_qa_llm_kwargs) # type: ignore[arg-type] cypher_generation_chain = LLMChain( llm=cypher_llm or llm, # type: ignore[arg-type] @@ -217,7 +280,6 @@ class GraphCypherQAChain(Chain): "Either `exclude_types` or `include_types` " "can be provided, but not both" ) - graph_schema = construct_schema( kwargs["graph"].get_structured_schema, include_types, exclude_types ) @@ -235,6 +297,7 @@ class GraphCypherQAChain(Chain): qa_chain=qa_chain, cypher_generation_chain=cypher_generation_chain, cypher_query_corrector=cypher_query_corrector, + use_function_response=use_function_response, **kwargs, ) @@ -284,12 +347,17 @@ class GraphCypherQAChain(Chain): ) intermediate_steps.append({"context": context}) - - result = self.qa_chain( - {"question": question, "context": context}, - callbacks=callbacks, - ) - final_result = result[self.qa_chain.output_key] + if self.use_function_response: + function_response = get_function_response(question, context) + final_result = self.qa_chain.invoke( # type: ignore + {"question": question, "function_response": function_response}, + ) + else: + result = self.qa_chain.invoke( # type: ignore + {"question": question, "context": context}, + callbacks=callbacks, + ) + final_result = result[self.qa_chain.output_key] # type: ignore chain_result: Dict[str, Any] = {self.output_key: final_result} if self.return_intermediate_steps: diff --git a/libs/community/tests/unit_tests/chains/test_graph_qa.py b/libs/community/tests/unit_tests/chains/test_graph_qa.py index 51c41bffec..d654a96eb7 100644 --- a/libs/community/tests/unit_tests/chains/test_graph_qa.py +++ b/libs/community/tests/unit_tests/chains/test_graph_qa.py @@ -60,7 +60,7 @@ def test_graph_cypher_qa_chain_prompt_selection_1() -> None: qa_prompt=qa_prompt, cypher_prompt=cypher_prompt, ) - assert chain.qa_chain.prompt == qa_prompt + assert chain.qa_chain.prompt == qa_prompt # type: ignore[union-attr] assert chain.cypher_generation_chain.prompt == cypher_prompt @@ -72,7 +72,7 @@ def test_graph_cypher_qa_chain_prompt_selection_2() -> None: verbose=True, return_intermediate_steps=False, ) - assert chain.qa_chain.prompt == CYPHER_QA_PROMPT + assert chain.qa_chain.prompt == CYPHER_QA_PROMPT # type: ignore[union-attr] assert chain.cypher_generation_chain.prompt == CYPHER_GENERATION_PROMPT @@ -88,7 +88,7 @@ def test_graph_cypher_qa_chain_prompt_selection_3() -> None: cypher_llm_kwargs={"memory": readonlymemory}, qa_llm_kwargs={"memory": readonlymemory}, ) - assert chain.qa_chain.prompt == CYPHER_QA_PROMPT + assert chain.qa_chain.prompt == CYPHER_QA_PROMPT # type: ignore[union-attr] assert chain.cypher_generation_chain.prompt == CYPHER_GENERATION_PROMPT @@ -108,7 +108,7 @@ def test_graph_cypher_qa_chain_prompt_selection_4() -> None: cypher_llm_kwargs={"prompt": cypher_prompt, "memory": readonlymemory}, qa_llm_kwargs={"prompt": qa_prompt, "memory": readonlymemory}, ) - assert chain.qa_chain.prompt == qa_prompt + assert chain.qa_chain.prompt == qa_prompt # type: ignore[union-attr] assert chain.cypher_generation_chain.prompt == cypher_prompt