From 80b3ec58690d61927a47854747cd9ea94e43f931 Mon Sep 17 00:00:00 2001 From: Gordon Clark Date: Mon, 24 Jul 2023 11:17:53 +0800 Subject: [PATCH] GitHub toolkit improvements (#8121) Fixes an issue with the github tool where the API returned special objects but the tool was expecting dictionaries. Also added proper docstrings to the GitHubAPIWraper methods and a (very basic) integration test. Maintainer responsibilities: - Agents / Tools / Toolkits: @hinthornw --------- Co-authored-by: Harrison Chase --- libs/langchain/langchain/tools/github/tool.py | 22 -- libs/langchain/langchain/utilities/github.py | 240 +++++++++++------- libs/langchain/poetry.lock | 50 +++- libs/langchain/pyproject.toml | 1 + .../utilities/test_github.py | 21 ++ 5 files changed, 220 insertions(+), 114 deletions(-) create mode 100644 libs/langchain/tests/integration_tests/utilities/test_github.py diff --git a/libs/langchain/langchain/tools/github/tool.py b/libs/langchain/langchain/tools/github/tool.py index f0d506e374..8d3ca9ce65 100644 --- a/libs/langchain/langchain/tools/github/tool.py +++ b/libs/langchain/langchain/tools/github/tool.py @@ -6,28 +6,6 @@ To use this tool, you must first set as environment variables: GITHUB_API_TOKEN GITHUB_REPOSITORY -> format: {owner}/{repo} -TODO: remove below -Below is a sample script that uses the Github tool: - -```python -from langchain.agents import AgentType -from langchain.agents import initialize_agent -from langchain.agents.agent_toolkits.github.toolkit import GitHubToolkit -from langchain.llms import OpenAI -from langchain.utilities.github import GitHubAPIWrapper - -llm = OpenAI(temperature=0) -github = GitHubAPIWrapper() -toolkit = GitHubToolkit.from_github_api_wrapper(github) -agent = initialize_agent( - toolkit.get_tools(), llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True -) - -agent.run( - "{{Enter a prompt here to direct the agent}}" -) - -``` """ from typing import Optional diff --git a/libs/langchain/langchain/utilities/github.py b/libs/langchain/langchain/utilities/github.py index eaf4b98b2f..8a8b30b426 100644 --- a/libs/langchain/langchain/utilities/github.py +++ b/libs/langchain/langchain/utilities/github.py @@ -2,6 +2,7 @@ import json from typing import Any, Dict, List, Optional +from github.Issue import Issue from pydantic import BaseModel, Extra, root_validator from langchain.utils import get_from_dict_or_env @@ -41,6 +42,7 @@ class GitHubAPIWrapper(BaseModel): try: from github import Auth, GithubIntegration + except ImportError: raise ImportError( "PyGithub is not installed. " @@ -69,135 +71,193 @@ class GitHubAPIWrapper(BaseModel): return values - def parse_issues(self, issues: List[dict]) -> List[dict]: + def parse_issues(self, issues: List[Issue]) -> List[dict]: + """ + Extracts title and number from each Issue and puts them in a dictionary + Parameters: + issues(List[Issue]): A list of Github Issue objects + Returns: + List[dict]: A dictionary of issue titles and numbers + """ parsed = [] for issue in issues: - title = issue["title"] - number = issue["number"] + title = issue.title + number = issue.number parsed.append({"title": title, "number": number}) return parsed def get_issues(self) -> str: + """ + Fetches all open issues from the repo + + Returns: + str: A plaintext report containing the number of issues + and each issue's title and number. + """ issues = self.github_repo_instance.get_issues(state="open") - parsed_issues = self.parse_issues(issues) - parsed_issues_str = ( - "Found " + str(len(parsed_issues)) + " issues:\n" + str(parsed_issues) - ) - return parsed_issues_str + if issues.totalCount > 0: + parsed_issues = self.parse_issues(issues) + parsed_issues_str = ( + "Found " + str(len(parsed_issues)) + " issues:\n" + str(parsed_issues) + ) + return parsed_issues_str + else: + return "No open issues available" def get_issue(self, issue_number: int) -> Dict[str, Any]: + """ + Fetches a specific issue and its first 10 comments + Parameters: + issue_number(int): The number for the github issue + Returns: + dict: A doctionary containing the issue's title, + body, and comments as a string + """ issue = self.github_repo_instance.get_issue(number=issue_number) - - # If there are too many comments - # We can't add them all to context so for now we'll just skip - if issue.get_comments().totalCount > 10: - return { - "message": ( - "There are too many comments to add them all to context. " - "Please visit the issue on GitHub to see them all." - ) - } page = 0 - comments = [] - while True: + comments: List[dict] = [] + while len(comments) <= 10: comments_page = issue.get_comments().get_page(page) if len(comments_page) == 0: break for comment in comments_page: - comments.append( - {"body": comment["body"], "user": comment["user"]["login"]} - ) + comments.append({"body": comment.body, "user": comment.user.login}) page += 1 return { - "title": issue["title"], - "body": issue["body"], + "title": issue.title, + "body": issue.body, "comments": str(comments), } def comment_on_issue(self, comment_query: str) -> str: - # comment_query is a string which contains the issue number and the comment - # the issue number is the first word in the string - # the comment is the rest of the string + """ + Adds a comment to a github issue + Parameters: + comment_query(str): a string which contains the issue number, + two newlines, and the comment. + for example: "1\n\nWorking on it now" + adds the comment "working on it now" to issue 1 + Returns: + str: A success or failure message + """ issue_number = int(comment_query.split("\n\n")[0]) comment = comment_query[len(str(issue_number)) + 2 :] - - issue = self.github_repo_instance.get_issue(number=issue_number) - issue.create_comment(comment) - return "Commented on issue " + str(issue_number) + try: + issue = self.github_repo_instance.get_issue(number=issue_number) + issue.create_comment(comment) + return "Commented on issue " + str(issue_number) + except Exception as e: + return "Unable to make comment due to error:\n" + str(e) def create_file(self, file_query: str) -> str: - # file_query is a string which contains the file path and the file contents - # the file path is the first line in the string - # the file contents is the rest of the string + """ + Creates a new file on the Github repo + Parameters: + file_query(str): a string which contains the file path + and the file contents. The file path is the first line + in the string, and the contents are the rest of the string. + For example, "hello_world.md\n# Hello World!" + Returns: + str: A success or failure message + """ file_path = file_query.split("\n")[0] file_contents = file_query[len(file_path) + 2 :] - - self.github_repo_instance.create_file( - path=file_path, - message="Create " + file_path, - content=file_contents, - branch=self.github_branch, - ) - return "Created file " + file_path + try: + exists = self.github_repo_instance.get_contents(file_path) + if exists is None: + self.github_repo_instance.create_file( + path=file_path, + message="Create " + file_path, + content=file_contents, + branch=self.github_branch, + ) + return "Created file " + file_path + else: + return f"File already exists at {file_path}. Use update_file instead" + except Exception as e: + return "Unable to make file due to error:\n" + str(e) def read_file(self, file_path: str) -> str: - # file_path is a string which contains the file path + """ + Reads a file from the github repo + Parameters: + file_path(str): the file path + Returns: + str: The file decoded as a string + """ file = self.github_repo_instance.get_contents(file_path) return file.decoded_content.decode("utf-8") def update_file(self, file_query: str) -> str: - # file_query is a string which contains the file path and the file contents - # the file path is the first line in the string - # the old file contents is wrapped in OLD <<<< and >>>> OLD - # the new file contents is wrapped in NEW <<<< and >>>> NEW - - # for example: - - # /test/test.txt - # OLD <<<< - # old contents - # >>>> OLD - # NEW <<<< - # new contents - # >>>> NEW - - # the old contents will be replaced with the new contents - file_path = file_query.split("\n")[0] - old_file_contents = file_query.split("OLD <<<<")[1].split(">>>> OLD")[0].strip() - new_file_contents = file_query.split("NEW <<<<")[1].split(">>>> NEW")[0].strip() - - file_content = self.read_file(file_path) - updated_file_content = file_content.replace( - old_file_contents, new_file_contents - ) + """ + Updates a file with new content. + Parameters: + file_query(str): Contains the file path and the file contents. + The old file contents is wrapped in OLD <<<< and >>>> OLD + The new file contents is wrapped in NEW <<<< and >>>> NEW + For example: + /test/hello.txt + OLD <<<< + Hello Earth! + >>>> OLD + NEW <<<< + Hello Mars! + >>>> NEW + Returns: + A success or failure message + """ + try: + file_path = file_query.split("\n")[0] + old_file_contents = ( + file_query.split("OLD <<<<")[1].split(">>>> OLD")[0].strip() + ) + new_file_contents = ( + file_query.split("NEW <<<<")[1].split(">>>> NEW")[0].strip() + ) - if file_content == updated_file_content: - return ( - "File content was not updated because the old content was not found. " - "It may be helpful to use the read_file action to get " - "the current file contents." + file_content = self.read_file(file_path) + updated_file_content = file_content.replace( + old_file_contents, new_file_contents ) - self.github_repo_instance.update_file( - path=file_path, - message="Update " + file_path, - content=updated_file_content, - branch=self.github_branch, - sha=self.github_repo_instance.get_contents(file_path).sha, - ) - return "Updated file " + file_path + if file_content == updated_file_content: + return ( + "File content was not updated because old content was not found." + "It may be helpful to use the read_file action to get " + "the current file contents." + ) + + self.github_repo_instance.update_file( + path=file_path, + message="Update " + file_path, + content=updated_file_content, + branch=self.github_branch, + sha=self.github_repo_instance.get_contents(file_path).sha, + ) + return "Updated file " + file_path + except Exception as e: + return "Unable to update file due to error:\n" + str(e) def delete_file(self, file_path: str) -> str: - # file_path is a string which contains the file path - file = self.github_repo_instance.get_contents(file_path) - self.github_repo_instance.delete_file( - path=file_path, - message="Delete " + file_path, - branch=self.github_branch, - sha=file.sha, - ) - return "Deleted file " + file_path + """ + Deletes a file from the repo + Parameters: + file_path(str): Where the file is + Returns: + str: Success or failure message + """ + try: + file = self.github_repo_instance.get_contents(file_path) + self.github_repo_instance.delete_file( + path=file_path, + message="Delete " + file_path, + branch=self.github_branch, + sha=file.sha, + ) + return "Deleted file " + file_path + except Exception as e: + return "Unable to delete file due to error:\n" + str(e) def run(self, mode: str, query: str) -> str: if mode == "get_issues": diff --git a/libs/langchain/poetry.lock b/libs/langchain/poetry.lock index 9f2dcf7805..db85bdffad 100644 --- a/libs/langchain/poetry.lock +++ b/libs/langchain/poetry.lock @@ -2291,7 +2291,7 @@ name = "deprecated" version = "1.2.14" description = "Python @deprecated decorator to deprecate old python classes, functions or methods." category = "main" -optional = true +optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" files = [ {file = "Deprecated-1.2.14-py2.py3-none-any.whl", hash = "sha256:6fac8b097794a90302bdbb17b9b815e732d3c4720583ff1b198499d78470466c"}, @@ -4398,6 +4398,7 @@ optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, !=3.5.*, !=3.6.*" files = [ {file = "jsonpointer-2.4-py2.py3-none-any.whl", hash = "sha256:15d51bba20eea3165644553647711d150376234112651b4f1811022aecad7d7a"}, + {file = "jsonpointer-2.4.tar.gz", hash = "sha256:585cee82b70211fa9e6043b7bb89db6e1aa49524340dde8ad6b63206ea689d88"}, ] [[package]] @@ -7999,6 +8000,24 @@ files = [ [package.dependencies] typing-extensions = "*" +[[package]] +name = "pygithub" +version = "1.59.0" +description = "Use the full Github API v3" +category = "dev" +optional = false +python-versions = ">=3.7" +files = [ + {file = "PyGithub-1.59.0-py3-none-any.whl", hash = "sha256:126bdbae72087d8d038b113aab6b059b4553cb59348e3024bb1a1cae406ace9e"}, + {file = "PyGithub-1.59.0.tar.gz", hash = "sha256:6e05ff49bac3caa7d1d6177a10c6e55a3e20c85b92424cc198571fd0cf786690"}, +] + +[package.dependencies] +deprecated = "*" +pyjwt = {version = ">=2.4.0", extras = ["crypto"]} +pynacl = ">=1.4.0" +requests = ">=2.14.0" + [[package]] name = "pygments" version = "2.15.1" @@ -8204,6 +8223,33 @@ files = [ {file = "PyMuPDF-1.22.3.tar.gz", hash = "sha256:5ecd928e96e63092571020973aa145b57b75707f3a3df97c742e563112615891"}, ] +[[package]] +name = "pynacl" +version = "1.5.0" +description = "Python binding to the Networking and Cryptography (NaCl) library" +category = "dev" +optional = false +python-versions = ">=3.6" +files = [ + {file = "PyNaCl-1.5.0-cp36-abi3-macosx_10_10_universal2.whl", hash = "sha256:401002a4aaa07c9414132aaed7f6836ff98f59277a234704ff66878c2ee4a0d1"}, + {file = "PyNaCl-1.5.0-cp36-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.manylinux_2_24_aarch64.whl", hash = "sha256:52cb72a79269189d4e0dc537556f4740f7f0a9ec41c1322598799b0bdad4ef92"}, + {file = "PyNaCl-1.5.0-cp36-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a36d4a9dda1f19ce6e03c9a784a2921a4b726b02e1c736600ca9c22029474394"}, + {file = "PyNaCl-1.5.0-cp36-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_24_x86_64.whl", hash = "sha256:0c84947a22519e013607c9be43706dd42513f9e6ae5d39d3613ca1e142fba44d"}, + {file = "PyNaCl-1.5.0-cp36-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:06b8f6fa7f5de8d5d2f7573fe8c863c051225a27b61e6860fd047b1775807858"}, + {file = "PyNaCl-1.5.0-cp36-abi3-musllinux_1_1_aarch64.whl", hash = "sha256:a422368fc821589c228f4c49438a368831cb5bbc0eab5ebe1d7fac9dded6567b"}, + {file = "PyNaCl-1.5.0-cp36-abi3-musllinux_1_1_x86_64.whl", hash = "sha256:61f642bf2378713e2c2e1de73444a3778e5f0a38be6fee0fe532fe30060282ff"}, + {file = "PyNaCl-1.5.0-cp36-abi3-win32.whl", hash = "sha256:e46dae94e34b085175f8abb3b0aaa7da40767865ac82c928eeb9e57e1ea8a543"}, + {file = "PyNaCl-1.5.0-cp36-abi3-win_amd64.whl", hash = "sha256:20f42270d27e1b6a29f54032090b972d97f0a1b0948cc52392041ef7831fee93"}, + {file = "PyNaCl-1.5.0.tar.gz", hash = "sha256:8ac7448f09ab85811607bdd21ec2464495ac8b7c66d146bf545b0f08fb9220ba"}, +] + +[package.dependencies] +cffi = ">=1.4.1" + +[package.extras] +docs = ["sphinx (>=1.6.5)", "sphinx-rtd-theme"] +tests = ["hypothesis (>=3.27.0)", "pytest (>=3.2.1,!=3.3.0)"] + [[package]] name = "pynvml" version = "11.5.0" @@ -12431,4 +12477,4 @@ text-helpers = ["chardet"] [metadata] lock-version = "2.0" python-versions = ">=3.8.1,<4.0" -content-hash = "0a2d2f7b59b7bbc1a92478c37bf5b3ae146242dd34d94e5df6d2c7bc89ebe05b" +content-hash = "c9c8d7972feb1de7b227f222f53b9a3f9b497c9aa5d4ef7666a9eafe0db423d6" diff --git a/libs/langchain/pyproject.toml b/libs/langchain/pyproject.toml index a04cba0beb..88b2e0f7d0 100644 --- a/libs/langchain/pyproject.toml +++ b/libs/langchain/pyproject.toml @@ -193,6 +193,7 @@ mastodon-py = "^1.8.1" momento = "^1.5.0" # Please do not add any dependencies in the test_integration group # See instructions above ^^ +pygithub = "^1.59.0" [tool.poetry.group.lint.dependencies] ruff = "^0.0.249" diff --git a/libs/langchain/tests/integration_tests/utilities/test_github.py b/libs/langchain/tests/integration_tests/utilities/test_github.py new file mode 100644 index 0000000000..48997887a2 --- /dev/null +++ b/libs/langchain/tests/integration_tests/utilities/test_github.py @@ -0,0 +1,21 @@ +"""Integration test for Github Wrapper.""" +import pytest + +from langchain.utilities.github import GitHubAPIWrapper + +# Make sure you have set the following env variables: +# GITHUB_REPOSITORY +# GITHUB_BRANCH +# GITHUB_APP_ID +# GITHUB_PRIVATE_KEY + + +@pytest.fixture +def api_client() -> GitHubAPIWrapper: + return GitHubAPIWrapper() + + +def test_get_open_issues(api_client: GitHubAPIWrapper) -> None: + """Basic test to fetch issues""" + issues = api_client.get_issues() + assert len(issues) != 0