From 080eb1b3fce515e715a078593cce315dfb880ea5 Mon Sep 17 00:00:00 2001 From: Davis Chase <130488702+dev2049@users.noreply.github.com> Date: Fri, 19 May 2023 15:27:50 -0700 Subject: [PATCH] Fix graphql tool (#4984) Fix construction and add unit test. --- langchain/utilities/graphql.py | 35 +++---- poetry.lock | 22 ++++- pyproject.toml | 5 +- .../utilities/test_graphql.py | 32 ------- tests/unit_tests/utilities/test_graphql.py | 92 +++++++++++++++++++ 5 files changed, 128 insertions(+), 58 deletions(-) delete mode 100644 tests/integration_tests/utilities/test_graphql.py create mode 100644 tests/unit_tests/utilities/test_graphql.py diff --git a/langchain/utilities/graphql.py b/langchain/utilities/graphql.py index d041920b..1e8a7b20 100644 --- a/langchain/utilities/graphql.py +++ b/langchain/utilities/graphql.py @@ -1,11 +1,8 @@ import json -from typing import TYPE_CHECKING, Any, Callable, Dict, Optional +from typing import Any, Callable, Dict, Optional from pydantic import BaseModel, Extra, root_validator -if TYPE_CHECKING: - from gql import Client - class GraphQLAPIWrapper(BaseModel): """Wrapper around GraphQL API. @@ -16,7 +13,7 @@ class GraphQLAPIWrapper(BaseModel): custom_headers: Optional[Dict[str, str]] = None graphql_endpoint: str - gql_client: "Client" #: :meta private: + gql_client: Any #: :meta private: gql_function: Callable[[str], Any] #: :meta private: class Config: @@ -24,29 +21,25 @@ class GraphQLAPIWrapper(BaseModel): extra = Extra.forbid - @root_validator() + @root_validator(pre=True) def validate_environment(cls, values: Dict) -> Dict: """Validate that the python package exists in the environment.""" - - headers = values.get("custom_headers", {}) - try: from gql import Client, gql from gql.transport.requests import RequestsHTTPTransport - - transport = RequestsHTTPTransport( - url=values["graphql_endpoint"], - headers=headers or None, - ) - - client = Client(transport=transport, fetch_schema_from_transport=True) - values["gql_client"] = client - values["gql_function"] = gql - except ImportError: - raise ValueError( + except ImportError as e: + raise ImportError( "Could not import gql python package. " - "Please install it with `pip install gql`." + f"Try installing it with `pip install gql`. Received error: {e}" ) + headers = values.get("custom_headers") + transport = RequestsHTTPTransport( + url=values["graphql_endpoint"], + headers=headers, + ) + client = Client(transport=transport, fetch_schema_from_transport=True) + values["gql_client"] = client + values["gql_function"] = gql return values def run(self, query: str) -> str: diff --git a/poetry.lock b/poetry.lock index 98a6c21f..7611c0ed 100644 --- a/poetry.lock +++ b/poetry.lock @@ -6620,7 +6620,6 @@ files = [ {file = "pylance-0.4.12-cp38-abi3-macosx_10_15_x86_64.whl", hash = "sha256:2b86fb8dccc03094c0db37bef0d91bda60e8eb0d1eddf245c6971450c8d8a53f"}, {file = "pylance-0.4.12-cp38-abi3-macosx_11_0_arm64.whl", hash = "sha256:0bc82914b13204187d673b5f3d45f93219c38a0e9d0542ba251074f639669789"}, {file = "pylance-0.4.12-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5a4bcce77f99ecd4cbebbadb01e58d5d8138d40eb56bdcdbc3b20b0475e7a472"}, - {file = "pylance-0.4.12-cp38-abi3-win_amd64.whl", hash = "sha256:9616931c5300030adb9626d22515710a127d1e46a46737a7a0f980b52f13627c"}, ] [package.dependencies] @@ -7590,6 +7589,21 @@ requests = ">=2.0.0" [package.extras] rsa = ["oauthlib[signedtoken] (>=3.0.0)"] +[[package]] +name = "requests-toolbelt" +version = "1.0.0" +description = "A utility belt for advanced users of python-requests" +category = "main" +optional = true +python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" +files = [ + {file = "requests-toolbelt-1.0.0.tar.gz", hash = "sha256:7681a0a3d047012b5bdc0ee37d7f8f07ebe76ab08caeccfc3921ce23c88d5bc6"}, + {file = "requests_toolbelt-1.0.0-py2.py3-none-any.whl", hash = "sha256:cccfdd665f0a24fcf4726e690f65639d272bb0637b9b92dfd91a5568ccf6bd06"}, +] + +[package.dependencies] +requests = ">=2.0.1,<3.0.0" + [[package]] name = "responses" version = "0.22.0" @@ -10330,11 +10344,11 @@ cffi = {version = ">=1.11", markers = "platform_python_implementation == \"PyPy\ cffi = ["cffi (>=1.11)"] [extras] -all = ["O365", "aleph-alpha-client", "anthropic", "arxiv", "atlassian-python-api", "azure-cosmos", "azure-identity", "beautifulsoup4", "clickhouse-connect", "cohere", "deeplake", "docarray", "duckduckgo-search", "elasticsearch", "faiss-cpu", "google-api-python-client", "google-search-results", "gptcache", "gql", "hnswlib", "html2text", "huggingface_hub", "jina", "jinja2", "jq", "lancedb", "lark", "lxml", "manifest-ml", "networkx", "nlpcloud", "nltk", "nomic", "openai", "opensearch-py", "pdfminer-six", "pexpect", "pgvector", "pinecone-client", "pinecone-text", "protobuf", "psycopg2-binary", "pyowm", "pypdf", "pytesseract", "pyvespa", "qdrant-client", "redis", "sentence-transformers", "spacy", "steamship", "tensorflow-text", "tiktoken", "torch", "transformers", "weaviate-client", "wikipedia", "wolframalpha"] +all = ["O365", "aleph-alpha-client", "anthropic", "arxiv", "atlassian-python-api", "azure-cosmos", "azure-identity", "beautifulsoup4", "clickhouse-connect", "cohere", "deeplake", "docarray", "duckduckgo-search", "elasticsearch", "faiss-cpu", "google-api-python-client", "google-search-results", "gptcache", "hnswlib", "html2text", "huggingface_hub", "jina", "jinja2", "jq", "lancedb", "lark", "lxml", "manifest-ml", "networkx", "nlpcloud", "nltk", "nomic", "openai", "opensearch-py", "pdfminer-six", "pexpect", "pgvector", "pinecone-client", "pinecone-text", "protobuf", "psycopg2-binary", "pyowm", "pypdf", "pytesseract", "pyvespa", "qdrant-client", "redis", "sentence-transformers", "spacy", "steamship", "tensorflow-text", "tiktoken", "torch", "transformers", "weaviate-client", "wikipedia", "wolframalpha"] azure = ["azure-core", "azure-cosmos", "azure-identity", "openai"] cohere = ["cohere"] embeddings = ["sentence-transformers"] -extended-testing = ["atlassian-python-api", "beautifulsoup4", "beautifulsoup4", "chardet", "html2text", "jq", "lxml", "pandas", "pdfminer-six", "pymupdf", "pypdf", "pypdfium2", "telethon", "tqdm", "zep-python"] +extended-testing = ["atlassian-python-api", "beautifulsoup4", "beautifulsoup4", "chardet", "gql", "html2text", "jq", "lxml", "pandas", "pdfminer-six", "pymupdf", "pypdf", "pypdfium2", "requests-toolbelt", "telethon", "tqdm", "zep-python"] hnswlib = ["docarray", "hnswlib", "protobuf"] in-memory-store = ["docarray"] llms = ["anthropic", "cohere", "huggingface_hub", "manifest-ml", "nlpcloud", "openai", "torch", "transformers"] @@ -10345,4 +10359,4 @@ text-helpers = ["chardet"] [metadata] lock-version = "2.0" python-versions = ">=3.8.1,<4.0" -content-hash = "cd116e8f127ccca1c6f700ef17863bae2f101384677448276fe0962dc3fc4cf6" +content-hash = "5202794df913184aee17f9c6c798edbaa102d5b5152cac885a623ebc93d1e2a3" diff --git a/pyproject.toml b/pyproject.toml index 918a4093..11f869c5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -90,6 +90,7 @@ pandas = {version = "^2.0.1", optional = true} telethon = {version = "^1.28.5", optional = true} zep-python = {version="^0.25", optional=true} chardet = {version="^5.1.0", optional=true} +requests-toolbelt = {version = "^1.0.0", optional = true} [tool.poetry.group.docs.dependencies] @@ -183,7 +184,7 @@ in_memory_store = ["docarray"] hnswlib = ["docarray", "protobuf", "hnswlib"] embeddings = ["sentence-transformers"] azure = ["azure-identity", "azure-cosmos", "openai", "azure-core"] -all = ["anthropic", "cohere", "openai", "nlpcloud", "huggingface_hub", "jina", "manifest-ml", "elasticsearch", "opensearch-py", "google-search-results", "faiss-cpu", "sentence-transformers", "transformers", "spacy", "nltk", "wikipedia", "beautifulsoup4", "tiktoken", "torch", "jinja2", "pinecone-client", "pinecone-text", "weaviate-client", "redis", "google-api-python-client", "wolframalpha", "qdrant-client", "tensorflow-text", "pypdf", "networkx", "nomic", "aleph-alpha-client", "deeplake", "pgvector", "psycopg2-binary", "pyowm", "pytesseract", "html2text", "atlassian-python-api", "gptcache", "duckduckgo-search", "arxiv", "azure-identity", "clickhouse-connect", "azure-cosmos", "lancedb", "lark", "pexpect", "pyvespa", "O365", "jq", "docarray", "protobuf", "hnswlib", "steamship", "pdfminer-six", "gql", "lxml"] +all = ["anthropic", "cohere", "openai", "nlpcloud", "huggingface_hub", "jina", "manifest-ml", "elasticsearch", "opensearch-py", "google-search-results", "faiss-cpu", "sentence-transformers", "transformers", "spacy", "nltk", "wikipedia", "beautifulsoup4", "tiktoken", "torch", "jinja2", "pinecone-client", "pinecone-text", "weaviate-client", "redis", "google-api-python-client", "wolframalpha", "qdrant-client", "tensorflow-text", "pypdf", "networkx", "nomic", "aleph-alpha-client", "deeplake", "pgvector", "psycopg2-binary", "pyowm", "pytesseract", "html2text", "atlassian-python-api", "gptcache", "duckduckgo-search", "arxiv", "azure-identity", "clickhouse-connect", "azure-cosmos", "lancedb", "lark", "pexpect", "pyvespa", "O365", "jq", "docarray", "protobuf", "hnswlib", "steamship", "pdfminer-six", "lxml"] # An extra used to be able to add extended testing. # Please use new-line on formatting to make it easier to add new packages without @@ -203,6 +204,8 @@ extended_testing = [ "pandas", "telethon", "zep-python", + "gql", + "requests_toolbelt", "html2text" ] diff --git a/tests/integration_tests/utilities/test_graphql.py b/tests/integration_tests/utilities/test_graphql.py deleted file mode 100644 index f283df24..00000000 --- a/tests/integration_tests/utilities/test_graphql.py +++ /dev/null @@ -1,32 +0,0 @@ -import json - -import pytest -import responses - -from langchain.utilities.graphql import GraphQLAPIWrapper - -TEST_ENDPOINT = "http://testserver/graphql" - -# Mock GraphQL response for testing -MOCK_RESPONSE = { - "data": {"allUsers": [{"id": 1, "name": "Alice", "email": "alice@example.com"}]} -} - - -@pytest.fixture -def graphql_wrapper() -> GraphQLAPIWrapper: - return GraphQLAPIWrapper( - graphql_endpoint=TEST_ENDPOINT, - custom_headers={"Authorization": "Bearer testtoken"}, - ) - - -@responses.activate -def test_run(graphql_wrapper: GraphQLAPIWrapper) -> None: - responses.add(responses.POST, TEST_ENDPOINT, json=MOCK_RESPONSE, status=200) - - query = "query { allUsers { id, name, email } }" - result = graphql_wrapper.run(query) - - expected_result = json.dumps(MOCK_RESPONSE, indent=2) - assert result == expected_result diff --git a/tests/unit_tests/utilities/test_graphql.py b/tests/unit_tests/utilities/test_graphql.py new file mode 100644 index 00000000..c0e90916 --- /dev/null +++ b/tests/unit_tests/utilities/test_graphql.py @@ -0,0 +1,92 @@ +import json + +import pytest +import responses + +from langchain.utilities.graphql import GraphQLAPIWrapper + +TEST_ENDPOINT = "http://testserver/graphql" + +# Mock GraphQL response for testing +MOCK_RESPONSE = { + "data": { + "allUsers": [{"name": "Alice"}], + "__schema": { + "queryType": {"name": "Query"}, + "types": [ + { + "kind": "OBJECT", + "name": "Query", + "fields": [ + { + "name": "allUsers", + "args": [], + "type": { + "kind": "NON_NULL", + "name": None, + "ofType": { + "kind": "OBJECT", + "name": "allUsers", + "ofType": None, + }, + }, + } + ], + "inputFields": None, + "interfaces": [], + "enumValues": None, + "possibleTypes": None, + }, + { + "kind": "SCALAR", + "name": "String", + }, + { + "kind": "OBJECT", + "name": "allUsers", + "description": None, + "fields": [ + { + "name": "name", + "description": None, + "args": [], + "type": { + "kind": "NON_NULL", + "name": None, + "ofType": { + "kind": "SCALAR", + "name": "String", + "ofType": None, + }, + }, + }, + ], + "inputFields": None, + "interfaces": [], + "enumValues": None, + "possibleTypes": None, + }, + { + "kind": "SCALAR", + "name": "Boolean", + }, + ], + }, + } +} + + +@pytest.mark.requires("gql", "requests_toolbelt") +@responses.activate +def test_run() -> None: + responses.add(responses.POST, TEST_ENDPOINT, json=MOCK_RESPONSE, status=200) + + query = "query { allUsers { name } }" + graphql_wrapper = GraphQLAPIWrapper( + graphql_endpoint=TEST_ENDPOINT, + custom_headers={"Authorization": "Bearer testtoken"}, + ) + result = graphql_wrapper.run(query) + + expected_result = json.dumps(MOCK_RESPONSE["data"], indent=2) + assert result == expected_result