mirror of https://github.com/hwchase17/langchain
[Feature] Add GraphQL Query Tool (#4409)
# Add GraphQL Query Support This PR introduces a GraphQL API Wrapper tool that allows LLM agents to query GraphQL databases. The tool utilizes the httpx and gql Python packages to interact with GraphQL APIs and provides a simple interface for running queries with LLM agents. @vowelparrot --------- Co-authored-by: Dev 2049 <dev.dev2049@gmail.com>pull/4534/head
parent
49ce5ce1ca
commit
cb802edf75
@ -0,0 +1,149 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"\n",
|
||||
"# GraphQL tool\n",
|
||||
"This Jupyter Notebook demonstrates how to use the BaseGraphQLTool component with an Agent.\n",
|
||||
"\n",
|
||||
"GraphQL is a query language for APIs and a runtime for executing those queries against your data. GraphQL provides a complete and understandable description of the data in your API, gives clients the power to ask for exactly what they need and nothing more, makes it easier to evolve APIs over time, and enables powerful developer tools.\n",
|
||||
"\n",
|
||||
"By including a BaseGraphQLTool in the list of tools provided to an Agent, you can grant your Agent the ability to query data from GraphQL APIs for any purposes you need.\n",
|
||||
"\n",
|
||||
"In this example, we'll be using the public Star Wars GraphQL API available at the following endpoint: https://swapi-graphql.netlify.app/.netlify/functions/index.\n",
|
||||
"\n",
|
||||
"First, you need to install httpx and gql Python packages."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"vscode": {
|
||||
"languageId": "shellscript"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"pip install httpx gql > /dev/null"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Now, let's create a BaseGraphQLTool instance with the specified Star Wars API endpoint and initialize an Agent with the tool."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain import OpenAI\n",
|
||||
"from langchain.agents import load_tools, initialize_agent, AgentType\n",
|
||||
"from langchain.utilities import GraphQLAPIWrapper\n",
|
||||
"\n",
|
||||
"llm = OpenAI(temperature=0)\n",
|
||||
"\n",
|
||||
"tools = load_tools([\"graphql\"], graphql_endpoint=\"https://swapi-graphql.netlify.app/.netlify/functions/index\", llm=llm)\n",
|
||||
"\n",
|
||||
"agent = initialize_agent(tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Now, we can use the Agent to run queries against the Star Wars GraphQL API. Let's ask the Agent to list all the Star Wars films and their release dates."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n",
|
||||
"\n",
|
||||
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
|
||||
"\u001b[32;1m\u001b[1;3m I need to query the graphql database to get the titles of all the star wars films\n",
|
||||
"Action: query_graphql\n",
|
||||
"Action Input: query { allFilms { films { title } } }\u001b[0m\n",
|
||||
"Observation: \u001b[36;1m\u001b[1;3m\"{\\n \\\"allFilms\\\": {\\n \\\"films\\\": [\\n {\\n \\\"title\\\": \\\"A New Hope\\\"\\n },\\n {\\n \\\"title\\\": \\\"The Empire Strikes Back\\\"\\n },\\n {\\n \\\"title\\\": \\\"Return of the Jedi\\\"\\n },\\n {\\n \\\"title\\\": \\\"The Phantom Menace\\\"\\n },\\n {\\n \\\"title\\\": \\\"Attack of the Clones\\\"\\n },\\n {\\n \\\"title\\\": \\\"Revenge of the Sith\\\"\\n }\\n ]\\n }\\n}\"\u001b[0m\n",
|
||||
"Thought:\u001b[32;1m\u001b[1;3m I now know the titles of all the star wars films\n",
|
||||
"Final Answer: The titles of all the star wars films are: A New Hope, The Empire Strikes Back, Return of the Jedi, The Phantom Menace, Attack of the Clones, and Revenge of the Sith.\u001b[0m\n",
|
||||
"\n",
|
||||
"\u001b[1m> Finished chain.\u001b[0m\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"'The titles of all the star wars films are: A New Hope, The Empire Strikes Back, Return of the Jedi, The Phantom Menace, Attack of the Clones, and Revenge of the Sith.'"
|
||||
]
|
||||
},
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"graphql_fields = \"\"\"allFilms {\n",
|
||||
" films {\n",
|
||||
" title\n",
|
||||
" director\n",
|
||||
" releaseDate\n",
|
||||
" speciesConnection {\n",
|
||||
" species {\n",
|
||||
" name\n",
|
||||
" classification\n",
|
||||
" homeworld {\n",
|
||||
" name\n",
|
||||
" }\n",
|
||||
" }\n",
|
||||
" }\n",
|
||||
" }\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
"\"\"\"\n",
|
||||
"\n",
|
||||
"suffix = \"Search for the titles of all the stawars films stored in the graphql database that has this schema \"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"agent.run(suffix + graphql_fields)"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"interpreter": {
|
||||
"hash": "f85209c3c4c190dca7367d6a1e623da50a9a4392fd53313a7cf9d4bda9c4b85b"
|
||||
},
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3.9.16 ('langchain')",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.9.16"
|
||||
},
|
||||
"orig_nbformat": 4
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
@ -0,0 +1 @@
|
||||
"""Tools for interacting with a GraphQL API"""
|
@ -0,0 +1,46 @@
|
||||
import json
|
||||
from typing import Optional
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForToolRun,
|
||||
CallbackManagerForToolRun,
|
||||
)
|
||||
from langchain.tools.base import BaseTool
|
||||
from langchain.utilities.graphql import GraphQLAPIWrapper
|
||||
|
||||
|
||||
class BaseGraphQLTool(BaseTool):
|
||||
"""Base tool for querying a GraphQL API."""
|
||||
|
||||
graphql_wrapper: GraphQLAPIWrapper
|
||||
|
||||
name = "query_graphql"
|
||||
description = """\
|
||||
Input to this tool is a detailed and correct GraphQL query, output is a result from the API.
|
||||
If the query is not correct, an error message will be returned.
|
||||
If an error is returned with 'Bad request' in it, rewrite the query and try again.
|
||||
If an error is returned with 'Unauthorized' in it, do not try again, but tell the user to change their authentication.
|
||||
|
||||
Example Input: query {{ allUsers {{ id, name, email }} }}\
|
||||
""" # noqa: E501
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def _run(
|
||||
self,
|
||||
tool_input: str,
|
||||
run_manager: Optional[CallbackManagerForToolRun] = None,
|
||||
) -> str:
|
||||
result = self.graphql_wrapper.run(tool_input)
|
||||
return json.dumps(result, indent=2)
|
||||
|
||||
async def _arun(
|
||||
self,
|
||||
tool_input: str,
|
||||
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
|
||||
) -> str:
|
||||
"""Use the Graphql tool asynchronously."""
|
||||
raise NotImplementedError("GraphQLAPIWrapper does not support async")
|
@ -0,0 +1,61 @@
|
||||
import json
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional
|
||||
|
||||
from pydantic import BaseModel, Extra, root_validator
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from gql import Client
|
||||
|
||||
|
||||
class GraphQLAPIWrapper(BaseModel):
|
||||
"""Wrapper around GraphQL API.
|
||||
|
||||
To use, you should have the ``gql`` python package installed.
|
||||
This wrapper will use the GraphQL API to conduct queries.
|
||||
"""
|
||||
|
||||
custom_headers: Optional[Dict[str, str]] = None
|
||||
graphql_endpoint: str
|
||||
gql_client: "Client" #: :meta private:
|
||||
gql_function: Callable[[str], Any] #: :meta private:
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
extra = Extra.forbid
|
||||
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that the python package exists in the environment."""
|
||||
|
||||
headers = values.get("custom_headers", {})
|
||||
|
||||
try:
|
||||
from gql import Client, gql
|
||||
from gql.transport.requests import RequestsHTTPTransport
|
||||
|
||||
transport = RequestsHTTPTransport(
|
||||
url=values["graphql_endpoint"],
|
||||
headers=headers or None,
|
||||
)
|
||||
|
||||
client = Client(transport=transport, fetch_schema_from_transport=True)
|
||||
values["gql_client"] = client
|
||||
values["gql_function"] = gql
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"Could not import gql python package. "
|
||||
"Please install it with `pip install gql`."
|
||||
)
|
||||
return values
|
||||
|
||||
def run(self, query: str) -> str:
|
||||
"""Run a GraphQL query and get the results."""
|
||||
result = self._execute_query(query)
|
||||
return json.dumps(result, indent=2)
|
||||
|
||||
def _execute_query(self, query: str) -> Dict[str, Any]:
|
||||
"""Execute a GraphQL query and return the results."""
|
||||
document_node = self.gql_function(query)
|
||||
result = self.gql_client.execute(document_node)
|
||||
return result
|
@ -0,0 +1,32 @@
|
||||
import json
|
||||
|
||||
import pytest
|
||||
import responses
|
||||
|
||||
from langchain.utilities.graphql import GraphQLAPIWrapper
|
||||
|
||||
TEST_ENDPOINT = "http://testserver/graphql"
|
||||
|
||||
# Mock GraphQL response for testing
|
||||
MOCK_RESPONSE = {
|
||||
"data": {"allUsers": [{"id": 1, "name": "Alice", "email": "alice@example.com"}]}
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def graphql_wrapper() -> GraphQLAPIWrapper:
|
||||
return GraphQLAPIWrapper(
|
||||
graphql_endpoint=TEST_ENDPOINT,
|
||||
custom_headers={"Authorization": "Bearer testtoken"},
|
||||
)
|
||||
|
||||
|
||||
@responses.activate
|
||||
def test_run(graphql_wrapper: GraphQLAPIWrapper) -> None:
|
||||
responses.add(responses.POST, TEST_ENDPOINT, json=MOCK_RESPONSE, status=200)
|
||||
|
||||
query = "query { allUsers { id, name, email } }"
|
||||
result = graphql_wrapper.run(query)
|
||||
|
||||
expected_result = json.dumps(MOCK_RESPONSE, indent=2)
|
||||
assert result == expected_result
|
Loading…
Reference in New Issue