langchain[patch], community[patch]: Fixes in the Ontotext GraphDB Graph and QA Chain (#17239)

- **Description:** Fixes in the Ontotext GraphDB Graph and QA Chain
related to the error handling in case of invalid SPARQL queries, for
which `prepareQuery` doesn't throw an exception, but the server returns
400 and the query is indeed invalid
  - **Issue:** N/A
  - **Dependencies:** N/A
  - **Twitter handle:** @OntotextGraphDB
pull/17252/head^2
Neli Hateva 4 months ago committed by GitHub
parent b88329e9a5
commit 9bb5157a3d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -69,7 +69,7 @@
"pip install openai==1.6.1\n",
"pip install rdflib==7.0.0\n",
"pip install langchain-openai==0.0.2\n",
"pip install langchain\n",
"pip install langchain>=0.1.5\n",
"```\n",
"\n",
"Run Jupyter with\n",

@ -204,11 +204,7 @@ class OntotextGraphDBGraph:
"""
Query the graph.
"""
from rdflib.exceptions import ParserError
from rdflib.query import ResultRow
try:
res = self.graph.query(query)
except ParserError as e:
raise ValueError(f"Generated SPARQL statement is invalid\n{e}")
res = self.graph.query(query)
return [r for r in res if isinstance(r, ResultRow)]

@ -10,7 +10,7 @@ cd libs/community/tests/integration_tests/graphs/docker-compose-ontotext-graphdb
"""
def test_query() -> None:
def test_query_method_with_valid_query() -> None:
graph = OntotextGraphDBGraph(
query_endpoint="http://localhost:7200/repositories/langchain",
query_ontology="CONSTRUCT {?s ?p ?o}"
@ -31,6 +31,36 @@ def test_query() -> None:
assert str(query_results[0][0]) == "yellow"
def test_query_method_with_invalid_query() -> None:
graph = OntotextGraphDBGraph(
query_endpoint="http://localhost:7200/repositories/langchain",
query_ontology="CONSTRUCT {?s ?p ?o}"
"FROM <https://swapi.co/ontology/> WHERE {?s ?p ?o}",
)
with pytest.raises(ValueError) as e:
graph.query(
"PREFIX : <https://swapi.co/vocabulary/> "
"PREFIX owl: <http://www.w3.org/2002/07/owl#> "
"PREFIX rdfs: <http://www.w3.org/2000/01/rdf-schema#> "
"PREFIX xsd: <http://www.w3.org/2001/XMLSchema#> "
"SELECT ?character (MAX(?lifespan) AS ?maxLifespan) "
"WHERE {"
" ?species a :Species ;"
" :character ?character ;"
" :averageLifespan ?lifespan ."
" FILTER(xsd:integer(?lifespan))"
"} "
"ORDER BY DESC(?maxLifespan) "
"LIMIT 1"
)
assert (
str(e.value)
== "You did something wrong formulating either the URI or your SPARQL query"
)
def test_get_schema_with_query() -> None:
graph = OntotextGraphDBGraph(
query_endpoint="http://localhost:7200/repositories/langchain",

@ -1,7 +1,10 @@
"""Question answering over a graph."""
from __future__ import annotations
from typing import Any, Dict, List, Optional
from typing import TYPE_CHECKING, Any, Dict, List, Optional
if TYPE_CHECKING:
import rdflib
from langchain_community.graphs import OntotextGraphDBGraph
from langchain_core.callbacks.manager import CallbackManager
@ -97,10 +100,10 @@ class OntotextGraphDBQAChain(Chain):
self.sparql_generation_chain.output_key
]
generated_sparql = self._get_valid_sparql_query(
generated_sparql = self._get_prepared_sparql_query(
_run_manager, callbacks, generated_sparql, ontology_schema
)
query_results = self.graph.query(generated_sparql)
query_results = self._execute_query(generated_sparql)
qa_chain_result = self.qa_chain.invoke(
{"prompt": prompt, "context": query_results}, callbacks=callbacks
@ -108,7 +111,7 @@ class OntotextGraphDBQAChain(Chain):
result = qa_chain_result[self.qa_chain.output_key]
return {self.output_key: result}
def _get_valid_sparql_query(
def _get_prepared_sparql_query(
self,
_run_manager: CallbackManagerForChainRun,
callbacks: CallbackManager,
@ -153,10 +156,10 @@ class OntotextGraphDBQAChain(Chain):
from rdflib.plugins.sparql import prepareQuery
prepareQuery(generated_sparql)
self._log_valid_sparql_query(_run_manager, generated_sparql)
self._log_prepared_sparql_query(_run_manager, generated_sparql)
return generated_sparql
def _log_valid_sparql_query(
def _log_prepared_sparql_query(
self, _run_manager: CallbackManagerForChainRun, generated_query: str
) -> None:
_run_manager.on_text("Generated SPARQL:", end="\n", verbose=self.verbose)
@ -180,3 +183,9 @@ class OntotextGraphDBQAChain(Chain):
_run_manager.on_text(
error_message, color="red", end="\n\n", verbose=self.verbose
)
def _execute_query(self, query: str) -> List[rdflib.query.ResultRow]:
try:
return self.graph.query(query)
except Exception:
raise ValueError("Failed to execute the generated SPARQL query.")

@ -165,6 +165,65 @@ def test_valid_sparql_after_first_retry(max_fix_retries: int) -> None:
assert result == {chain.output_key: answer, chain.input_key: question}
@pytest.mark.requires("langchain_openai", "rdflib")
@pytest.mark.parametrize("max_fix_retries", [1, 2, 3])
def test_invalid_sparql_server_response_400(max_fix_retries: int) -> None:
from langchain_openai import ChatOpenAI
question = "Who is the oldest character?"
generated_invalid_sparql = (
"PREFIX : <https://swapi.co/vocabulary/> "
"PREFIX owl: <http://www.w3.org/2002/07/owl#> "
"PREFIX rdfs: <http://www.w3.org/2000/01/rdf-schema#> "
"PREFIX xsd: <http://www.w3.org/2001/XMLSchema#> "
"SELECT ?character (MAX(?lifespan) AS ?maxLifespan) "
"WHERE {"
" ?species a :Species ;"
" :character ?character ;"
" :averageLifespan ?lifespan ."
" FILTER(xsd:integer(?lifespan))"
"} "
"ORDER BY DESC(?maxLifespan) "
"LIMIT 1"
)
graph = OntotextGraphDBGraph(
query_endpoint="http://localhost:7200/repositories/starwars",
query_ontology="CONSTRUCT {?s ?p ?o} "
"FROM <https://swapi.co/ontology/> WHERE {?s ?p ?o}",
)
chain = OntotextGraphDBQAChain.from_llm(
Mock(ChatOpenAI),
graph=graph,
max_fix_retries=max_fix_retries,
)
chain.sparql_generation_chain = Mock(LLMChain)
chain.sparql_fix_chain = Mock(LLMChain)
chain.qa_chain = Mock(LLMChain)
chain.sparql_generation_chain.output_key = "text"
chain.sparql_generation_chain.invoke = MagicMock(
return_value={
"text": generated_invalid_sparql,
"prompt": question,
"schema": "",
}
)
chain.sparql_fix_chain.output_key = "text"
chain.sparql_fix_chain.invoke = MagicMock()
chain.qa_chain.output_key = "text"
chain.qa_chain.invoke = MagicMock()
with pytest.raises(ValueError) as e:
chain.invoke({chain.input_key: question})
assert str(e.value) == "Failed to execute the generated SPARQL query."
assert chain.sparql_generation_chain.invoke.call_count == 1
assert chain.sparql_fix_chain.invoke.call_count == 0
assert chain.qa_chain.invoke.call_count == 0
@pytest.mark.requires("langchain_openai", "rdflib")
@pytest.mark.parametrize("max_fix_retries", [1, 2, 3])
def test_invalid_sparql_after_all_retries(max_fix_retries: int) -> None:

Loading…
Cancel
Save