[Feature] Add GraphQL Query Tool (#4409)

# Add GraphQL Query Support

This PR introduces a GraphQL API Wrapper tool that allows LLM agents to
query GraphQL databases. The tool utilizes the httpx and gql Python
packages to interact with GraphQL APIs and provides a simple interface
for running queries with LLM agents.

@vowelparrot

---------

Co-authored-by: Dev 2049 <dev.dev2049@gmail.com>
This commit is contained in:
Roma 2023-05-15 18:06:12 -03:00 committed by GitHub
parent 49ce5ce1ca
commit cb802edf75
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 347 additions and 8 deletions

View File

@ -0,0 +1,149 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"\n",
"# GraphQL tool\n",
"This Jupyter Notebook demonstrates how to use the BaseGraphQLTool component with an Agent.\n",
"\n",
"GraphQL is a query language for APIs and a runtime for executing those queries against your data. GraphQL provides a complete and understandable description of the data in your API, gives clients the power to ask for exactly what they need and nothing more, makes it easier to evolve APIs over time, and enables powerful developer tools.\n",
"\n",
"By including a BaseGraphQLTool in the list of tools provided to an Agent, you can grant your Agent the ability to query data from GraphQL APIs for any purposes you need.\n",
"\n",
"In this example, we'll be using the public Star Wars GraphQL API available at the following endpoint: https://swapi-graphql.netlify.app/.netlify/functions/index.\n",
"\n",
"First, you need to install httpx and gql Python packages."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"vscode": {
"languageId": "shellscript"
}
},
"outputs": [],
"source": [
"pip install httpx gql > /dev/null"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now, let's create a BaseGraphQLTool instance with the specified Star Wars API endpoint and initialize an Agent with the tool."
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"from langchain import OpenAI\n",
"from langchain.agents import load_tools, initialize_agent, AgentType\n",
"from langchain.utilities import GraphQLAPIWrapper\n",
"\n",
"llm = OpenAI(temperature=0)\n",
"\n",
"tools = load_tools([\"graphql\"], graphql_endpoint=\"https://swapi-graphql.netlify.app/.netlify/functions/index\", llm=llm)\n",
"\n",
"agent = initialize_agent(tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True)\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now, we can use the Agent to run queries against the Star Wars GraphQL API. Let's ask the Agent to list all the Star Wars films and their release dates."
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
"\u001b[32;1m\u001b[1;3m I need to query the graphql database to get the titles of all the star wars films\n",
"Action: query_graphql\n",
"Action Input: query { allFilms { films { title } } }\u001b[0m\n",
"Observation: \u001b[36;1m\u001b[1;3m\"{\\n \\\"allFilms\\\": {\\n \\\"films\\\": [\\n {\\n \\\"title\\\": \\\"A New Hope\\\"\\n },\\n {\\n \\\"title\\\": \\\"The Empire Strikes Back\\\"\\n },\\n {\\n \\\"title\\\": \\\"Return of the Jedi\\\"\\n },\\n {\\n \\\"title\\\": \\\"The Phantom Menace\\\"\\n },\\n {\\n \\\"title\\\": \\\"Attack of the Clones\\\"\\n },\\n {\\n \\\"title\\\": \\\"Revenge of the Sith\\\"\\n }\\n ]\\n }\\n}\"\u001b[0m\n",
"Thought:\u001b[32;1m\u001b[1;3m I now know the titles of all the star wars films\n",
"Final Answer: The titles of all the star wars films are: A New Hope, The Empire Strikes Back, Return of the Jedi, The Phantom Menace, Attack of the Clones, and Revenge of the Sith.\u001b[0m\n",
"\n",
"\u001b[1m> Finished chain.\u001b[0m\n"
]
},
{
"data": {
"text/plain": [
"'The titles of all the star wars films are: A New Hope, The Empire Strikes Back, Return of the Jedi, The Phantom Menace, Attack of the Clones, and Revenge of the Sith.'"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"graphql_fields = \"\"\"allFilms {\n",
" films {\n",
" title\n",
" director\n",
" releaseDate\n",
" speciesConnection {\n",
" species {\n",
" name\n",
" classification\n",
" homeworld {\n",
" name\n",
" }\n",
" }\n",
" }\n",
" }\n",
" }\n",
"\n",
"\"\"\"\n",
"\n",
"suffix = \"Search for the titles of all the stawars films stored in the graphql database that has this schema \"\n",
"\n",
"\n",
"agent.run(suffix + graphql_fields)"
]
}
],
"metadata": {
"interpreter": {
"hash": "f85209c3c4c190dca7367d6a1e623da50a9a4392fd53313a7cf9d4bda9c4b85b"
},
"kernelspec": {
"display_name": "Python 3.9.16 ('langchain')",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.16"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}

View File

@ -20,6 +20,7 @@ from langchain.tools.ddg_search.tool import DuckDuckGoSearchRun
from langchain.tools.google_search.tool import GoogleSearchResults, GoogleSearchRun
from langchain.tools.metaphor_search.tool import MetaphorSearchResults
from langchain.tools.google_serper.tool import GoogleSerperResults, GoogleSerperRun
from langchain.tools.graphql.tool import BaseGraphQLTool
from langchain.tools.human.tool import HumanInputRun
from langchain.tools.python.tool import PythonREPLTool
from langchain.tools.requests.tool import (
@ -42,6 +43,7 @@ from langchain.utilities.google_search import GoogleSearchAPIWrapper
from langchain.utilities.google_serper import GoogleSerperAPIWrapper
from langchain.utilities.metaphor_search import MetaphorSearchAPIWrapper
from langchain.utilities.awslambda import LambdaWrapper
from langchain.utilities.graphql import GraphQLAPIWrapper
from langchain.utilities.searx_search import SearxSearchWrapper
from langchain.utilities.serpapi import SerpAPIWrapper
from langchain.utilities.wikipedia import WikipediaAPIWrapper
@ -245,6 +247,12 @@ def _get_scenexplain(**kwargs: Any) -> BaseTool:
return SceneXplainTool(**kwargs)
def _get_graphql_tool(**kwargs: Any) -> BaseTool:
graphql_endpoint = kwargs["graphql_endpoint"]
wrapper = GraphQLAPIWrapper(graphql_endpoint=graphql_endpoint)
return BaseGraphQLTool(graphql_wrapper=wrapper)
def _get_openweathermap(**kwargs: Any) -> BaseTool:
return OpenWeatherMapQueryRun(api_wrapper=OpenWeatherMapAPIWrapper(**kwargs))
@ -290,6 +298,7 @@ _EXTRA_OPTIONAL_TOOLS: Dict[str, Tuple[Callable[[KwArg(Any)], BaseTool], List[st
["awslambda_tool_name", "awslambda_tool_description", "function_name"],
),
"sceneXplain": (_get_scenexplain, []),
"graphql": (_get_graphql_tool, ["graphql_endpoint"]),
"openweathermap-api": (_get_openweathermap, ["openweathermap_api_key"]),
}

View File

@ -0,0 +1 @@
"""Tools for interacting with a GraphQL API"""

View File

@ -0,0 +1,46 @@
import json
from typing import Optional
from langchain.callbacks.manager import (
AsyncCallbackManagerForToolRun,
CallbackManagerForToolRun,
)
from langchain.tools.base import BaseTool
from langchain.utilities.graphql import GraphQLAPIWrapper
class BaseGraphQLTool(BaseTool):
"""Base tool for querying a GraphQL API."""
graphql_wrapper: GraphQLAPIWrapper
name = "query_graphql"
description = """\
Input to this tool is a detailed and correct GraphQL query, output is a result from the API.
If the query is not correct, an error message will be returned.
If an error is returned with 'Bad request' in it, rewrite the query and try again.
If an error is returned with 'Unauthorized' in it, do not try again, but tell the user to change their authentication.
Example Input: query {{ allUsers {{ id, name, email }} }}\
""" # noqa: E501
class Config:
"""Configuration for this pydantic object."""
arbitrary_types_allowed = True
def _run(
self,
tool_input: str,
run_manager: Optional[CallbackManagerForToolRun] = None,
) -> str:
result = self.graphql_wrapper.run(tool_input)
return json.dumps(result, indent=2)
async def _arun(
self,
tool_input: str,
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
) -> str:
"""Use the Graphql tool asynchronously."""
raise NotImplementedError("GraphQLAPIWrapper does not support async")

View File

@ -9,6 +9,7 @@ from langchain.utilities.duckduckgo_search import DuckDuckGoSearchAPIWrapper
from langchain.utilities.google_places_api import GooglePlacesAPIWrapper
from langchain.utilities.google_search import GoogleSearchAPIWrapper
from langchain.utilities.google_serper import GoogleSerperAPIWrapper
from langchain.utilities.graphql import GraphQLAPIWrapper
from langchain.utilities.metaphor_search import MetaphorSearchAPIWrapper
from langchain.utilities.openweathermap import OpenWeatherMapAPIWrapper
from langchain.utilities.powerbi import PowerBIDataset
@ -27,6 +28,7 @@ __all__ = [
"GoogleSearchAPIWrapper",
"GoogleSerperAPIWrapper",
"GooglePlacesAPIWrapper",
"GraphQLAPIWrapper",
"WolframAlphaAPIWrapper",
"SerpAPIWrapper",
"SearxSearchWrapper",

View File

@ -0,0 +1,61 @@
import json
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional
from pydantic import BaseModel, Extra, root_validator
if TYPE_CHECKING:
from gql import Client
class GraphQLAPIWrapper(BaseModel):
"""Wrapper around GraphQL API.
To use, you should have the ``gql`` python package installed.
This wrapper will use the GraphQL API to conduct queries.
"""
custom_headers: Optional[Dict[str, str]] = None
graphql_endpoint: str
gql_client: "Client" #: :meta private:
gql_function: Callable[[str], Any] #: :meta private:
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that the python package exists in the environment."""
headers = values.get("custom_headers", {})
try:
from gql import Client, gql
from gql.transport.requests import RequestsHTTPTransport
transport = RequestsHTTPTransport(
url=values["graphql_endpoint"],
headers=headers or None,
)
client = Client(transport=transport, fetch_schema_from_transport=True)
values["gql_client"] = client
values["gql_function"] = gql
except ImportError:
raise ValueError(
"Could not import gql python package. "
"Please install it with `pip install gql`."
)
return values
def run(self, query: str) -> str:
"""Run a GraphQL query and get the results."""
result = self._execute_query(query)
return json.dumps(result, indent=2)
def _execute_query(self, query: str) -> Dict[str, Any]:
"""Execute a GraphQL query and return the results."""
document_node = self.gql_function(query)
result = self.gql_client.execute(document_node)
return result

52
poetry.lock generated
View File

@ -2288,6 +2288,45 @@ cachetools = "*"
numpy = "*"
requests = "*"
[[package]]
name = "gql"
version = "3.4.1"
description = "GraphQL client for Python"
category = "main"
optional = true
python-versions = "*"
files = [
{file = "gql-3.4.1-py2.py3-none-any.whl", hash = "sha256:315624ca0f4d571ef149d455033ebd35e45c1a13f18a059596aeddcea99135cf"},
{file = "gql-3.4.1.tar.gz", hash = "sha256:11dc5d8715a827f2c2899593439a4f36449db4f0eafa5b1ea63948f8a2f8c545"},
]
[package.dependencies]
backoff = ">=1.11.1,<3.0"
graphql-core = ">=3.2,<3.3"
yarl = ">=1.6,<2.0"
[package.extras]
aiohttp = ["aiohttp (>=3.7.1,<3.9.0)"]
all = ["aiohttp (>=3.7.1,<3.9.0)", "botocore (>=1.21,<2)", "requests (>=2.26,<3)", "requests-toolbelt (>=0.9.1,<1)", "urllib3 (>=1.26,<2)", "websockets (>=10,<11)", "websockets (>=9,<10)"]
botocore = ["botocore (>=1.21,<2)"]
dev = ["aiofiles", "aiohttp (>=3.7.1,<3.9.0)", "black (==22.3.0)", "botocore (>=1.21,<2)", "check-manifest (>=0.42,<1)", "flake8 (==3.8.1)", "isort (==4.3.21)", "mock (==4.0.2)", "mypy (==0.910)", "parse (==1.15.0)", "pytest (==6.2.5)", "pytest-asyncio (==0.16.0)", "pytest-console-scripts (==1.3.1)", "pytest-cov (==3.0.0)", "requests (>=2.26,<3)", "requests-toolbelt (>=0.9.1,<1)", "sphinx (>=3.0.0,<4)", "sphinx-argparse (==0.2.5)", "sphinx-rtd-theme (>=0.4,<1)", "types-aiofiles", "types-mock", "types-requests", "urllib3 (>=1.26,<2)", "vcrpy (==4.0.2)", "websockets (>=10,<11)", "websockets (>=9,<10)"]
requests = ["requests (>=2.26,<3)", "requests-toolbelt (>=0.9.1,<1)", "urllib3 (>=1.26,<2)"]
test = ["aiofiles", "aiohttp (>=3.7.1,<3.9.0)", "botocore (>=1.21,<2)", "mock (==4.0.2)", "parse (==1.15.0)", "pytest (==6.2.5)", "pytest-asyncio (==0.16.0)", "pytest-console-scripts (==1.3.1)", "pytest-cov (==3.0.0)", "requests (>=2.26,<3)", "requests-toolbelt (>=0.9.1,<1)", "urllib3 (>=1.26,<2)", "vcrpy (==4.0.2)", "websockets (>=10,<11)", "websockets (>=9,<10)"]
test-no-transport = ["aiofiles", "mock (==4.0.2)", "parse (==1.15.0)", "pytest (==6.2.5)", "pytest-asyncio (==0.16.0)", "pytest-console-scripts (==1.3.1)", "pytest-cov (==3.0.0)", "vcrpy (==4.0.2)"]
websockets = ["websockets (>=10,<11)", "websockets (>=9,<10)"]
[[package]]
name = "graphql-core"
version = "3.2.3"
description = "GraphQL implementation for Python, a port of GraphQL.js, the JavaScript reference implementation for GraphQL."
category = "main"
optional = true
python-versions = ">=3.6,<4"
files = [
{file = "graphql-core-3.2.3.tar.gz", hash = "sha256:06d2aad0ac723e35b1cb47885d3e5c45e956a53bc1b209a9fc5369007fe46676"},
{file = "graphql_core-3.2.3-py3-none-any.whl", hash = "sha256:5766780452bd5ec8ba133f8bf287dc92713e3868ddd83aee4faab9fc3e303dc3"},
]
[[package]]
name = "greenlet"
version = "2.0.1"
@ -6478,16 +6517,15 @@ tests = ["coverage[toml] (==5.0.4)", "pytest (>=6.0.0,<7.0.0)"]
[[package]]
name = "pylance"
version = "0.4.11"
version = "0.4.12"
description = "python wrapper for lance-rs"
category = "main"
optional = true
python-versions = ">=3.8"
files = [
{file = "pylance-0.4.11-cp38-abi3-macosx_10_15_x86_64.whl", hash = "sha256:667b0bc0e52bfdb8638f206c63df4f8e7bd695b25fbe2c5effa430190eae0da1"},
{file = "pylance-0.4.11-cp38-abi3-macosx_11_0_arm64.whl", hash = "sha256:086a1b3ede234514deab382cca723ec0870f98771418595f52d5f063eab56629"},
{file = "pylance-0.4.11-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c02dd9ea64149781ac1c96413e23589daaa1106732ad99ce3077a24d8c076a2a"},
{file = "pylance-0.4.11-cp38-abi3-win_amd64.whl", hash = "sha256:d9703771940610e37c01ce0f2c0e95acdbb646526a43ec987bdb837377096e4a"},
{file = "pylance-0.4.12-cp38-abi3-macosx_10_15_x86_64.whl", hash = "sha256:2b86fb8dccc03094c0db37bef0d91bda60e8eb0d1eddf245c6971450c8d8a53f"},
{file = "pylance-0.4.12-cp38-abi3-macosx_11_0_arm64.whl", hash = "sha256:0bc82914b13204187d673b5f3d45f93219c38a0e9d0542ba251074f639669789"},
{file = "pylance-0.4.12-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5a4bcce77f99ecd4cbebbadb01e58d5d8138d40eb56bdcdbc3b20b0475e7a472"},
]
[package.dependencies]
@ -10135,7 +10173,7 @@ cffi = {version = ">=1.11", markers = "platform_python_implementation == \"PyPy\
cffi = ["cffi (>=1.11)"]
[extras]
all = ["O365", "aleph-alpha-client", "anthropic", "arxiv", "atlassian-python-api", "azure-cosmos", "azure-identity", "beautifulsoup4", "clickhouse-connect", "cohere", "deeplake", "docarray", "duckduckgo-search", "elasticsearch", "faiss-cpu", "google-api-python-client", "google-search-results", "gptcache", "hnswlib", "html2text", "huggingface_hub", "jina", "jinja2", "jq", "lancedb", "lark", "manifest-ml", "networkx", "nlpcloud", "nltk", "nomic", "openai", "opensearch-py", "pdfminer-six", "pexpect", "pgvector", "pinecone-client", "pinecone-text", "protobuf", "psycopg2-binary", "pyowm", "pypdf", "pytesseract", "pyvespa", "qdrant-client", "redis", "sentence-transformers", "spacy", "steamship", "tensorflow-text", "tiktoken", "torch", "transformers", "weaviate-client", "wikipedia", "wolframalpha"]
all = ["O365", "aleph-alpha-client", "anthropic", "arxiv", "atlassian-python-api", "azure-cosmos", "azure-identity", "beautifulsoup4", "clickhouse-connect", "cohere", "deeplake", "docarray", "duckduckgo-search", "elasticsearch", "faiss-cpu", "google-api-python-client", "google-search-results", "gptcache", "gql", "hnswlib", "html2text", "huggingface_hub", "jina", "jinja2", "jq", "lancedb", "lark", "manifest-ml", "networkx", "nlpcloud", "nltk", "nomic", "openai", "opensearch-py", "pdfminer-six", "pexpect", "pgvector", "pinecone-client", "pinecone-text", "protobuf", "psycopg2-binary", "pyowm", "pypdf", "pytesseract", "pyvespa", "qdrant-client", "redis", "sentence-transformers", "spacy", "steamship", "tensorflow-text", "tiktoken", "torch", "transformers", "weaviate-client", "wikipedia", "wolframalpha"]
azure = ["azure-core", "azure-cosmos", "azure-identity", "openai"]
cohere = ["cohere"]
embeddings = ["sentence-transformers"]
@ -10149,4 +10187,4 @@ qdrant = ["qdrant-client"]
[metadata]
lock-version = "2.0"
python-versions = ">=3.8.1,<4.0"
content-hash = "8c327ab0b84d7d2c61d429154a604db0775ae2d40b66efbdf9dbee9bdc8a245c"
content-hash = "209e2cb579599084e5bcbc30c597f6e0f2825fc33c2c06d17edcdaa3c9902f30"

View File

@ -85,6 +85,7 @@ hnswlib = {version="^0.7.0", optional=true}
lxml = {version = "^4.9.2", optional = true}
pymupdf = {version = "^1.22.3", optional = true}
pypdfium2 = {version = "^4.10.0", optional = true}
gql = {version = "^3.4.1", optional = true}
[tool.poetry.group.docs.dependencies]
@ -171,7 +172,7 @@ in_memory_store = ["docarray"]
hnswlib = ["docarray", "protobuf", "hnswlib"]
embeddings = ["sentence-transformers"]
azure = ["azure-identity", "azure-cosmos", "openai", "azure-core"]
all = ["anthropic", "cohere", "openai", "nlpcloud", "huggingface_hub", "jina", "manifest-ml", "elasticsearch", "opensearch-py", "google-search-results", "faiss-cpu", "sentence-transformers", "transformers", "spacy", "nltk", "wikipedia", "beautifulsoup4", "tiktoken", "torch", "jinja2", "pinecone-client", "pinecone-text", "weaviate-client", "redis", "google-api-python-client", "wolframalpha", "qdrant-client", "tensorflow-text", "pypdf", "networkx", "nomic", "aleph-alpha-client", "deeplake", "pgvector", "psycopg2-binary", "boto3", "pyowm", "pytesseract", "html2text", "atlassian-python-api", "gptcache", "duckduckgo-search", "arxiv", "azure-identity", "clickhouse-connect", "azure-cosmos", "lancedb", "lark", "pexpect", "pyvespa", "O365", "jq", "docarray", "protobuf", "hnswlib", "steamship", "pdfminer-six"]
all = ["anthropic", "cohere", "openai", "nlpcloud", "huggingface_hub", "jina", "manifest-ml", "elasticsearch", "opensearch-py", "google-search-results", "faiss-cpu", "sentence-transformers", "transformers", "spacy", "nltk", "wikipedia", "beautifulsoup4", "tiktoken", "torch", "jinja2", "pinecone-client", "pinecone-text", "weaviate-client", "redis", "google-api-python-client", "wolframalpha", "qdrant-client", "tensorflow-text", "pypdf", "networkx", "nomic", "aleph-alpha-client", "deeplake", "pgvector", "psycopg2-binary", "boto3", "pyowm", "pytesseract", "html2text", "atlassian-python-api", "gptcache", "duckduckgo-search", "arxiv", "azure-identity", "clickhouse-connect", "azure-cosmos", "lancedb", "lark", "pexpect", "pyvespa", "O365", "jq", "docarray", "protobuf", "hnswlib", "steamship", "pdfminer-six", "gql"]
# An extra used to be able to add extended testing.
# Please use new-line on formatting to make it easier to add new packages without
# merge-conflicts

View File

@ -0,0 +1,32 @@
import json
import pytest
import responses
from langchain.utilities.graphql import GraphQLAPIWrapper
TEST_ENDPOINT = "http://testserver/graphql"
# Mock GraphQL response for testing
MOCK_RESPONSE = {
"data": {"allUsers": [{"id": 1, "name": "Alice", "email": "alice@example.com"}]}
}
@pytest.fixture
def graphql_wrapper() -> GraphQLAPIWrapper:
return GraphQLAPIWrapper(
graphql_endpoint=TEST_ENDPOINT,
custom_headers={"Authorization": "Bearer testtoken"},
)
@responses.activate
def test_run(graphql_wrapper: GraphQLAPIWrapper) -> None:
responses.add(responses.POST, TEST_ENDPOINT, json=MOCK_RESPONSE, status=200)
query = "query { allUsers { id, name, email } }"
result = graphql_wrapper.run(query)
expected_result = json.dumps(MOCK_RESPONSE, indent=2)
assert result == expected_result