GitHub toolkit improvements (#8121)

Fixes an issue with the github tool where the API returned special
objects but the tool was expecting dictionaries.

Also added proper docstrings to the GitHubAPIWraper methods and a (very
basic) integration test.

Maintainer responsibilities:
  - Agents / Tools / Toolkits: @hinthornw

---------

Co-authored-by: Harrison Chase <hw.chase.17@gmail.com>
pull/8162/head
Gordon Clark 1 year ago committed by GitHub
parent 33fd6184ba
commit 80b3ec5869
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -6,28 +6,6 @@ To use this tool, you must first set as environment variables:
GITHUB_API_TOKEN GITHUB_API_TOKEN
GITHUB_REPOSITORY -> format: {owner}/{repo} GITHUB_REPOSITORY -> format: {owner}/{repo}
TODO: remove below
Below is a sample script that uses the Github tool:
```python
from langchain.agents import AgentType
from langchain.agents import initialize_agent
from langchain.agents.agent_toolkits.github.toolkit import GitHubToolkit
from langchain.llms import OpenAI
from langchain.utilities.github import GitHubAPIWrapper
llm = OpenAI(temperature=0)
github = GitHubAPIWrapper()
toolkit = GitHubToolkit.from_github_api_wrapper(github)
agent = initialize_agent(
toolkit.get_tools(), llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True
)
agent.run(
"{{Enter a prompt here to direct the agent}}"
)
```
""" """
from typing import Optional from typing import Optional

@ -2,6 +2,7 @@
import json import json
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
from github.Issue import Issue
from pydantic import BaseModel, Extra, root_validator from pydantic import BaseModel, Extra, root_validator
from langchain.utils import get_from_dict_or_env from langchain.utils import get_from_dict_or_env
@ -41,6 +42,7 @@ class GitHubAPIWrapper(BaseModel):
try: try:
from github import Auth, GithubIntegration from github import Auth, GithubIntegration
except ImportError: except ImportError:
raise ImportError( raise ImportError(
"PyGithub is not installed. " "PyGithub is not installed. "
@ -69,135 +71,193 @@ class GitHubAPIWrapper(BaseModel):
return values return values
def parse_issues(self, issues: List[dict]) -> List[dict]: def parse_issues(self, issues: List[Issue]) -> List[dict]:
"""
Extracts title and number from each Issue and puts them in a dictionary
Parameters:
issues(List[Issue]): A list of Github Issue objects
Returns:
List[dict]: A dictionary of issue titles and numbers
"""
parsed = [] parsed = []
for issue in issues: for issue in issues:
title = issue["title"] title = issue.title
number = issue["number"] number = issue.number
parsed.append({"title": title, "number": number}) parsed.append({"title": title, "number": number})
return parsed return parsed
def get_issues(self) -> str: def get_issues(self) -> str:
"""
Fetches all open issues from the repo
Returns:
str: A plaintext report containing the number of issues
and each issue's title and number.
"""
issues = self.github_repo_instance.get_issues(state="open") issues = self.github_repo_instance.get_issues(state="open")
parsed_issues = self.parse_issues(issues) if issues.totalCount > 0:
parsed_issues_str = ( parsed_issues = self.parse_issues(issues)
"Found " + str(len(parsed_issues)) + " issues:\n" + str(parsed_issues) parsed_issues_str = (
) "Found " + str(len(parsed_issues)) + " issues:\n" + str(parsed_issues)
return parsed_issues_str )
return parsed_issues_str
else:
return "No open issues available"
def get_issue(self, issue_number: int) -> Dict[str, Any]: def get_issue(self, issue_number: int) -> Dict[str, Any]:
"""
Fetches a specific issue and its first 10 comments
Parameters:
issue_number(int): The number for the github issue
Returns:
dict: A doctionary containing the issue's title,
body, and comments as a string
"""
issue = self.github_repo_instance.get_issue(number=issue_number) issue = self.github_repo_instance.get_issue(number=issue_number)
# If there are too many comments
# We can't add them all to context so for now we'll just skip
if issue.get_comments().totalCount > 10:
return {
"message": (
"There are too many comments to add them all to context. "
"Please visit the issue on GitHub to see them all."
)
}
page = 0 page = 0
comments = [] comments: List[dict] = []
while True: while len(comments) <= 10:
comments_page = issue.get_comments().get_page(page) comments_page = issue.get_comments().get_page(page)
if len(comments_page) == 0: if len(comments_page) == 0:
break break
for comment in comments_page: for comment in comments_page:
comments.append( comments.append({"body": comment.body, "user": comment.user.login})
{"body": comment["body"], "user": comment["user"]["login"]}
)
page += 1 page += 1
return { return {
"title": issue["title"], "title": issue.title,
"body": issue["body"], "body": issue.body,
"comments": str(comments), "comments": str(comments),
} }
def comment_on_issue(self, comment_query: str) -> str: def comment_on_issue(self, comment_query: str) -> str:
# comment_query is a string which contains the issue number and the comment """
# the issue number is the first word in the string Adds a comment to a github issue
# the comment is the rest of the string Parameters:
comment_query(str): a string which contains the issue number,
two newlines, and the comment.
for example: "1\n\nWorking on it now"
adds the comment "working on it now" to issue 1
Returns:
str: A success or failure message
"""
issue_number = int(comment_query.split("\n\n")[0]) issue_number = int(comment_query.split("\n\n")[0])
comment = comment_query[len(str(issue_number)) + 2 :] comment = comment_query[len(str(issue_number)) + 2 :]
try:
issue = self.github_repo_instance.get_issue(number=issue_number) issue = self.github_repo_instance.get_issue(number=issue_number)
issue.create_comment(comment) issue.create_comment(comment)
return "Commented on issue " + str(issue_number) return "Commented on issue " + str(issue_number)
except Exception as e:
return "Unable to make comment due to error:\n" + str(e)
def create_file(self, file_query: str) -> str: def create_file(self, file_query: str) -> str:
# file_query is a string which contains the file path and the file contents """
# the file path is the first line in the string Creates a new file on the Github repo
# the file contents is the rest of the string Parameters:
file_query(str): a string which contains the file path
and the file contents. The file path is the first line
in the string, and the contents are the rest of the string.
For example, "hello_world.md\n# Hello World!"
Returns:
str: A success or failure message
"""
file_path = file_query.split("\n")[0] file_path = file_query.split("\n")[0]
file_contents = file_query[len(file_path) + 2 :] file_contents = file_query[len(file_path) + 2 :]
try:
self.github_repo_instance.create_file( exists = self.github_repo_instance.get_contents(file_path)
path=file_path, if exists is None:
message="Create " + file_path, self.github_repo_instance.create_file(
content=file_contents, path=file_path,
branch=self.github_branch, message="Create " + file_path,
) content=file_contents,
return "Created file " + file_path branch=self.github_branch,
)
return "Created file " + file_path
else:
return f"File already exists at {file_path}. Use update_file instead"
except Exception as e:
return "Unable to make file due to error:\n" + str(e)
def read_file(self, file_path: str) -> str: def read_file(self, file_path: str) -> str:
# file_path is a string which contains the file path """
Reads a file from the github repo
Parameters:
file_path(str): the file path
Returns:
str: The file decoded as a string
"""
file = self.github_repo_instance.get_contents(file_path) file = self.github_repo_instance.get_contents(file_path)
return file.decoded_content.decode("utf-8") return file.decoded_content.decode("utf-8")
def update_file(self, file_query: str) -> str: def update_file(self, file_query: str) -> str:
# file_query is a string which contains the file path and the file contents """
# the file path is the first line in the string Updates a file with new content.
# the old file contents is wrapped in OLD <<<< and >>>> OLD Parameters:
# the new file contents is wrapped in NEW <<<< and >>>> NEW file_query(str): Contains the file path and the file contents.
The old file contents is wrapped in OLD <<<< and >>>> OLD
# for example: The new file contents is wrapped in NEW <<<< and >>>> NEW
For example:
# /test/test.txt /test/hello.txt
# OLD <<<< OLD <<<<
# old contents Hello Earth!
# >>>> OLD >>>> OLD
# NEW <<<< NEW <<<<
# new contents Hello Mars!
# >>>> NEW >>>> NEW
Returns:
# the old contents will be replaced with the new contents A success or failure message
file_path = file_query.split("\n")[0] """
old_file_contents = file_query.split("OLD <<<<")[1].split(">>>> OLD")[0].strip() try:
new_file_contents = file_query.split("NEW <<<<")[1].split(">>>> NEW")[0].strip() file_path = file_query.split("\n")[0]
old_file_contents = (
file_content = self.read_file(file_path) file_query.split("OLD <<<<")[1].split(">>>> OLD")[0].strip()
updated_file_content = file_content.replace( )
old_file_contents, new_file_contents new_file_contents = (
) file_query.split("NEW <<<<")[1].split(">>>> NEW")[0].strip()
)
if file_content == updated_file_content: file_content = self.read_file(file_path)
return ( updated_file_content = file_content.replace(
"File content was not updated because the old content was not found. " old_file_contents, new_file_contents
"It may be helpful to use the read_file action to get "
"the current file contents."
) )
self.github_repo_instance.update_file( if file_content == updated_file_content:
path=file_path, return (
message="Update " + file_path, "File content was not updated because old content was not found."
content=updated_file_content, "It may be helpful to use the read_file action to get "
branch=self.github_branch, "the current file contents."
sha=self.github_repo_instance.get_contents(file_path).sha, )
)
return "Updated file " + file_path self.github_repo_instance.update_file(
path=file_path,
message="Update " + file_path,
content=updated_file_content,
branch=self.github_branch,
sha=self.github_repo_instance.get_contents(file_path).sha,
)
return "Updated file " + file_path
except Exception as e:
return "Unable to update file due to error:\n" + str(e)
def delete_file(self, file_path: str) -> str: def delete_file(self, file_path: str) -> str:
# file_path is a string which contains the file path """
file = self.github_repo_instance.get_contents(file_path) Deletes a file from the repo
self.github_repo_instance.delete_file( Parameters:
path=file_path, file_path(str): Where the file is
message="Delete " + file_path, Returns:
branch=self.github_branch, str: Success or failure message
sha=file.sha, """
) try:
return "Deleted file " + file_path file = self.github_repo_instance.get_contents(file_path)
self.github_repo_instance.delete_file(
path=file_path,
message="Delete " + file_path,
branch=self.github_branch,
sha=file.sha,
)
return "Deleted file " + file_path
except Exception as e:
return "Unable to delete file due to error:\n" + str(e)
def run(self, mode: str, query: str) -> str: def run(self, mode: str, query: str) -> str:
if mode == "get_issues": if mode == "get_issues":

@ -2291,7 +2291,7 @@ name = "deprecated"
version = "1.2.14" version = "1.2.14"
description = "Python @deprecated decorator to deprecate old python classes, functions or methods." description = "Python @deprecated decorator to deprecate old python classes, functions or methods."
category = "main" category = "main"
optional = true optional = false
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*"
files = [ files = [
{file = "Deprecated-1.2.14-py2.py3-none-any.whl", hash = "sha256:6fac8b097794a90302bdbb17b9b815e732d3c4720583ff1b198499d78470466c"}, {file = "Deprecated-1.2.14-py2.py3-none-any.whl", hash = "sha256:6fac8b097794a90302bdbb17b9b815e732d3c4720583ff1b198499d78470466c"},
@ -4398,6 +4398,7 @@ optional = false
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, !=3.5.*, !=3.6.*" python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, !=3.5.*, !=3.6.*"
files = [ files = [
{file = "jsonpointer-2.4-py2.py3-none-any.whl", hash = "sha256:15d51bba20eea3165644553647711d150376234112651b4f1811022aecad7d7a"}, {file = "jsonpointer-2.4-py2.py3-none-any.whl", hash = "sha256:15d51bba20eea3165644553647711d150376234112651b4f1811022aecad7d7a"},
{file = "jsonpointer-2.4.tar.gz", hash = "sha256:585cee82b70211fa9e6043b7bb89db6e1aa49524340dde8ad6b63206ea689d88"},
] ]
[[package]] [[package]]
@ -7999,6 +8000,24 @@ files = [
[package.dependencies] [package.dependencies]
typing-extensions = "*" typing-extensions = "*"
[[package]]
name = "pygithub"
version = "1.59.0"
description = "Use the full Github API v3"
category = "dev"
optional = false
python-versions = ">=3.7"
files = [
{file = "PyGithub-1.59.0-py3-none-any.whl", hash = "sha256:126bdbae72087d8d038b113aab6b059b4553cb59348e3024bb1a1cae406ace9e"},
{file = "PyGithub-1.59.0.tar.gz", hash = "sha256:6e05ff49bac3caa7d1d6177a10c6e55a3e20c85b92424cc198571fd0cf786690"},
]
[package.dependencies]
deprecated = "*"
pyjwt = {version = ">=2.4.0", extras = ["crypto"]}
pynacl = ">=1.4.0"
requests = ">=2.14.0"
[[package]] [[package]]
name = "pygments" name = "pygments"
version = "2.15.1" version = "2.15.1"
@ -8204,6 +8223,33 @@ files = [
{file = "PyMuPDF-1.22.3.tar.gz", hash = "sha256:5ecd928e96e63092571020973aa145b57b75707f3a3df97c742e563112615891"}, {file = "PyMuPDF-1.22.3.tar.gz", hash = "sha256:5ecd928e96e63092571020973aa145b57b75707f3a3df97c742e563112615891"},
] ]
[[package]]
name = "pynacl"
version = "1.5.0"
description = "Python binding to the Networking and Cryptography (NaCl) library"
category = "dev"
optional = false
python-versions = ">=3.6"
files = [
{file = "PyNaCl-1.5.0-cp36-abi3-macosx_10_10_universal2.whl", hash = "sha256:401002a4aaa07c9414132aaed7f6836ff98f59277a234704ff66878c2ee4a0d1"},
{file = "PyNaCl-1.5.0-cp36-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.manylinux_2_24_aarch64.whl", hash = "sha256:52cb72a79269189d4e0dc537556f4740f7f0a9ec41c1322598799b0bdad4ef92"},
{file = "PyNaCl-1.5.0-cp36-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a36d4a9dda1f19ce6e03c9a784a2921a4b726b02e1c736600ca9c22029474394"},
{file = "PyNaCl-1.5.0-cp36-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_24_x86_64.whl", hash = "sha256:0c84947a22519e013607c9be43706dd42513f9e6ae5d39d3613ca1e142fba44d"},
{file = "PyNaCl-1.5.0-cp36-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:06b8f6fa7f5de8d5d2f7573fe8c863c051225a27b61e6860fd047b1775807858"},
{file = "PyNaCl-1.5.0-cp36-abi3-musllinux_1_1_aarch64.whl", hash = "sha256:a422368fc821589c228f4c49438a368831cb5bbc0eab5ebe1d7fac9dded6567b"},
{file = "PyNaCl-1.5.0-cp36-abi3-musllinux_1_1_x86_64.whl", hash = "sha256:61f642bf2378713e2c2e1de73444a3778e5f0a38be6fee0fe532fe30060282ff"},
{file = "PyNaCl-1.5.0-cp36-abi3-win32.whl", hash = "sha256:e46dae94e34b085175f8abb3b0aaa7da40767865ac82c928eeb9e57e1ea8a543"},
{file = "PyNaCl-1.5.0-cp36-abi3-win_amd64.whl", hash = "sha256:20f42270d27e1b6a29f54032090b972d97f0a1b0948cc52392041ef7831fee93"},
{file = "PyNaCl-1.5.0.tar.gz", hash = "sha256:8ac7448f09ab85811607bdd21ec2464495ac8b7c66d146bf545b0f08fb9220ba"},
]
[package.dependencies]
cffi = ">=1.4.1"
[package.extras]
docs = ["sphinx (>=1.6.5)", "sphinx-rtd-theme"]
tests = ["hypothesis (>=3.27.0)", "pytest (>=3.2.1,!=3.3.0)"]
[[package]] [[package]]
name = "pynvml" name = "pynvml"
version = "11.5.0" version = "11.5.0"
@ -12431,4 +12477,4 @@ text-helpers = ["chardet"]
[metadata] [metadata]
lock-version = "2.0" lock-version = "2.0"
python-versions = ">=3.8.1,<4.0" python-versions = ">=3.8.1,<4.0"
content-hash = "0a2d2f7b59b7bbc1a92478c37bf5b3ae146242dd34d94e5df6d2c7bc89ebe05b" content-hash = "c9c8d7972feb1de7b227f222f53b9a3f9b497c9aa5d4ef7666a9eafe0db423d6"

@ -193,6 +193,7 @@ mastodon-py = "^1.8.1"
momento = "^1.5.0" momento = "^1.5.0"
# Please do not add any dependencies in the test_integration group # Please do not add any dependencies in the test_integration group
# See instructions above ^^ # See instructions above ^^
pygithub = "^1.59.0"
[tool.poetry.group.lint.dependencies] [tool.poetry.group.lint.dependencies]
ruff = "^0.0.249" ruff = "^0.0.249"

@ -0,0 +1,21 @@
"""Integration test for Github Wrapper."""
import pytest
from langchain.utilities.github import GitHubAPIWrapper
# Make sure you have set the following env variables:
# GITHUB_REPOSITORY
# GITHUB_BRANCH
# GITHUB_APP_ID
# GITHUB_PRIVATE_KEY
@pytest.fixture
def api_client() -> GitHubAPIWrapper:
return GitHubAPIWrapper()
def test_get_open_issues(api_client: GitHubAPIWrapper) -> None:
"""Basic test to fetch issues"""
issues = api_client.get_issues()
assert len(issues) != 0
Loading…
Cancel
Save