langchain[patch]: return formatted SPARQL query on demand (#11263)

- **Description:** Added the `return_sparql_query` feature to the
`GraphSparqlQAChain` class, allowing users to get the formatted SPARQL
query along with the chain's result.
  - **Issue:** NA
  - **Dependencies:** None

Note: I've ensured that the PR passes linting and testing by running
make format, make lint, and make test locally.

I have added a test for the integration (which relies on network access)
and I have added an example to the notebook showing its use.
This commit is contained in:
Reid Falconer 2024-02-23 02:03:26 +01:00 committed by GitHub
parent b15fccbb99
commit 0534ba5a7d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 137 additions and 11 deletions

View File

@ -21,7 +21,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 1, "execution_count": 3,
"id": "62812aad", "id": "62812aad",
"metadata": { "metadata": {
"pycharm": { "pycharm": {
@ -37,7 +37,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 8, "execution_count": 4,
"id": "0928915d", "id": "0928915d",
"metadata": { "metadata": {
"pycharm": { "pycharm": {
@ -74,7 +74,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 9, "execution_count": 5,
"id": "4e3de44f", "id": "4e3de44f",
"metadata": { "metadata": {
"pycharm": { "pycharm": {
@ -88,7 +88,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 10, "execution_count": 6,
"id": "1fe76ccd", "id": "1fe76ccd",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
@ -121,7 +121,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 11, "execution_count": null,
"id": "7476ce98", "id": "7476ce98",
"metadata": { "metadata": {
"pycharm": { "pycharm": {
@ -250,7 +250,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 15, "execution_count": 27,
"id": "f874171b", "id": "f874171b",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
@ -277,13 +277,99 @@
")\n", ")\n",
"graph.query(query)" "graph.query(query)"
] ]
},
{
"cell_type": "markdown",
"id": "eb00a625-a6c9-4766-b3f0-eaed024851c9",
"metadata": {},
"source": [
"## Return SQARQL query\n",
"You can return the SPARQL query step from the Sparql QA Chain using the `return_sparql_query` parameter"
]
},
{
"cell_type": "code",
"execution_count": 20,
"id": "f13e2865-176a-4417-95e6-db818b214d08",
"metadata": {},
"outputs": [],
"source": [
"chain = GraphSparqlQAChain.from_llm(\n",
" ChatOpenAI(temperature=0), graph=graph, verbose=True, return_sparql_query=True\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 29,
"id": "4f4d47b6-4202-4e74-8c88-aeaac5344c04",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
"\u001b[1m> Entering new GraphSparqlQAChain chain...\u001b[0m\n",
"Identified intent:\n",
"\u001b[32;1m\u001b[1;3mSELECT\u001b[0m\n",
"Generated SPARQL:\n",
"\u001b[32;1m\u001b[1;3mPREFIX foaf: <http://xmlns.com/foaf/0.1/>\n",
"SELECT ?workHomepage\n",
"WHERE {\n",
" ?person foaf:name \"Tim Berners-Lee\" .\n",
" ?person foaf:workplaceHomepage ?workHomepage .\n",
"}\u001b[0m\n",
"Full Context:\n",
"\u001b[32;1m\u001b[1;3m[]\u001b[0m\n",
"\n",
"\u001b[1m> Finished chain.\u001b[0m\n",
"SQARQL query: PREFIX foaf: <http://xmlns.com/foaf/0.1/>\n",
"SELECT ?workHomepage\n",
"WHERE {\n",
" ?person foaf:name \"Tim Berners-Lee\" .\n",
" ?person foaf:workplaceHomepage ?workHomepage .\n",
"}\n",
"Final answer: Tim Berners-Lee's work homepage is http://www.w3.org/People/Berners-Lee/.\n"
]
}
],
"source": [
"result = chain(\"What is Tim Berners-Lee's work homepage?\")\n",
"print(f\"SQARQL query: {result['sparql_query']}\")\n",
"print(f\"Final answer: {result['result']}\")"
]
},
{
"cell_type": "code",
"execution_count": 30,
"id": "be3d9ff7-dc00-47d0-857d-fd40437a3f22",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"PREFIX foaf: <http://xmlns.com/foaf/0.1/>\n",
"SELECT ?workHomepage\n",
"WHERE {\n",
" ?person foaf:name \"Tim Berners-Lee\" .\n",
" ?person foaf:workplaceHomepage ?workHomepage .\n",
"}\n"
]
}
],
"source": [
"print(result[\"sparql_query\"])"
]
} }
], ],
"metadata": { "metadata": {
"kernelspec": { "kernelspec": {
"display_name": "lc", "display_name": "Python 3 (ipykernel)",
"language": "python", "language": "python",
"name": "lc" "name": "python3"
}, },
"language_info": { "language_info": {
"codemirror_mode": { "codemirror_mode": {
@ -295,9 +381,9 @@
"name": "python", "name": "python",
"nbconvert_exporter": "python", "nbconvert_exporter": "python",
"pygments_lexer": "ipython3", "pygments_lexer": "ipython3",
"version": "3.11.4" "version": "3.10.9"
} }
}, },
"nbformat": 4, "nbformat": 4,
"nbformat_minor": 5 "nbformat_minor": 5
} }

View File

@ -41,15 +41,25 @@ class GraphSparqlQAChain(Chain):
sparql_generation_update_chain: LLMChain sparql_generation_update_chain: LLMChain
sparql_intent_chain: LLMChain sparql_intent_chain: LLMChain
qa_chain: LLMChain qa_chain: LLMChain
return_sparql_query: bool = False
input_key: str = "query" #: :meta private: input_key: str = "query" #: :meta private:
output_key: str = "result" #: :meta private: output_key: str = "result" #: :meta private:
sparql_query_key: str = "sparql_query" #: :meta private:
@property @property
def input_keys(self) -> List[str]: def input_keys(self) -> List[str]:
"""Return the input keys.
:meta private:
"""
return [self.input_key] return [self.input_key]
@property @property
def output_keys(self) -> List[str]: def output_keys(self) -> List[str]:
"""Return the output keys.
:meta private:
"""
_output_keys = [self.output_key] _output_keys = [self.output_key]
return _output_keys return _output_keys
@ -135,4 +145,8 @@ class GraphSparqlQAChain(Chain):
res = "Successfully inserted triples into the graph." res = "Successfully inserted triples into the graph."
else: else:
raise ValueError("Unsupported SPARQL query type.") raise ValueError("Unsupported SPARQL query type.")
return {self.output_key: res}
chain_result: Dict[str, Any] = {self.output_key: res}
if self.return_sparql_query:
chain_result[self.sparql_query_key] = generated_sparql
return chain_result

View File

@ -78,3 +78,29 @@ def test_sparql_insert() -> None:
os.remove(_local_copy) os.remove(_local_copy)
except OSError: except OSError:
pass pass
def test_sparql_select_return_query() -> None:
"""
Test for generating and executing simple SPARQL SELECT query
and returning the generated SPARQL query.
"""
berners_lee_card = "http://www.w3.org/People/Berners-Lee/card"
graph = RdfGraph(
source_file=berners_lee_card,
standard="rdf",
)
chain = GraphSparqlQAChain.from_llm(
OpenAI(temperature=0), graph=graph, return_sparql_query=True
)
output = chain("What is Tim Berners-Lee's work homepage?")
# Verify the expected answer
expected_output = (
" The work homepage of Tim Berners-Lee is "
"http://www.w3.org/People/Berners-Lee/."
)
assert output["result"] == expected_output
assert "sparql_query" in output