diff --git a/docs/docs/use_cases/graph/graph_sparql_qa.ipynb b/docs/docs/use_cases/graph/graph_sparql_qa.ipynb
index b7541f4388..3d41a68a84 100644
--- a/docs/docs/use_cases/graph/graph_sparql_qa.ipynb
+++ b/docs/docs/use_cases/graph/graph_sparql_qa.ipynb
@@ -21,7 +21,7 @@
},
{
"cell_type": "code",
- "execution_count": 1,
+ "execution_count": 3,
"id": "62812aad",
"metadata": {
"pycharm": {
@@ -37,7 +37,7 @@
},
{
"cell_type": "code",
- "execution_count": 8,
+ "execution_count": 4,
"id": "0928915d",
"metadata": {
"pycharm": {
@@ -74,7 +74,7 @@
},
{
"cell_type": "code",
- "execution_count": 9,
+ "execution_count": 5,
"id": "4e3de44f",
"metadata": {
"pycharm": {
@@ -88,7 +88,7 @@
},
{
"cell_type": "code",
- "execution_count": 10,
+ "execution_count": 6,
"id": "1fe76ccd",
"metadata": {},
"outputs": [
@@ -121,7 +121,7 @@
},
{
"cell_type": "code",
- "execution_count": 11,
+ "execution_count": null,
"id": "7476ce98",
"metadata": {
"pycharm": {
@@ -250,7 +250,7 @@
},
{
"cell_type": "code",
- "execution_count": 15,
+ "execution_count": 27,
"id": "f874171b",
"metadata": {},
"outputs": [
@@ -277,13 +277,99 @@
")\n",
"graph.query(query)"
]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "eb00a625-a6c9-4766-b3f0-eaed024851c9",
+ "metadata": {},
+ "source": [
+ "## Return SQARQL query\n",
+ "You can return the SPARQL query step from the Sparql QA Chain using the `return_sparql_query` parameter"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 20,
+ "id": "f13e2865-176a-4417-95e6-db818b214d08",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "chain = GraphSparqlQAChain.from_llm(\n",
+ " ChatOpenAI(temperature=0), graph=graph, verbose=True, return_sparql_query=True\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 29,
+ "id": "4f4d47b6-4202-4e74-8c88-aeaac5344c04",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "\n",
+ "\u001b[1m> Entering new GraphSparqlQAChain chain...\u001b[0m\n",
+ "Identified intent:\n",
+ "\u001b[32;1m\u001b[1;3mSELECT\u001b[0m\n",
+ "Generated SPARQL:\n",
+ "\u001b[32;1m\u001b[1;3mPREFIX foaf: \n",
+ "SELECT ?workHomepage\n",
+ "WHERE {\n",
+ " ?person foaf:name \"Tim Berners-Lee\" .\n",
+ " ?person foaf:workplaceHomepage ?workHomepage .\n",
+ "}\u001b[0m\n",
+ "Full Context:\n",
+ "\u001b[32;1m\u001b[1;3m[]\u001b[0m\n",
+ "\n",
+ "\u001b[1m> Finished chain.\u001b[0m\n",
+ "SQARQL query: PREFIX foaf: \n",
+ "SELECT ?workHomepage\n",
+ "WHERE {\n",
+ " ?person foaf:name \"Tim Berners-Lee\" .\n",
+ " ?person foaf:workplaceHomepage ?workHomepage .\n",
+ "}\n",
+ "Final answer: Tim Berners-Lee's work homepage is http://www.w3.org/People/Berners-Lee/.\n"
+ ]
+ }
+ ],
+ "source": [
+ "result = chain(\"What is Tim Berners-Lee's work homepage?\")\n",
+ "print(f\"SQARQL query: {result['sparql_query']}\")\n",
+ "print(f\"Final answer: {result['result']}\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 30,
+ "id": "be3d9ff7-dc00-47d0-857d-fd40437a3f22",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "PREFIX foaf: \n",
+ "SELECT ?workHomepage\n",
+ "WHERE {\n",
+ " ?person foaf:name \"Tim Berners-Lee\" .\n",
+ " ?person foaf:workplaceHomepage ?workHomepage .\n",
+ "}\n"
+ ]
+ }
+ ],
+ "source": [
+ "print(result[\"sparql_query\"])"
+ ]
}
],
"metadata": {
"kernelspec": {
- "display_name": "lc",
+ "display_name": "Python 3 (ipykernel)",
"language": "python",
- "name": "lc"
+ "name": "python3"
},
"language_info": {
"codemirror_mode": {
@@ -295,9 +381,9 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
- "version": "3.11.4"
+ "version": "3.10.9"
}
},
"nbformat": 4,
"nbformat_minor": 5
-}
\ No newline at end of file
+}
diff --git a/libs/langchain/langchain/chains/graph_qa/sparql.py b/libs/langchain/langchain/chains/graph_qa/sparql.py
index f9d4897090..f1c5d2fc81 100644
--- a/libs/langchain/langchain/chains/graph_qa/sparql.py
+++ b/libs/langchain/langchain/chains/graph_qa/sparql.py
@@ -41,15 +41,25 @@ class GraphSparqlQAChain(Chain):
sparql_generation_update_chain: LLMChain
sparql_intent_chain: LLMChain
qa_chain: LLMChain
+ return_sparql_query: bool = False
input_key: str = "query" #: :meta private:
output_key: str = "result" #: :meta private:
+ sparql_query_key: str = "sparql_query" #: :meta private:
@property
def input_keys(self) -> List[str]:
+ """Return the input keys.
+
+ :meta private:
+ """
return [self.input_key]
@property
def output_keys(self) -> List[str]:
+ """Return the output keys.
+
+ :meta private:
+ """
_output_keys = [self.output_key]
return _output_keys
@@ -135,4 +145,8 @@ class GraphSparqlQAChain(Chain):
res = "Successfully inserted triples into the graph."
else:
raise ValueError("Unsupported SPARQL query type.")
- return {self.output_key: res}
+
+ chain_result: Dict[str, Any] = {self.output_key: res}
+ if self.return_sparql_query:
+ chain_result[self.sparql_query_key] = generated_sparql
+ return chain_result
diff --git a/libs/langchain/tests/integration_tests/chains/test_graph_database_sparql.py b/libs/langchain/tests/integration_tests/chains/test_graph_database_sparql.py
index 5a0fbde3cf..844e0ce86f 100644
--- a/libs/langchain/tests/integration_tests/chains/test_graph_database_sparql.py
+++ b/libs/langchain/tests/integration_tests/chains/test_graph_database_sparql.py
@@ -78,3 +78,29 @@ def test_sparql_insert() -> None:
os.remove(_local_copy)
except OSError:
pass
+
+
+def test_sparql_select_return_query() -> None:
+ """
+ Test for generating and executing simple SPARQL SELECT query
+ and returning the generated SPARQL query.
+ """
+ berners_lee_card = "http://www.w3.org/People/Berners-Lee/card"
+
+ graph = RdfGraph(
+ source_file=berners_lee_card,
+ standard="rdf",
+ )
+
+ chain = GraphSparqlQAChain.from_llm(
+ OpenAI(temperature=0), graph=graph, return_sparql_query=True
+ )
+ output = chain("What is Tim Berners-Lee's work homepage?")
+
+ # Verify the expected answer
+ expected_output = (
+ " The work homepage of Tim Berners-Lee is "
+ "http://www.w3.org/People/Berners-Lee/."
+ )
+ assert output["result"] == expected_output
+ assert "sparql_query" in output