mirror of
https://github.com/danielmiessler/fabric
synced 2024-11-10 07:10:31 +00:00
Merge pull request #1 from Snorkell-ai/snorkell_ai/auto_doc_2024-02-07-21-44
[Snorkell.ai] Please review the generated documentation
This commit is contained in:
commit
1bcbe56d06
125
client/utils.py
125
client/utils.py
@ -14,6 +14,21 @@ env_file = os.path.join(config_directory, ".env")
|
|||||||
|
|
||||||
class Standalone:
|
class Standalone:
|
||||||
def __init__(self, args, pattern="", env_file="~/.config/fabric/.env"):
|
def __init__(self, args, pattern="", env_file="~/.config/fabric/.env"):
|
||||||
|
""" Initialize the class with the provided arguments and environment file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
args: The arguments for initialization.
|
||||||
|
pattern: The pattern to be used (default is an empty string).
|
||||||
|
env_file: The path to the environment file (default is "~/.config/fabric/.env").
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
None
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
KeyError: If the "OPENAI_API_KEY" is not found in the environment variables.
|
||||||
|
FileNotFoundError: If no API key is found in the environment variables.
|
||||||
|
"""
|
||||||
|
|
||||||
# Expand the tilde to the full path
|
# Expand the tilde to the full path
|
||||||
env_file = os.path.expanduser(env_file)
|
env_file = os.path.expanduser(env_file)
|
||||||
load_dotenv(env_file)
|
load_dotenv(env_file)
|
||||||
@ -32,6 +47,18 @@ class Standalone:
|
|||||||
self.args = args
|
self.args = args
|
||||||
|
|
||||||
def streamMessage(self, input_data: str):
|
def streamMessage(self, input_data: str):
|
||||||
|
""" Stream a message and handle exceptions.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_data (str): The input data for the message.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
None: If the pattern is not found.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
FileNotFoundError: If the pattern file is not found.
|
||||||
|
"""
|
||||||
|
|
||||||
wisdomFilePath = os.path.join(
|
wisdomFilePath = os.path.join(
|
||||||
config_directory, f"patterns/{self.pattern}/system.md"
|
config_directory, f"patterns/{self.pattern}/system.md"
|
||||||
)
|
)
|
||||||
@ -80,6 +107,18 @@ class Standalone:
|
|||||||
f.write(buffer)
|
f.write(buffer)
|
||||||
|
|
||||||
def sendMessage(self, input_data: str):
|
def sendMessage(self, input_data: str):
|
||||||
|
""" Send a message using the input data and generate a response.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_data (str): The input data to be sent as a message.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
None
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
FileNotFoundError: If the specified pattern file is not found.
|
||||||
|
"""
|
||||||
|
|
||||||
wisdomFilePath = os.path.join(
|
wisdomFilePath = os.path.join(
|
||||||
config_directory, f"patterns/{self.pattern}/system.md"
|
config_directory, f"patterns/{self.pattern}/system.md"
|
||||||
)
|
)
|
||||||
@ -118,6 +157,15 @@ class Standalone:
|
|||||||
|
|
||||||
class Update:
|
class Update:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
""" Initialize the object with default values and update patterns.
|
||||||
|
|
||||||
|
This method initializes the object with default values for root_api_url, config_directory, and pattern_directory.
|
||||||
|
It then creates the pattern_directory if it does not exist and calls the update_patterns method to update the patterns.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
OSError: If there is an issue creating the pattern_directory.
|
||||||
|
"""
|
||||||
|
|
||||||
self.root_api_url = "https://api.github.com/repos/danielmiessler/fabric/contents/patterns?ref=main"
|
self.root_api_url = "https://api.github.com/repos/danielmiessler/fabric/contents/patterns?ref=main"
|
||||||
self.config_directory = os.path.expanduser("~/.config/fabric")
|
self.config_directory = os.path.expanduser("~/.config/fabric")
|
||||||
self.pattern_directory = os.path.join(self.config_directory, "patterns")
|
self.pattern_directory = os.path.join(self.config_directory, "patterns")
|
||||||
@ -125,6 +173,12 @@ class Update:
|
|||||||
self.update_patterns() # Call the update process from a method.
|
self.update_patterns() # Call the update process from a method.
|
||||||
|
|
||||||
def update_patterns(self):
|
def update_patterns(self):
|
||||||
|
""" Update the patterns by downloading from the GitHub directory.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
HTTPError: If there is an HTTP error while downloading patterns.
|
||||||
|
"""
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self.progress_bar = tqdm(desc="Downloading Patterns…", unit="file")
|
self.progress_bar = tqdm(desc="Downloading Patterns…", unit="file")
|
||||||
self.get_github_directory_contents(
|
self.get_github_directory_contents(
|
||||||
@ -145,6 +199,16 @@ class Update:
|
|||||||
sys.exit() # Exit after handling the error.
|
sys.exit() # Exit after handling the error.
|
||||||
|
|
||||||
def download_file(self, url, local_path):
|
def download_file(self, url, local_path):
|
||||||
|
""" Download a file from the given URL and save it to the local path.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
url (str): The URL of the file to be downloaded.
|
||||||
|
local_path (str): The local path where the file will be saved.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
HTTPError: If an HTTP error occurs during the download process.
|
||||||
|
"""
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = requests.get(url)
|
response = requests.get(url)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
@ -156,6 +220,19 @@ class Update:
|
|||||||
sys.exit()
|
sys.exit()
|
||||||
|
|
||||||
def process_item(self, item, local_dir):
|
def process_item(self, item, local_dir):
|
||||||
|
""" Process the given item and save it to the local directory.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
item (dict): The item to be processed, containing information about the type, download URL, name, and URL.
|
||||||
|
local_dir (str): The local directory where the item will be saved.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
None
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
OSError: If there is an issue creating the new directory using os.makedirs.
|
||||||
|
"""
|
||||||
|
|
||||||
if item["type"] == "file":
|
if item["type"] == "file":
|
||||||
self.download_file(
|
self.download_file(
|
||||||
item["download_url"], os.path.join(local_dir, item["name"])
|
item["download_url"], os.path.join(local_dir, item["name"])
|
||||||
@ -166,6 +243,22 @@ class Update:
|
|||||||
self.get_github_directory_contents(item["url"], new_dir)
|
self.get_github_directory_contents(item["url"], new_dir)
|
||||||
|
|
||||||
def get_github_directory_contents(self, api_url, local_dir):
|
def get_github_directory_contents(self, api_url, local_dir):
|
||||||
|
""" Get the contents of a directory from GitHub API and process each item.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
api_url (str): The URL of the GitHub API endpoint for the directory.
|
||||||
|
local_dir (str): The local directory where the contents will be processed.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
None
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
HTTPError: If an HTTP error occurs while fetching the directory contents.
|
||||||
|
If the status code is 403, it prints a message about GitHub API rate limit exceeded
|
||||||
|
and closes the progress bar. For any other status code, it prints a message
|
||||||
|
about failing to fetch directory contents due to an HTTP error.
|
||||||
|
"""
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = requests.get(api_url)
|
response = requests.get(api_url)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
@ -184,22 +277,54 @@ class Update:
|
|||||||
|
|
||||||
class Setup:
|
class Setup:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
""" Initialize the object.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
OSError: If there is an error in creating the pattern directory.
|
||||||
|
"""
|
||||||
|
|
||||||
self.config_directory = os.path.expanduser("~/.config/fabric")
|
self.config_directory = os.path.expanduser("~/.config/fabric")
|
||||||
self.pattern_directory = os.path.join(self.config_directory, "patterns")
|
self.pattern_directory = os.path.join(self.config_directory, "patterns")
|
||||||
os.makedirs(self.pattern_directory, exist_ok=True)
|
os.makedirs(self.pattern_directory, exist_ok=True)
|
||||||
self.env_file = os.path.join(self.config_directory, ".env")
|
self.env_file = os.path.join(self.config_directory, ".env")
|
||||||
|
|
||||||
def api_key(self, api_key):
|
def api_key(self, api_key):
|
||||||
|
""" Set the OpenAI API key in the environment file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
api_key (str): The API key to be set.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
None
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
OSError: If the environment file does not exist or cannot be accessed.
|
||||||
|
"""
|
||||||
|
|
||||||
if not os.path.exists(self.env_file):
|
if not os.path.exists(self.env_file):
|
||||||
with open(self.env_file, "w") as f:
|
with open(self.env_file, "w") as f:
|
||||||
f.write(f"OPENAI_API_KEY={api_key}")
|
f.write(f"OPENAI_API_KEY={api_key}")
|
||||||
print(f"OpenAI API key set to {api_key}")
|
print(f"OpenAI API key set to {api_key}")
|
||||||
|
|
||||||
def patterns(self):
|
def patterns(self):
|
||||||
|
""" Method to update patterns and exit the system.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
None
|
||||||
|
"""
|
||||||
|
|
||||||
Update()
|
Update()
|
||||||
sys.exit()
|
sys.exit()
|
||||||
|
|
||||||
def run(self):
|
def run(self):
|
||||||
|
""" Execute the Fabric program.
|
||||||
|
|
||||||
|
This method prompts the user for their OpenAI API key, sets the API key in the Fabric object, and then calls the patterns method.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
None
|
||||||
|
"""
|
||||||
|
|
||||||
print("Welcome to Fabric. Let's get started.")
|
print("Welcome to Fabric. Let's get started.")
|
||||||
apikey = input("Please enter your OpenAI API key\n")
|
apikey = input("Please enter your OpenAI API key\n")
|
||||||
self.api_key(apikey.strip())
|
self.api_key(apikey.strip())
|
||||||
|
@ -41,8 +41,32 @@ with open("fabric_api_keys.json", "r") as tokens_file:
|
|||||||
|
|
||||||
# The function to check if the token is valid
|
# The function to check if the token is valid
|
||||||
def auth_required(f):
|
def auth_required(f):
|
||||||
|
""" Decorator function to check if the token is valid.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
f: The function to be decorated
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The decorated function
|
||||||
|
"""
|
||||||
|
|
||||||
@wraps(f)
|
@wraps(f)
|
||||||
def decorated_function(*args, **kwargs):
|
def decorated_function(*args, **kwargs):
|
||||||
|
""" Decorated function to handle authentication token and API endpoint.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
*args: Variable length argument list.
|
||||||
|
**kwargs: Arbitrary keyword arguments.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Result of the decorated function.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
KeyError: If 'Authorization' header is not found in the request.
|
||||||
|
TypeError: If 'Authorization' header value is not a string.
|
||||||
|
ValueError: If the authentication token is invalid or expired.
|
||||||
|
"""
|
||||||
|
|
||||||
# Get the authentication token from request header
|
# Get the authentication token from request header
|
||||||
auth_token = request.headers.get("Authorization", "")
|
auth_token = request.headers.get("Authorization", "")
|
||||||
|
|
||||||
@ -65,6 +89,16 @@ def auth_required(f):
|
|||||||
|
|
||||||
# Check for a valid token/user for the given route
|
# Check for a valid token/user for the given route
|
||||||
def check_auth_token(token, route):
|
def check_auth_token(token, route):
|
||||||
|
""" Check if the provided token is valid for the given route and return the corresponding user.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
token (str): The token to be checked for validity.
|
||||||
|
route (str): The route for which the token validity is to be checked.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The user corresponding to the provided token and route if valid, otherwise returns "Unauthorized: You are not authorized for this API".
|
||||||
|
"""
|
||||||
|
|
||||||
# Check if token is valid for the given route and return corresponding user
|
# Check if token is valid for the given route and return corresponding user
|
||||||
if route in valid_tokens and token in valid_tokens[route]:
|
if route in valid_tokens and token in valid_tokens[route]:
|
||||||
return valid_tokens[route][token]
|
return valid_tokens[route][token]
|
||||||
@ -78,11 +112,32 @@ ALLOWLIST_PATTERN = re.compile(r"^[a-zA-Z0-9\s.,;:!?\-]+$")
|
|||||||
|
|
||||||
# Sanitize the content, sort of. Prompt injection is the main threat so this isn't a huge deal
|
# Sanitize the content, sort of. Prompt injection is the main threat so this isn't a huge deal
|
||||||
def sanitize_content(content):
|
def sanitize_content(content):
|
||||||
|
""" Sanitize the content by removing characters that do not match the ALLOWLIST_PATTERN.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
content (str): The content to be sanitized.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The sanitized content.
|
||||||
|
"""
|
||||||
|
|
||||||
return "".join(char for char in content if ALLOWLIST_PATTERN.match(char))
|
return "".join(char for char in content if ALLOWLIST_PATTERN.match(char))
|
||||||
|
|
||||||
|
|
||||||
# Pull the URL content's from the GitHub repo
|
# Pull the URL content's from the GitHub repo
|
||||||
def fetch_content_from_url(url):
|
def fetch_content_from_url(url):
|
||||||
|
""" Fetches content from the given URL.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
url (str): The URL from which to fetch content.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The sanitized content fetched from the URL.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
requests.RequestException: If an error occurs while making the request to the URL.
|
||||||
|
"""
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = requests.get(url)
|
response = requests.get(url)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
@ -99,6 +154,15 @@ def fetch_content_from_url(url):
|
|||||||
@app.route("/extwis", methods=["POST"])
|
@app.route("/extwis", methods=["POST"])
|
||||||
@auth_required # Require authentication
|
@auth_required # Require authentication
|
||||||
def extwis():
|
def extwis():
|
||||||
|
""" Extract wisdom from user input using OpenAI's GPT-4 model.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
JSON: A JSON response containing the generated response or an error message.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
Exception: If there is an error during the API call.
|
||||||
|
"""
|
||||||
|
|
||||||
data = request.get_json()
|
data = request.get_json()
|
||||||
|
|
||||||
# Warn if there's no input
|
# Warn if there's no input
|
||||||
|
@ -16,6 +16,19 @@ import os
|
|||||||
|
|
||||||
|
|
||||||
def send_request(prompt, endpoint):
|
def send_request(prompt, endpoint):
|
||||||
|
""" Send a request to the specified endpoint of an HTTP-only server.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt (str): The input prompt for the request.
|
||||||
|
endpoint (str): The endpoint to which the request will be sent.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The response from the server.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
KeyError: If the response JSON does not contain the expected "response" key.
|
||||||
|
"""
|
||||||
|
|
||||||
base_url = "http://127.0.0.1:13337"
|
base_url = "http://127.0.0.1:13337"
|
||||||
url = f"{base_url}{endpoint}"
|
url = f"{base_url}{endpoint}"
|
||||||
headers = {
|
headers = {
|
||||||
@ -37,6 +50,15 @@ app.secret_key = "your_secret_key"
|
|||||||
|
|
||||||
@app.route("/favicon.ico")
|
@app.route("/favicon.ico")
|
||||||
def favicon():
|
def favicon():
|
||||||
|
""" Send the favicon.ico file from the static directory.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Response object with the favicon.ico file
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
-
|
||||||
|
"""
|
||||||
|
|
||||||
return send_from_directory(
|
return send_from_directory(
|
||||||
os.path.join(app.root_path, "static"),
|
os.path.join(app.root_path, "static"),
|
||||||
"favicon.ico",
|
"favicon.ico",
|
||||||
@ -46,6 +68,12 @@ def favicon():
|
|||||||
|
|
||||||
@app.route("/", methods=["GET", "POST"])
|
@app.route("/", methods=["GET", "POST"])
|
||||||
def index():
|
def index():
|
||||||
|
""" Process the POST request and send a request to the specified API endpoint.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The rendered HTML template with the response data.
|
||||||
|
"""
|
||||||
|
|
||||||
if request.method == "POST":
|
if request.method == "POST":
|
||||||
prompt = request.form.get("prompt")
|
prompt = request.form.get("prompt")
|
||||||
endpoint = request.form.get("api")
|
endpoint = request.form.get("api")
|
||||||
|
Loading…
Reference in New Issue
Block a user