diff --git a/README.md b/README.md index c30bfc1..6b2f2af 100644 --- a/README.md +++ b/README.md @@ -55,6 +55,7 @@ Install the following on both the WG server and WG peer 1. The project assumes that you already have a working WG installation on both the WG peer and WG server. 1. The project assumes that you already have a WG interface like `wg0`. 1. The project assumes that you already have a WG conf file like `/etc/wireguard/wg0.conf`. ***Take a backup of these files***. +1. Before running the following scripts, the WG interface is expected to be active/online such that the peer is able to ping the server. 1. Start the WG server script before the WG peer script ### On the WG Server diff --git a/nr_wg_mtu_finder/main.py b/nr_wg_mtu_finder/main.py index d03ba8c..9e52af8 100644 --- a/nr_wg_mtu_finder/main.py +++ b/nr_wg_mtu_finder/main.py @@ -5,6 +5,7 @@ import sys from pydantic import BaseModel, StrictStr, StrictInt, root_validator from typing_extensions import Literal from .mtu_finder import MTUFinder +from distutils.util import strtobool def signal_handler(sig, frame): @@ -30,6 +31,8 @@ class ArgsModel(BaseModel): server_ip: StrictStr server_port: int = 5000 + peer_skip_errors: bool = True + interface: StrictStr = "wg0" conf_file: StrictStr = "/etc/wireguard/wg0.conf" @@ -94,6 +97,13 @@ def setup_args(): required=False, default="/etc/wireguard/wg0.conf", ) + parser.add_argument( + "--peer-skip-errors", + help="Skip errors when an expected error occurs in peer mode during MTU loop.", + required=False, + default=True, + type=strtobool, + ) args = parser.parse_args() return args diff --git a/nr_wg_mtu_finder/mtu_finder.py b/nr_wg_mtu_finder/mtu_finder.py index 8e127e7..66bf26d 100644 --- a/nr_wg_mtu_finder/mtu_finder.py +++ b/nr_wg_mtu_finder/mtu_finder.py @@ -10,9 +10,22 @@ from nr_wg_mtu_finder.sync_server import run_sync_server from nr_wg_mtu_finder.plot import plot_log +class ReturncodeError(Exception): + pass + + class MTUFinder(object): def __init__( - self, mode, server_ip, server_port, interface, conf_file, mtu_max, mtu_min, mtu_step + self, + mode, + server_ip, + server_port, + interface, + conf_file, + mtu_max, + mtu_min, + mtu_step, + peer_skip_errors, ): """Init.""" self.mode = mode @@ -30,6 +43,8 @@ class MTUFinder(object): self.server_mtu = None self.current_mtu = None + self.peer_skip_errors = peer_skip_errors + self.log_filename = ( f"wg_mtu_finder_{self.mode}_{datetime.now().strftime('%Y%m%dT%H%M%S')}.csv" ) @@ -62,6 +77,21 @@ class MTUFinder(object): ) print("SUCCESS") + @staticmethod + def handle_returncode(returncode, stdout, stderr): + """Handle status code.""" + if returncode == 0: + print("SUCCESS") + else: + print(f"FAILED with code {returncode}") + print(f"*" * 80) + print(f"STDOUT:\n-------") + print(stdout) + print(f"STDERR:\n-------") + print(stderr) + print(f"*" * 80) + raise ReturncodeError() + def append_log_with_bandwidth_info(self, up_rcv_bps, up_snd_bps, down_rcv_bps, down_snd_bps): """Append the bandwidth information to the log file.""" if self.mode == "server": @@ -93,7 +123,7 @@ class MTUFinder(object): universal_newlines=True, ) stdout, stderr = process.communicate() - print("SUCCESS" if process.returncode == 0 else f"FAILED with code {process.returncode}") + self.handle_returncode(returncode=process.returncode, stdout=stdout, stderr=stderr) def wg_quick_up(self): """Spin up the interface using wg-quick.""" @@ -106,7 +136,7 @@ class MTUFinder(object): universal_newlines=True, ) stdout, stderr = process.communicate() - print("SUCCESS" if process.returncode == 0 else f"FAILED with code {process.returncode}") + self.handle_returncode(returncode=process.returncode, stdout=stdout, stderr=stderr) def __validate_conf_file(self): """Validate that a line `MTU =` exists in the wireguard conf.""" @@ -137,7 +167,7 @@ class MTUFinder(object): ) stdout, stderr = process.communicate() - print("SUCCESS" if process.returncode == 0 else f"FAILED with code {process.returncode}") + self.handle_returncode(returncode=process.returncode, stdout=stdout, stderr=stderr) def run_iperf3_upload_test(self): """Run iperf3 upload test.""" @@ -151,8 +181,7 @@ class MTUFinder(object): # Wait iperf3 test to be done. stdout, stderr = process.communicate() - - print("SUCCESS" if process.returncode == 0 else f"FAILED with code {process.returncode}") + self.handle_returncode(returncode=process.returncode, stdout=stdout, stderr=stderr) # load iperf3 output json which results from the -J flag output = json.loads(stdout) @@ -174,8 +203,7 @@ class MTUFinder(object): # Wait iperf3 test to be done. stdout, stderr = process.communicate() - - print("SUCCESS" if process.returncode == 0 else f"FAILED with code {process.returncode}") + self.handle_returncode(returncode=process.returncode, stdout=stdout, stderr=stderr) # load iperf3 output json which results from the -J flag output = json.loads(stdout) @@ -243,8 +271,7 @@ class MTUFinder(object): ) stdout, stderr = process.communicate() - - print("SUCCESS" if process.returncode == 0 else f"FAILED with code {process.returncode}") + self.handle_returncode(returncode=process.returncode, stdout=stdout, stderr=stderr) def run_peer_mode(self): """Run all steps for peer mode. @@ -291,13 +318,28 @@ class MTUFinder(object): # Wait a short while after interface is spun up. time.sleep(1) - up_rcv_bps, up_snd_bps = self.run_iperf3_upload_test() - time.sleep(1) - down_rcv_bps, down_snd_bps = self.run_iperf3_download_test() - self.append_log_with_bandwidth_info( - up_rcv_bps, up_snd_bps, down_rcv_bps, down_snd_bps - ) + try: + # Ping IP address of server to flush connection + self.__peer_mode__ping_server() + + up_rcv_bps, up_snd_bps = self.run_iperf3_upload_test() + time.sleep(1) + down_rcv_bps, down_snd_bps = self.run_iperf3_download_test() + + self.append_log_with_bandwidth_info( + up_rcv_bps, up_snd_bps, down_rcv_bps, down_snd_bps + ) + except ReturncodeError: + if self.peer_skip_errors: + print( + "Caught ReturncodeError. --peer-skip-errors is set to True. " + "Continuing with other peer MTUs. Bandwidth for this MTU will be " + "recorded as -1." + ) + self.append_log_with_bandwidth_info(-1000000, -1000000, -1000000, -1000000) + else: + raise def run_iperf3_server_test(self): """Run iperf3 upload test.""" diff --git a/nr_wg_mtu_finder/plot.py b/nr_wg_mtu_finder/plot.py index d3f51bb..3e38b9f 100644 --- a/nr_wg_mtu_finder/plot.py +++ b/nr_wg_mtu_finder/plot.py @@ -75,10 +75,3 @@ def plot_log(log_filename, plot_filename): f.suptitle("Peer MTU vs Server MTU Bandwidth (Mbps)") f.tight_layout() f.savefig(plot_filename, dpi=300) - - -if __name__ == "__main__": - plot_log( - "/home/nitred/projects/group-nr/nr-wg-mtu-finder/examples/example.csv", - "/home/nitred/projects/group-nr/nr-wg-mtu-finder/examples/example.png", - )