langchain/libs/community/langchain_community/utilities/graphql.py
Zachary Toliver 29ee0496b6
community[patch]: Allow override of 'fetch_schema_from_transport' in the GraphQL tool (#17649)
- **Description:** In order to override the bool value of
"fetch_schema_from_transport" in the GraphQLAPIWrapper, a
"fetch_schema_from_transport" value needed to be added to the
"_EXTRA_OPTIONAL_TOOLS" dictionary in load_tools in the "graphql" key.
The parameter "fetch_schema_from_transport" must also be passed in to
the GraphQLAPIWrapper to allow reading of the value when creating the
client. Passing as an optional parameter is probably best to avoid
breaking changes. This change is necessary to support GraphQL instances
that do not support fetching schema, such as TigerGraph. More info here:
[TigerGraph GraphQL Schema
Docs](https://docs.tigergraph.com/graphql/current/schema)
  - **Threads handle:** @zacharytoliver

---------

Co-authored-by: Zachary Toliver <zt10191991@hotmail.com>
Co-authored-by: Bagatur <baskaryan@gmail.com>
2024-02-21 15:32:43 -08:00

59 lines
2.0 KiB
Python

import json
from typing import Any, Callable, Dict, Optional
from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator
class GraphQLAPIWrapper(BaseModel):
"""Wrapper around GraphQL API.
To use, you should have the ``gql`` python package installed.
This wrapper will use the GraphQL API to conduct queries.
"""
custom_headers: Optional[Dict[str, str]] = None
fetch_schema_from_transport: Optional[bool] = None
graphql_endpoint: str
gql_client: Any #: :meta private:
gql_function: Callable[[str], Any] #: :meta private:
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
@root_validator(pre=True)
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that the python package exists in the environment."""
try:
from gql import Client, gql
from gql.transport.requests import RequestsHTTPTransport
except ImportError as e:
raise ImportError(
"Could not import gql python package. "
f"Try installing it with `pip install gql`. Received error: {e}"
)
headers = values.get("custom_headers")
transport = RequestsHTTPTransport(
url=values["graphql_endpoint"],
headers=headers,
)
fetch_schema_from_transport = values.get("fetch_schema_from_transport", True)
client = Client(
transport=transport, fetch_schema_from_transport=fetch_schema_from_transport
)
values["gql_client"] = client
values["gql_function"] = gql
return values
def run(self, query: str) -> str:
"""Run a GraphQL query and get the results."""
result = self._execute_query(query)
return json.dumps(result, indent=2)
def _execute_query(self, query: str) -> Dict[str, Any]:
"""Execute a GraphQL query and return the results."""
document_node = self.gql_function(query)
result = self.gql_client.execute(document_node)
return result