diff --git a/client/utils.py b/client/utils.py index ed4fb13..893c714 100644 --- a/client/utils.py +++ b/client/utils.py @@ -14,6 +14,21 @@ env_file = os.path.join(config_directory, ".env") class Standalone: def __init__(self, args, pattern="", env_file="~/.config/fabric/.env"): + """ Initialize the class with the provided arguments and environment file. + + Args: + args: The arguments for initialization. + pattern: The pattern to be used (default is an empty string). + env_file: The path to the environment file (default is "~/.config/fabric/.env"). + + Returns: + None + + Raises: + KeyError: If the "OPENAI_API_KEY" is not found in the environment variables. + FileNotFoundError: If no API key is found in the environment variables. + """ + # Expand the tilde to the full path env_file = os.path.expanduser(env_file) load_dotenv(env_file) @@ -32,6 +47,18 @@ class Standalone: self.args = args def streamMessage(self, input_data: str): + """ Stream a message and handle exceptions. + + Args: + input_data (str): The input data for the message. + + Returns: + None: If the pattern is not found. + + Raises: + FileNotFoundError: If the pattern file is not found. + """ + wisdomFilePath = os.path.join( config_directory, f"patterns/{self.pattern}/system.md" ) @@ -80,6 +107,18 @@ class Standalone: f.write(buffer) def sendMessage(self, input_data: str): + """ Send a message using the input data and generate a response. + + Args: + input_data (str): The input data to be sent as a message. + + Returns: + None + + Raises: + FileNotFoundError: If the specified pattern file is not found. + """ + wisdomFilePath = os.path.join( config_directory, f"patterns/{self.pattern}/system.md" ) @@ -118,6 +157,15 @@ class Standalone: class Update: def __init__(self): + """ Initialize the object with default values and update patterns. + + This method initializes the object with default values for root_api_url, config_directory, and pattern_directory. + It then creates the pattern_directory if it does not exist and calls the update_patterns method to update the patterns. + + Raises: + OSError: If there is an issue creating the pattern_directory. + """ + self.root_api_url = "https://api.github.com/repos/danielmiessler/fabric/contents/patterns?ref=main" self.config_directory = os.path.expanduser("~/.config/fabric") self.pattern_directory = os.path.join(self.config_directory, "patterns") @@ -125,6 +173,12 @@ class Update: self.update_patterns() # Call the update process from a method. def update_patterns(self): + """ Update the patterns by downloading from the GitHub directory. + + Raises: + HTTPError: If there is an HTTP error while downloading patterns. + """ + try: self.progress_bar = tqdm(desc="Downloading Patterns…", unit="file") self.get_github_directory_contents( @@ -145,6 +199,16 @@ class Update: sys.exit() # Exit after handling the error. def download_file(self, url, local_path): + """ Download a file from the given URL and save it to the local path. + + Args: + url (str): The URL of the file to be downloaded. + local_path (str): The local path where the file will be saved. + + Raises: + HTTPError: If an HTTP error occurs during the download process. + """ + try: response = requests.get(url) response.raise_for_status() @@ -156,6 +220,19 @@ class Update: sys.exit() def process_item(self, item, local_dir): + """ Process the given item and save it to the local directory. + + Args: + item (dict): The item to be processed, containing information about the type, download URL, name, and URL. + local_dir (str): The local directory where the item will be saved. + + Returns: + None + + Raises: + OSError: If there is an issue creating the new directory using os.makedirs. + """ + if item["type"] == "file": self.download_file( item["download_url"], os.path.join(local_dir, item["name"]) @@ -166,6 +243,22 @@ class Update: self.get_github_directory_contents(item["url"], new_dir) def get_github_directory_contents(self, api_url, local_dir): + """ Get the contents of a directory from GitHub API and process each item. + + Args: + api_url (str): The URL of the GitHub API endpoint for the directory. + local_dir (str): The local directory where the contents will be processed. + + Returns: + None + + Raises: + HTTPError: If an HTTP error occurs while fetching the directory contents. + If the status code is 403, it prints a message about GitHub API rate limit exceeded + and closes the progress bar. For any other status code, it prints a message + about failing to fetch directory contents due to an HTTP error. + """ + try: response = requests.get(api_url) response.raise_for_status() @@ -184,22 +277,54 @@ class Update: class Setup: def __init__(self): + """ Initialize the object. + + Raises: + OSError: If there is an error in creating the pattern directory. + """ + self.config_directory = os.path.expanduser("~/.config/fabric") self.pattern_directory = os.path.join(self.config_directory, "patterns") os.makedirs(self.pattern_directory, exist_ok=True) self.env_file = os.path.join(self.config_directory, ".env") def api_key(self, api_key): + """ Set the OpenAI API key in the environment file. + + Args: + api_key (str): The API key to be set. + + Returns: + None + + Raises: + OSError: If the environment file does not exist or cannot be accessed. + """ + if not os.path.exists(self.env_file): with open(self.env_file, "w") as f: f.write(f"OPENAI_API_KEY={api_key}") print(f"OpenAI API key set to {api_key}") def patterns(self): + """ Method to update patterns and exit the system. + + Returns: + None + """ + Update() sys.exit() def run(self): + """ Execute the Fabric program. + + This method prompts the user for their OpenAI API key, sets the API key in the Fabric object, and then calls the patterns method. + + Returns: + None + """ + print("Welcome to Fabric. Let's get started.") apikey = input("Please enter your OpenAI API key\n") self.api_key(apikey.strip()) diff --git a/server/fabric_api_server.py b/server/fabric_api_server.py index b4a95b4..8f1cf55 100644 --- a/server/fabric_api_server.py +++ b/server/fabric_api_server.py @@ -41,8 +41,32 @@ with open("fabric_api_keys.json", "r") as tokens_file: # The function to check if the token is valid def auth_required(f): + """ Decorator function to check if the token is valid. + + Args: + f: The function to be decorated + + Returns: + The decorated function + """ + @wraps(f) def decorated_function(*args, **kwargs): + """ Decorated function to handle authentication token and API endpoint. + + Args: + *args: Variable length argument list. + **kwargs: Arbitrary keyword arguments. + + Returns: + Result of the decorated function. + + Raises: + KeyError: If 'Authorization' header is not found in the request. + TypeError: If 'Authorization' header value is not a string. + ValueError: If the authentication token is invalid or expired. + """ + # Get the authentication token from request header auth_token = request.headers.get("Authorization", "") @@ -65,6 +89,16 @@ def auth_required(f): # Check for a valid token/user for the given route def check_auth_token(token, route): + """ Check if the provided token is valid for the given route and return the corresponding user. + + Args: + token (str): The token to be checked for validity. + route (str): The route for which the token validity is to be checked. + + Returns: + str: The user corresponding to the provided token and route if valid, otherwise returns "Unauthorized: You are not authorized for this API". + """ + # Check if token is valid for the given route and return corresponding user if route in valid_tokens and token in valid_tokens[route]: return valid_tokens[route][token] @@ -78,11 +112,32 @@ ALLOWLIST_PATTERN = re.compile(r"^[a-zA-Z0-9\s.,;:!?\-]+$") # Sanitize the content, sort of. Prompt injection is the main threat so this isn't a huge deal def sanitize_content(content): + """ Sanitize the content by removing characters that do not match the ALLOWLIST_PATTERN. + + Args: + content (str): The content to be sanitized. + + Returns: + str: The sanitized content. + """ + return "".join(char for char in content if ALLOWLIST_PATTERN.match(char)) # Pull the URL content's from the GitHub repo def fetch_content_from_url(url): + """ Fetches content from the given URL. + + Args: + url (str): The URL from which to fetch content. + + Returns: + str: The sanitized content fetched from the URL. + + Raises: + requests.RequestException: If an error occurs while making the request to the URL. + """ + try: response = requests.get(url) response.raise_for_status() @@ -99,6 +154,15 @@ def fetch_content_from_url(url): @app.route("/extwis", methods=["POST"]) @auth_required # Require authentication def extwis(): + """ Extract wisdom from user input using OpenAI's GPT-4 model. + + Returns: + JSON: A JSON response containing the generated response or an error message. + + Raises: + Exception: If there is an error during the API call. + """ + data = request.get_json() # Warn if there's no input diff --git a/server/fabric_web_interface/fabric_web_server.py b/server/fabric_web_interface/fabric_web_server.py index dad4a96..eab17b6 100644 --- a/server/fabric_web_interface/fabric_web_server.py +++ b/server/fabric_web_interface/fabric_web_server.py @@ -16,6 +16,19 @@ import os def send_request(prompt, endpoint): + """ Send a request to the specified endpoint of an HTTP-only server. + + Args: + prompt (str): The input prompt for the request. + endpoint (str): The endpoint to which the request will be sent. + + Returns: + str: The response from the server. + + Raises: + KeyError: If the response JSON does not contain the expected "response" key. + """ + base_url = "http://127.0.0.1:13337" url = f"{base_url}{endpoint}" headers = { @@ -37,6 +50,15 @@ app.secret_key = "your_secret_key" @app.route("/favicon.ico") def favicon(): + """ Send the favicon.ico file from the static directory. + + Returns: + Response object with the favicon.ico file + + Raises: + - + """ + return send_from_directory( os.path.join(app.root_path, "static"), "favicon.ico", @@ -46,6 +68,12 @@ def favicon(): @app.route("/", methods=["GET", "POST"]) def index(): + """ Process the POST request and send a request to the specified API endpoint. + + Returns: + str: The rendered HTML template with the response data. + """ + if request.method == "POST": prompt = request.form.get("prompt") endpoint = request.form.get("api")