You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
484 lines
17 KiB
Python
484 lines
17 KiB
Python
import json
|
|
import subprocess
|
|
import sys
|
|
import time
|
|
from datetime import datetime
|
|
|
|
import requests
|
|
|
|
from nr_wg_mtu_finder.plot import create_heatmap_from_log
|
|
|
|
# Set to either client or server
|
|
from nr_wg_mtu_finder.sync_server import run_sync_server
|
|
|
|
|
|
class ReturncodeError(Exception):
|
|
pass
|
|
|
|
|
|
class MTUFinder(object):
|
|
def __init__(
|
|
self,
|
|
mode,
|
|
server_ip,
|
|
server_port,
|
|
interface,
|
|
conf_file,
|
|
mtu_max,
|
|
mtu_min,
|
|
mtu_step,
|
|
peer_skip_errors,
|
|
):
|
|
"""Init."""
|
|
self.mode = mode
|
|
|
|
self.server_ip = server_ip
|
|
self.server_port = server_port
|
|
self.interface = interface
|
|
self.conf_file = conf_file
|
|
|
|
self.mtu_max = mtu_max
|
|
self.mtu_min = mtu_min
|
|
self.mtu_step = mtu_step
|
|
|
|
self.peer_mtu = None
|
|
self.server_mtu = None
|
|
self.current_mtu = None
|
|
|
|
self.peer_skip_errors = peer_skip_errors
|
|
|
|
self.log_filepath = (
|
|
f"wg_mtu_finder_{self.mode}_{datetime.now().strftime('%Y%m%dT%H%M%S')}.csv"
|
|
)
|
|
self.heatmap_filepath = (
|
|
f"wg_mtu_finder_{self.mode}_{datetime.now().strftime('%Y%m%dT%H%M%S')}.png"
|
|
)
|
|
|
|
if self.mode == "server":
|
|
self.run_server_mode()
|
|
elif self.mode == "peer":
|
|
self.run_peer_mode()
|
|
else:
|
|
raise NotImplementedError()
|
|
|
|
def create_log(self):
|
|
"""Create an empty CSV log file with the headers.
|
|
|
|
This log file will be used to store all bandwidth information for each MTU test.
|
|
"""
|
|
msg = f"Creating log file: {self.log_filepath}"
|
|
print(f"{msg:<50s}", end=": ")
|
|
with open(self.log_filepath, "w") as f:
|
|
f.write(
|
|
f"server_mtu,"
|
|
f"peer_mtu,"
|
|
f"upload_rcv_mbps,"
|
|
f"upload_send_mbps,"
|
|
f"download_rcv_mbps,"
|
|
f"download_send_mbps\n"
|
|
)
|
|
print("SUCCESS")
|
|
|
|
@staticmethod
|
|
def handle_returncode(returncode, stdout, stderr):
|
|
"""Handle status code."""
|
|
if returncode == 0:
|
|
print("SUCCESS")
|
|
else:
|
|
print(f"FAILED with code {returncode}")
|
|
print(f"*" * 80)
|
|
print(f"STDOUT:\n-------")
|
|
print(stdout)
|
|
print(f"STDERR:\n-------")
|
|
print(stderr)
|
|
print(f"*" * 80)
|
|
raise ReturncodeError()
|
|
|
|
def append_log_with_bandwidth_info(
|
|
self, up_rcv_bps, up_snd_bps, down_rcv_bps, down_snd_bps
|
|
):
|
|
"""Append the bandwidth information to the log file."""
|
|
if self.mode == "server":
|
|
raise NotImplementedError()
|
|
|
|
msg = f"Appending log for MTU: {self.current_mtu}"
|
|
print(f"{msg:<50s}", end=": ")
|
|
|
|
with open(self.log_filepath, "a") as f:
|
|
f.write(
|
|
f"{self.server_mtu},"
|
|
f"{self.peer_mtu},"
|
|
f"{up_rcv_bps / 1000000:0.3f},"
|
|
f"{up_snd_bps / 1000000:0.3f},"
|
|
f"{down_rcv_bps / 1000000:0.3f},"
|
|
f"{down_snd_bps / 1000000:0.3f}\n"
|
|
)
|
|
|
|
print("SUCCESS")
|
|
|
|
def wg_quick_down(self):
|
|
"""Spin down the interface using wg-quick."""
|
|
msg = "WG Interface Down"
|
|
print(f"{msg:<50s}", end=": ")
|
|
process = subprocess.Popen(
|
|
["wg-quick", "down", f"{self.interface}"],
|
|
stdout=subprocess.PIPE,
|
|
stderr=subprocess.PIPE,
|
|
universal_newlines=True,
|
|
)
|
|
stdout, stderr = process.communicate()
|
|
self.handle_returncode(
|
|
returncode=process.returncode, stdout=stdout, stderr=stderr
|
|
)
|
|
|
|
def wg_quick_up(self):
|
|
"""Spin up the interface using wg-quick."""
|
|
msg = "WG Interface Up"
|
|
print(f"{msg:<50s}", end=": ")
|
|
process = subprocess.Popen(
|
|
["wg-quick", "up", f"{self.interface}"],
|
|
stdout=subprocess.PIPE,
|
|
stderr=subprocess.PIPE,
|
|
universal_newlines=True,
|
|
)
|
|
stdout, stderr = process.communicate()
|
|
self.handle_returncode(
|
|
returncode=process.returncode, stdout=stdout, stderr=stderr
|
|
)
|
|
|
|
def __validate_conf_file(self):
|
|
"""Validate that a line `MTU =` exists in the wireguard conf."""
|
|
with open(self.conf_file, "r") as f:
|
|
for line in f.readlines():
|
|
if line.startswith("MTU ="):
|
|
return
|
|
|
|
# If no line starts with "MTU = ", then raise an error.
|
|
raise ValueError(
|
|
f"Expected to find a line that begins with 'MTU =' in {self.conf_file} "
|
|
f"file but it was not found. Please check the README file for instructions "
|
|
f"on how to add the missing line to the wg.conf file."
|
|
)
|
|
|
|
def update_mtu_in_conf_file(self):
|
|
"""Update the MTU setting in the WG Conf.
|
|
|
|
Find a line that starts with 'MTU =***' and replace it with 'MTU = <current_mtu>'
|
|
"""
|
|
self.__validate_conf_file()
|
|
|
|
msg = f"Setting MTU to {self.current_mtu} in {self.conf_file}"
|
|
print(f"{msg:<50s}", end=": ")
|
|
process = subprocess.Popen(
|
|
["sed", "-i", f"s/MTU.*/MTU = {self.current_mtu}/", f"{self.conf_file}"],
|
|
stdout=subprocess.PIPE,
|
|
stderr=subprocess.PIPE,
|
|
universal_newlines=True,
|
|
)
|
|
|
|
stdout, stderr = process.communicate()
|
|
self.handle_returncode(
|
|
returncode=process.returncode, stdout=stdout, stderr=stderr
|
|
)
|
|
|
|
def run_iperf3_upload_test(self):
|
|
"""Run iperf3 upload test."""
|
|
msg = f"Running peer upload"
|
|
print(f"{msg:<50s}", end=": ")
|
|
command = ["iperf3", "-c", f"{self.server_ip}", "-J", "-t", "5", "-i", "5"]
|
|
# print(f"command: {' '.join(command)}")
|
|
process = subprocess.Popen(
|
|
command,
|
|
stdout=subprocess.PIPE,
|
|
stderr=subprocess.PIPE,
|
|
universal_newlines=True,
|
|
)
|
|
|
|
# Wait iperf3 test to be done.
|
|
stdout, stderr = process.communicate()
|
|
self.handle_returncode(
|
|
returncode=process.returncode, stdout=stdout, stderr=stderr
|
|
)
|
|
|
|
# load iperf3 output json which results from the -J flag
|
|
output = json.loads(stdout)
|
|
return (
|
|
output["end"]["streams"][0]["receiver"]["bits_per_second"],
|
|
output["end"]["streams"][0]["sender"]["bits_per_second"],
|
|
)
|
|
|
|
def run_iperf3_download_test(self):
|
|
"""Run iperf3 upload test."""
|
|
msg = f"Running peer download"
|
|
print(f"{msg:<50s}", end=": ")
|
|
process = subprocess.Popen(
|
|
["iperf3", "-c", f"{self.server_ip}", "-J", "-t", "5", "-i", "5", "-R"],
|
|
stdout=subprocess.PIPE,
|
|
stderr=subprocess.PIPE,
|
|
universal_newlines=True,
|
|
)
|
|
|
|
# Wait iperf3 test to be done.
|
|
stdout, stderr = process.communicate()
|
|
self.handle_returncode(
|
|
returncode=process.returncode, stdout=stdout, stderr=stderr
|
|
)
|
|
|
|
# load iperf3 output json which results from the -J flag
|
|
output = json.loads(stdout)
|
|
return (
|
|
output["end"]["streams"][0]["receiver"]["bits_per_second"],
|
|
output["end"]["streams"][0]["sender"]["bits_per_second"],
|
|
)
|
|
|
|
def __peer_mode__wait_for_server_init(self):
|
|
"""Get server mtu.
|
|
|
|
Raises:
|
|
- requests.Timeout, requests.ConnectionError or KeyError if there is
|
|
something wrong with the Flask server running on the WG server.
|
|
"""
|
|
while True:
|
|
msg = f"Waiting for server init and status"
|
|
print(f"{msg:<50s}", end=": ")
|
|
try:
|
|
resp = requests.get(
|
|
f"http://{self.server_ip}:{self.server_port}/server/status",
|
|
verify=False,
|
|
timeout=5,
|
|
)
|
|
|
|
server_mtu, server_status = (
|
|
resp.json()["server_mtu"],
|
|
resp.json()["server_status"],
|
|
)
|
|
|
|
if (server_status == "INITIALIZED") or (server_status == "SHUTDOWN"):
|
|
print(
|
|
f"SUCCESS, SERVER_MTU: {server_mtu}, "
|
|
f"SERVER_STATUS: {server_status}"
|
|
)
|
|
return server_mtu, server_status
|
|
else:
|
|
print(f"FAILED, SERVER_STATUS: {server_status}, Retrying...")
|
|
time.sleep(1)
|
|
continue
|
|
except requests.exceptions.ConnectTimeout:
|
|
print("FAILED, ConnectTimeout, Retrying...")
|
|
time.sleep(1)
|
|
continue
|
|
|
|
def __peer_mode__send_server_peer_ready(self):
|
|
"""Send restart signal to flask server and get back server status."""
|
|
msg = f"Send peer ready for next loop to server"
|
|
print(f"{msg:<50s}", end=": ")
|
|
resp = requests.get(
|
|
f"http://{self.server_ip}:{self.server_port}/peer/ready",
|
|
verify=False,
|
|
timeout=5,
|
|
)
|
|
server_mtu, server_status = (
|
|
resp.json()["server_mtu"],
|
|
resp.json()["server_status"],
|
|
)
|
|
print("SUCCESS")
|
|
return server_mtu, server_status
|
|
|
|
def __peer_mode__ping_server(self):
|
|
"""Ping server to reestablish connection between peer and server.
|
|
|
|
After server interface is spun down and spun up again, the peer is not
|
|
guaranteed to be connected to the server. Therefore we force a ping to make
|
|
sure peer sends packets on this network.
|
|
"""
|
|
msg = f"Pinging server to establish connection"
|
|
print(f"{msg:<50s}", end=": ")
|
|
process = subprocess.Popen(
|
|
["ping", "-c", "1", f"{self.server_ip}"],
|
|
stdout=subprocess.PIPE,
|
|
stderr=subprocess.PIPE,
|
|
universal_newlines=True,
|
|
)
|
|
|
|
stdout, stderr = process.communicate()
|
|
self.handle_returncode(
|
|
returncode=process.returncode, stdout=stdout, stderr=stderr
|
|
)
|
|
|
|
def run_peer_mode(self):
|
|
"""Run all steps for peer mode.
|
|
|
|
IMPORTANT: Peer is the one that logs bandwidth into the log file (csv)
|
|
"""
|
|
self.create_log()
|
|
while True:
|
|
# Ping IP address of server to flush connection
|
|
self.__peer_mode__ping_server()
|
|
|
|
# Tell server that peer is ready for next loop.
|
|
self.__peer_mode__send_server_peer_ready()
|
|
|
|
# Ping IP address of server to flush connection
|
|
self.__peer_mode__ping_server()
|
|
|
|
# Start a fresh loop of cycling through all peer MTUs
|
|
# At start, find what the current server_mtu is.
|
|
self.server_mtu, server_status = self.__peer_mode__wait_for_server_init()
|
|
|
|
if server_status == "INITIALIZED":
|
|
pass
|
|
elif server_status == "SHUTDOWN":
|
|
print(f"Server has shutdown... Shutting down peer script.")
|
|
print(f"Check final bandwidth log: {self.log_filepath}")
|
|
create_heatmap_from_log(
|
|
log_filepath=self.log_filepath,
|
|
heatmap_filepath=self.heatmap_filepath,
|
|
)
|
|
print(f"Check final bandwidth plot: {self.heatmap_filepath}")
|
|
sys.exit(0)
|
|
else:
|
|
raise NotImplementedError()
|
|
|
|
for current_mtu in range(self.mtu_min, self.mtu_max + 1, self.mtu_step):
|
|
if self.server_mtu is None:
|
|
raise NotImplementedError()
|
|
|
|
self.current_mtu = current_mtu
|
|
self.peer_mtu = current_mtu
|
|
|
|
print("-" * 80)
|
|
self.wg_quick_down()
|
|
self.update_mtu_in_conf_file()
|
|
self.wg_quick_up()
|
|
|
|
# Wait a short while after interface is spun up.
|
|
time.sleep(1)
|
|
|
|
try:
|
|
# Ping IP address of server to flush connection
|
|
self.__peer_mode__ping_server()
|
|
|
|
up_rcv_bps, up_snd_bps = self.run_iperf3_upload_test()
|
|
time.sleep(1)
|
|
down_rcv_bps, down_snd_bps = self.run_iperf3_download_test()
|
|
|
|
self.append_log_with_bandwidth_info(
|
|
up_rcv_bps, up_snd_bps, down_rcv_bps, down_snd_bps
|
|
)
|
|
except ReturncodeError:
|
|
if self.peer_skip_errors:
|
|
print(
|
|
"Caught ReturncodeError: The --peer-skip-errors flag is "
|
|
"set to True so this Peer MTU iteration will be skipped. "
|
|
"Continuing with other peer MTUs. Bandwidth for this MTU "
|
|
"will be recorded as -1 in the log file (csv)."
|
|
)
|
|
self.append_log_with_bandwidth_info(-1, -1, -1, -1)
|
|
else:
|
|
print(
|
|
"Caught ReturncodeError: The --peer-skip-errors flag is "
|
|
"set to False so the Peer loop will crash. If you wish "
|
|
"to skip MTUs that raise this error in the future, set the "
|
|
"--peer-skip-errors flag to True when running the script."
|
|
)
|
|
raise
|
|
|
|
def run_iperf3_server_test(self):
|
|
"""Run iperf3 upload test."""
|
|
msg = f"Running iperf3 server"
|
|
print(f"{msg:<50s}", end=": ")
|
|
process = subprocess.Popen(
|
|
["iperf3", "-s"],
|
|
stdout=subprocess.PIPE,
|
|
stderr=subprocess.PIPE,
|
|
universal_newlines=True,
|
|
)
|
|
|
|
time.sleep(1)
|
|
print("SUCCESS")
|
|
|
|
return process
|
|
|
|
def run_server_mode(self):
|
|
"""Run all steps for server mode."""
|
|
import multiprocessing as mp
|
|
|
|
pool = mp.Pool(1)
|
|
manager = mp.Manager()
|
|
|
|
to_server_queue = manager.Queue(5)
|
|
from_server_queue = manager.Queue(5)
|
|
|
|
pool.apply_async(
|
|
run_sync_server,
|
|
kwds={
|
|
"host": self.server_ip,
|
|
"port": self.server_port,
|
|
"to_server_queue": to_server_queue,
|
|
"from_server_queue": from_server_queue,
|
|
},
|
|
)
|
|
|
|
iperf3_server_process = None
|
|
mtu_range = list(range(self.mtu_min, self.mtu_max + 1, self.mtu_step))
|
|
mtu_range_iter = iter(mtu_range)
|
|
|
|
while True:
|
|
print("-" * 80)
|
|
# Wait for init command from sync server
|
|
sync_server_status = from_server_queue.get(block=True)
|
|
|
|
# Any time a message is received from the sync_server, the iperf3 server
|
|
# must be terminated.
|
|
if iperf3_server_process:
|
|
iperf3_server_process.terminate()
|
|
|
|
if sync_server_status == "INITIALIZE":
|
|
# We receive INITIALIZE from the peer but sometimes the connection is
|
|
# spun down too quickly before a response could be sent. Therefore
|
|
# we'll wait for a little while until the request has been handled.
|
|
time.sleep(1)
|
|
|
|
try:
|
|
self.current_mtu = next(mtu_range_iter)
|
|
except StopIteration:
|
|
# Done with cycling through all MTUs
|
|
# Send Shutdown signal to the sync_server
|
|
# And go back to waiting for shutdown signal from sync_server
|
|
to_server_queue.put(
|
|
{"server_mtu": self.server_mtu, "server_status": "SHUTDOWN"}
|
|
)
|
|
continue
|
|
|
|
self.server_mtu = self.current_mtu
|
|
|
|
self.wg_quick_down()
|
|
self.update_mtu_in_conf_file()
|
|
self.wg_quick_up()
|
|
|
|
iperf3_server_process = self.run_iperf3_server_test()
|
|
|
|
# Wait a short while after interface is spun up.
|
|
time.sleep(1)
|
|
|
|
to_server_queue.put(
|
|
{"server_mtu": self.server_mtu, "server_status": "INITIALIZED"}
|
|
)
|
|
|
|
# Now wait for peer to ping our server
|
|
# Peer will get a response that tells it that the iperf3 server is
|
|
# ready with the current_mtu.
|
|
# Peer will start cycling through all of its MTUs
|
|
# Peer will send another "init" command if it needs the server to
|
|
|
|
elif sync_server_status == "SHUTDOWN":
|
|
time.sleep(2)
|
|
print("Received 'SHUTDOWN' signal from sync server. Shutting down.")
|
|
sys.exit(0)
|
|
|
|
else:
|
|
raise NotImplementedError()
|
|
|
|
# Code should not reach here.
|
|
raise NotImplementedError()
|