diff --git a/docs/docs/use_cases/graph/graph_sparql_qa.ipynb b/docs/docs/use_cases/graph/graph_sparql_qa.ipynb index b7541f4388..3d41a68a84 100644 --- a/docs/docs/use_cases/graph/graph_sparql_qa.ipynb +++ b/docs/docs/use_cases/graph/graph_sparql_qa.ipynb @@ -21,7 +21,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 3, "id": "62812aad", "metadata": { "pycharm": { @@ -37,7 +37,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 4, "id": "0928915d", "metadata": { "pycharm": { @@ -74,7 +74,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 5, "id": "4e3de44f", "metadata": { "pycharm": { @@ -88,7 +88,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 6, "id": "1fe76ccd", "metadata": {}, "outputs": [ @@ -121,7 +121,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": null, "id": "7476ce98", "metadata": { "pycharm": { @@ -250,7 +250,7 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 27, "id": "f874171b", "metadata": {}, "outputs": [ @@ -277,13 +277,99 @@ ")\n", "graph.query(query)" ] + }, + { + "cell_type": "markdown", + "id": "eb00a625-a6c9-4766-b3f0-eaed024851c9", + "metadata": {}, + "source": [ + "## Return SQARQL query\n", + "You can return the SPARQL query step from the Sparql QA Chain using the `return_sparql_query` parameter" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "id": "f13e2865-176a-4417-95e6-db818b214d08", + "metadata": {}, + "outputs": [], + "source": [ + "chain = GraphSparqlQAChain.from_llm(\n", + " ChatOpenAI(temperature=0), graph=graph, verbose=True, return_sparql_query=True\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "id": "4f4d47b6-4202-4e74-8c88-aeaac5344c04", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\n", + "\u001b[1m> Entering new GraphSparqlQAChain chain...\u001b[0m\n", + "Identified intent:\n", + "\u001b[32;1m\u001b[1;3mSELECT\u001b[0m\n", + "Generated SPARQL:\n", + "\u001b[32;1m\u001b[1;3mPREFIX foaf: \n", + "SELECT ?workHomepage\n", + "WHERE {\n", + " ?person foaf:name \"Tim Berners-Lee\" .\n", + " ?person foaf:workplaceHomepage ?workHomepage .\n", + "}\u001b[0m\n", + "Full Context:\n", + "\u001b[32;1m\u001b[1;3m[]\u001b[0m\n", + "\n", + "\u001b[1m> Finished chain.\u001b[0m\n", + "SQARQL query: PREFIX foaf: \n", + "SELECT ?workHomepage\n", + "WHERE {\n", + " ?person foaf:name \"Tim Berners-Lee\" .\n", + " ?person foaf:workplaceHomepage ?workHomepage .\n", + "}\n", + "Final answer: Tim Berners-Lee's work homepage is http://www.w3.org/People/Berners-Lee/.\n" + ] + } + ], + "source": [ + "result = chain(\"What is Tim Berners-Lee's work homepage?\")\n", + "print(f\"SQARQL query: {result['sparql_query']}\")\n", + "print(f\"Final answer: {result['result']}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "id": "be3d9ff7-dc00-47d0-857d-fd40437a3f22", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "PREFIX foaf: \n", + "SELECT ?workHomepage\n", + "WHERE {\n", + " ?person foaf:name \"Tim Berners-Lee\" .\n", + " ?person foaf:workplaceHomepage ?workHomepage .\n", + "}\n" + ] + } + ], + "source": [ + "print(result[\"sparql_query\"])" + ] } ], "metadata": { "kernelspec": { - "display_name": "lc", + "display_name": "Python 3 (ipykernel)", "language": "python", - "name": "lc" + "name": "python3" }, "language_info": { "codemirror_mode": { @@ -295,9 +381,9 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.4" + "version": "3.10.9" } }, "nbformat": 4, "nbformat_minor": 5 -} \ No newline at end of file +} diff --git a/libs/langchain/langchain/chains/graph_qa/sparql.py b/libs/langchain/langchain/chains/graph_qa/sparql.py index f9d4897090..f1c5d2fc81 100644 --- a/libs/langchain/langchain/chains/graph_qa/sparql.py +++ b/libs/langchain/langchain/chains/graph_qa/sparql.py @@ -41,15 +41,25 @@ class GraphSparqlQAChain(Chain): sparql_generation_update_chain: LLMChain sparql_intent_chain: LLMChain qa_chain: LLMChain + return_sparql_query: bool = False input_key: str = "query" #: :meta private: output_key: str = "result" #: :meta private: + sparql_query_key: str = "sparql_query" #: :meta private: @property def input_keys(self) -> List[str]: + """Return the input keys. + + :meta private: + """ return [self.input_key] @property def output_keys(self) -> List[str]: + """Return the output keys. + + :meta private: + """ _output_keys = [self.output_key] return _output_keys @@ -135,4 +145,8 @@ class GraphSparqlQAChain(Chain): res = "Successfully inserted triples into the graph." else: raise ValueError("Unsupported SPARQL query type.") - return {self.output_key: res} + + chain_result: Dict[str, Any] = {self.output_key: res} + if self.return_sparql_query: + chain_result[self.sparql_query_key] = generated_sparql + return chain_result diff --git a/libs/langchain/tests/integration_tests/chains/test_graph_database_sparql.py b/libs/langchain/tests/integration_tests/chains/test_graph_database_sparql.py index 5a0fbde3cf..844e0ce86f 100644 --- a/libs/langchain/tests/integration_tests/chains/test_graph_database_sparql.py +++ b/libs/langchain/tests/integration_tests/chains/test_graph_database_sparql.py @@ -78,3 +78,29 @@ def test_sparql_insert() -> None: os.remove(_local_copy) except OSError: pass + + +def test_sparql_select_return_query() -> None: + """ + Test for generating and executing simple SPARQL SELECT query + and returning the generated SPARQL query. + """ + berners_lee_card = "http://www.w3.org/People/Berners-Lee/card" + + graph = RdfGraph( + source_file=berners_lee_card, + standard="rdf", + ) + + chain = GraphSparqlQAChain.from_llm( + OpenAI(temperature=0), graph=graph, return_sparql_query=True + ) + output = chain("What is Tim Berners-Lee's work homepage?") + + # Verify the expected answer + expected_output = ( + " The work homepage of Tim Berners-Lee is " + "http://www.w3.org/People/Berners-Lee/." + ) + assert output["result"] == expected_output + assert "sparql_query" in output