You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

484 lines
17 KiB
Python

import json
import subprocess
import sys
import time
from datetime import datetime
import requests
from nr_wg_mtu_finder.plot import create_heatmap_from_log
# Set to either client or server
from nr_wg_mtu_finder.sync_server import run_sync_server
class ReturncodeError(Exception):
pass
class MTUFinder(object):
def __init__(
self,
mode,
server_ip,
server_port,
interface,
conf_file,
mtu_max,
mtu_min,
mtu_step,
peer_skip_errors,
):
"""Init."""
self.mode = mode
self.server_ip = server_ip
self.server_port = server_port
self.interface = interface
self.conf_file = conf_file
self.mtu_max = mtu_max
self.mtu_min = mtu_min
self.mtu_step = mtu_step
self.peer_mtu = None
self.server_mtu = None
self.current_mtu = None
self.peer_skip_errors = peer_skip_errors
self.log_filepath = (
f"wg_mtu_finder_{self.mode}_{datetime.now().strftime('%Y%m%dT%H%M%S')}.csv"
)
self.heatmap_filepath = (
f"wg_mtu_finder_{self.mode}_{datetime.now().strftime('%Y%m%dT%H%M%S')}.png"
)
if self.mode == "server":
self.run_server_mode()
elif self.mode == "peer":
self.run_peer_mode()
else:
raise NotImplementedError()
def create_log(self):
"""Create an empty CSV log file with the headers.
This log file will be used to store all bandwidth information for each MTU test.
"""
msg = f"Creating log file: {self.log_filepath}"
print(f"{msg:<50s}", end=": ")
with open(self.log_filepath, "w") as f:
f.write(
f"server_mtu,"
f"peer_mtu,"
f"upload_rcv_mbps,"
f"upload_send_mbps,"
f"download_rcv_mbps,"
f"download_send_mbps\n"
)
print("SUCCESS")
@staticmethod
def handle_returncode(returncode, stdout, stderr):
"""Handle status code."""
if returncode == 0:
print("SUCCESS")
else:
print(f"FAILED with code {returncode}")
print(f"*" * 80)
print(f"STDOUT:\n-------")
print(stdout)
print(f"STDERR:\n-------")
print(stderr)
print(f"*" * 80)
raise ReturncodeError()
def append_log_with_bandwidth_info(
self, up_rcv_bps, up_snd_bps, down_rcv_bps, down_snd_bps
):
"""Append the bandwidth information to the log file."""
if self.mode == "server":
raise NotImplementedError()
msg = f"Appending log for MTU: {self.current_mtu}"
print(f"{msg:<50s}", end=": ")
with open(self.log_filepath, "a") as f:
f.write(
f"{self.server_mtu},"
f"{self.peer_mtu},"
f"{up_rcv_bps / 1000000:0.3f},"
f"{up_snd_bps / 1000000:0.3f},"
f"{down_rcv_bps / 1000000:0.3f},"
f"{down_snd_bps / 1000000:0.3f}\n"
)
print("SUCCESS")
def wg_quick_down(self):
"""Spin down the interface using wg-quick."""
msg = "WG Interface Down"
print(f"{msg:<50s}", end=": ")
process = subprocess.Popen(
["wg-quick", "down", f"{self.interface}"],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
universal_newlines=True,
)
stdout, stderr = process.communicate()
self.handle_returncode(
returncode=process.returncode, stdout=stdout, stderr=stderr
)
def wg_quick_up(self):
"""Spin up the interface using wg-quick."""
msg = "WG Interface Up"
print(f"{msg:<50s}", end=": ")
process = subprocess.Popen(
["wg-quick", "up", f"{self.interface}"],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
universal_newlines=True,
)
stdout, stderr = process.communicate()
self.handle_returncode(
returncode=process.returncode, stdout=stdout, stderr=stderr
)
def __validate_conf_file(self):
"""Validate that a line `MTU =` exists in the wireguard conf."""
with open(self.conf_file, "r") as f:
for line in f.readlines():
if line.startswith("MTU ="):
return
# If no line starts with "MTU = ", then raise an error.
raise ValueError(
f"Expected to find a line that begins with 'MTU =' in {self.conf_file} "
f"file but it was not found. Please check the README file for instructions "
f"on how to add the missing line to the wg.conf file."
)
def update_mtu_in_conf_file(self):
"""Update the MTU setting in the WG Conf.
Find a line that starts with 'MTU =***' and replace it with 'MTU = <current_mtu>'
"""
self.__validate_conf_file()
msg = f"Setting MTU to {self.current_mtu} in {self.conf_file}"
print(f"{msg:<50s}", end=": ")
process = subprocess.Popen(
["sed", "-i", f"s/MTU.*/MTU = {self.current_mtu}/", f"{self.conf_file}"],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
universal_newlines=True,
)
stdout, stderr = process.communicate()
self.handle_returncode(
returncode=process.returncode, stdout=stdout, stderr=stderr
)
def run_iperf3_upload_test(self):
"""Run iperf3 upload test."""
msg = f"Running peer upload"
print(f"{msg:<50s}", end=": ")
command = ["iperf3", "-c", f"{self.server_ip}", "-J", "-t", "5", "-i", "5"]
# print(f"command: {' '.join(command)}")
process = subprocess.Popen(
command,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
universal_newlines=True,
)
# Wait iperf3 test to be done.
stdout, stderr = process.communicate()
self.handle_returncode(
returncode=process.returncode, stdout=stdout, stderr=stderr
)
# load iperf3 output json which results from the -J flag
output = json.loads(stdout)
return (
output["end"]["streams"][0]["receiver"]["bits_per_second"],
output["end"]["streams"][0]["sender"]["bits_per_second"],
)
def run_iperf3_download_test(self):
"""Run iperf3 upload test."""
msg = f"Running peer download"
print(f"{msg:<50s}", end=": ")
process = subprocess.Popen(
["iperf3", "-c", f"{self.server_ip}", "-J", "-t", "5", "-i", "5", "-R"],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
universal_newlines=True,
)
# Wait iperf3 test to be done.
stdout, stderr = process.communicate()
self.handle_returncode(
returncode=process.returncode, stdout=stdout, stderr=stderr
)
# load iperf3 output json which results from the -J flag
output = json.loads(stdout)
return (
output["end"]["streams"][0]["receiver"]["bits_per_second"],
output["end"]["streams"][0]["sender"]["bits_per_second"],
)
def __peer_mode__wait_for_server_init(self):
"""Get server mtu.
Raises:
- requests.Timeout, requests.ConnectionError or KeyError if there is
something wrong with the Flask server running on the WG server.
"""
while True:
msg = f"Waiting for server init and status"
print(f"{msg:<50s}", end=": ")
try:
resp = requests.get(
f"http://{self.server_ip}:{self.server_port}/server/status",
verify=False,
timeout=5,
)
server_mtu, server_status = (
resp.json()["server_mtu"],
resp.json()["server_status"],
)
if (server_status == "INITIALIZED") or (server_status == "SHUTDOWN"):
print(
f"SUCCESS, SERVER_MTU: {server_mtu}, "
f"SERVER_STATUS: {server_status}"
)
return server_mtu, server_status
else:
print(f"FAILED, SERVER_STATUS: {server_status}, Retrying...")
time.sleep(1)
continue
except requests.exceptions.ConnectTimeout:
print("FAILED, ConnectTimeout, Retrying...")
time.sleep(1)
continue
def __peer_mode__send_server_peer_ready(self):
"""Send restart signal to flask server and get back server status."""
msg = f"Send peer ready for next loop to server"
print(f"{msg:<50s}", end=": ")
resp = requests.get(
f"http://{self.server_ip}:{self.server_port}/peer/ready",
verify=False,
timeout=5,
)
server_mtu, server_status = (
resp.json()["server_mtu"],
resp.json()["server_status"],
)
print("SUCCESS")
return server_mtu, server_status
def __peer_mode__ping_server(self):
"""Ping server to reestablish connection between peer and server.
After server interface is spun down and spun up again, the peer is not
guaranteed to be connected to the server. Therefore we force a ping to make
sure peer sends packets on this network.
"""
msg = f"Pinging server to establish connection"
print(f"{msg:<50s}", end=": ")
process = subprocess.Popen(
["ping", "-c", "1", f"{self.server_ip}"],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
universal_newlines=True,
)
stdout, stderr = process.communicate()
self.handle_returncode(
returncode=process.returncode, stdout=stdout, stderr=stderr
)
def run_peer_mode(self):
"""Run all steps for peer mode.
IMPORTANT: Peer is the one that logs bandwidth into the log file (csv)
"""
self.create_log()
while True:
# Ping IP address of server to flush connection
self.__peer_mode__ping_server()
# Tell server that peer is ready for next loop.
self.__peer_mode__send_server_peer_ready()
# Ping IP address of server to flush connection
self.__peer_mode__ping_server()
# Start a fresh loop of cycling through all peer MTUs
# At start, find what the current server_mtu is.
self.server_mtu, server_status = self.__peer_mode__wait_for_server_init()
if server_status == "INITIALIZED":
pass
elif server_status == "SHUTDOWN":
print(f"Server has shutdown... Shutting down peer script.")
print(f"Check final bandwidth log: {self.log_filepath}")
create_heatmap_from_log(
log_filepath=self.log_filepath,
heatmap_filepath=self.heatmap_filepath,
)
print(f"Check final bandwidth plot: {self.heatmap_filepath}")
sys.exit(0)
else:
raise NotImplementedError()
for current_mtu in range(self.mtu_min, self.mtu_max + 1, self.mtu_step):
if self.server_mtu is None:
raise NotImplementedError()
self.current_mtu = current_mtu
self.peer_mtu = current_mtu
print("-" * 80)
self.wg_quick_down()
self.update_mtu_in_conf_file()
self.wg_quick_up()
# Wait a short while after interface is spun up.
time.sleep(1)
try:
# Ping IP address of server to flush connection
self.__peer_mode__ping_server()
up_rcv_bps, up_snd_bps = self.run_iperf3_upload_test()
time.sleep(1)
down_rcv_bps, down_snd_bps = self.run_iperf3_download_test()
self.append_log_with_bandwidth_info(
up_rcv_bps, up_snd_bps, down_rcv_bps, down_snd_bps
)
except ReturncodeError:
if self.peer_skip_errors:
print(
"Caught ReturncodeError: The --peer-skip-errors flag is "
"set to True so this Peer MTU iteration will be skipped. "
"Continuing with other peer MTUs. Bandwidth for this MTU "
"will be recorded as -1 in the log file (csv)."
)
self.append_log_with_bandwidth_info(-1, -1, -1, -1)
else:
print(
"Caught ReturncodeError: The --peer-skip-errors flag is "
"set to False so the Peer loop will crash. If you wish "
"to skip MTUs that raise this error in the future, set the "
"--peer-skip-errors flag to True when running the script."
)
raise
def run_iperf3_server_test(self):
"""Run iperf3 upload test."""
msg = f"Running iperf3 server"
print(f"{msg:<50s}", end=": ")
process = subprocess.Popen(
["iperf3", "-s"],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
universal_newlines=True,
)
time.sleep(1)
print("SUCCESS")
return process
def run_server_mode(self):
"""Run all steps for server mode."""
import multiprocessing as mp
pool = mp.Pool(1)
manager = mp.Manager()
to_server_queue = manager.Queue(5)
from_server_queue = manager.Queue(5)
pool.apply_async(
run_sync_server,
kwds={
"host": self.server_ip,
"port": self.server_port,
"to_server_queue": to_server_queue,
"from_server_queue": from_server_queue,
},
)
iperf3_server_process = None
mtu_range = list(range(self.mtu_min, self.mtu_max + 1, self.mtu_step))
mtu_range_iter = iter(mtu_range)
while True:
print("-" * 80)
# Wait for init command from sync server
sync_server_status = from_server_queue.get(block=True)
# Any time a message is received from the sync_server, the iperf3 server
# must be terminated.
if iperf3_server_process:
iperf3_server_process.terminate()
if sync_server_status == "INITIALIZE":
# We receive INITIALIZE from the peer but sometimes the connection is
# spun down too quickly before a response could be sent. Therefore
# we'll wait for a little while until the request has been handled.
time.sleep(1)
try:
self.current_mtu = next(mtu_range_iter)
except StopIteration:
# Done with cycling through all MTUs
# Send Shutdown signal to the sync_server
# And go back to waiting for shutdown signal from sync_server
to_server_queue.put(
{"server_mtu": self.server_mtu, "server_status": "SHUTDOWN"}
)
continue
self.server_mtu = self.current_mtu
self.wg_quick_down()
self.update_mtu_in_conf_file()
self.wg_quick_up()
iperf3_server_process = self.run_iperf3_server_test()
# Wait a short while after interface is spun up.
time.sleep(1)
to_server_queue.put(
{"server_mtu": self.server_mtu, "server_status": "INITIALIZED"}
)
# Now wait for peer to ping our server
# Peer will get a response that tells it that the iperf3 server is
# ready with the current_mtu.
# Peer will start cycling through all of its MTUs
# Peer will send another "init" command if it needs the server to
elif sync_server_status == "SHUTDOWN":
time.sleep(2)
print("Received 'SHUTDOWN' signal from sync server. Shutting down.")
sys.exit(0)
else:
raise NotImplementedError()
# Code should not reach here.
raise NotImplementedError()