Added and formatted with black

pull/6/head
nitred 2 years ago
parent 9ec992545b
commit 1e39b00a4f

@ -1 +1 @@
__version__ = "0.1.0"
__version__ = "0.2.0"

@ -2,15 +2,11 @@ import argparse
import signal
import sys
import time
from distutils.util import strtobool
from pydantic import BaseModel, StrictInt, StrictStr, root_validator
from typing_extensions import Literal
from pydantic import BaseModel, StrictStr
from nr_wg_mtu_finder.plot import create_heatmap_from_log
from .mtu_finder import MTUFinder
def signal_handler(sig, frame):
"""Handle ctrl+c interrupt.
@ -39,7 +35,7 @@ def setup_args():
"""Setup args."""
parser = argparse.ArgumentParser(
description=(
"nr-wg-mtu-finder-plot - "
"nr-wg-mtu-finder-heatmap - "
"Generate a heatmap file (png) from a log file (csv) that was created "
"by the `nr-wg-mtu-finder` script. This is useful in case the original "
"script file crashed midway."

@ -1,13 +1,15 @@
import json
import subprocess
import time
import sys
import json
import time
from datetime import datetime
import requests
from nr_wg_mtu_finder.plot import create_heatmap_from_log
# Set to either client or server
from nr_wg_mtu_finder.sync_server import run_sync_server
from nr_wg_mtu_finder.plot import plot_log
class ReturncodeError(Exception):
@ -45,10 +47,10 @@ class MTUFinder(object):
self.peer_skip_errors = peer_skip_errors
self.log_filename = (
self.log_filepath = (
f"wg_mtu_finder_{self.mode}_{datetime.now().strftime('%Y%m%dT%H%M%S')}.csv"
)
self.plot_filename = (
self.heatmap_filepath = (
f"wg_mtu_finder_{self.mode}_{datetime.now().strftime('%Y%m%dT%H%M%S')}.png"
)
@ -64,9 +66,9 @@ class MTUFinder(object):
This log file will be used to store all bandwidth information for each MTU test.
"""
msg = f"Creating log file: {self.log_filename}"
msg = f"Creating log file: {self.log_filepath}"
print(f"{msg:<50s}", end=": ")
with open(self.log_filename, "w") as f:
with open(self.log_filepath, "w") as f:
f.write(
f"server_mtu,"
f"peer_mtu,"
@ -92,7 +94,9 @@ class MTUFinder(object):
print(f"*" * 80)
raise ReturncodeError()
def append_log_with_bandwidth_info(self, up_rcv_bps, up_snd_bps, down_rcv_bps, down_snd_bps):
def append_log_with_bandwidth_info(
self, up_rcv_bps, up_snd_bps, down_rcv_bps, down_snd_bps
):
"""Append the bandwidth information to the log file."""
if self.mode == "server":
raise NotImplementedError()
@ -100,7 +104,7 @@ class MTUFinder(object):
msg = f"Appending log for MTU: {self.current_mtu}"
print(f"{msg:<50s}", end=": ")
with open(self.log_filename, "a") as f:
with open(self.log_filepath, "a") as f:
f.write(
f"{self.server_mtu},"
f"{self.peer_mtu},"
@ -123,7 +127,9 @@ class MTUFinder(object):
universal_newlines=True,
)
stdout, stderr = process.communicate()
self.handle_returncode(returncode=process.returncode, stdout=stdout, stderr=stderr)
self.handle_returncode(
returncode=process.returncode, stdout=stdout, stderr=stderr
)
def wg_quick_up(self):
"""Spin up the interface using wg-quick."""
@ -136,7 +142,9 @@ class MTUFinder(object):
universal_newlines=True,
)
stdout, stderr = process.communicate()
self.handle_returncode(returncode=process.returncode, stdout=stdout, stderr=stderr)
self.handle_returncode(
returncode=process.returncode, stdout=stdout, stderr=stderr
)
def __validate_conf_file(self):
"""Validate that a line `MTU =` exists in the wireguard conf."""
@ -147,7 +155,9 @@ class MTUFinder(object):
# If no line starts with "MTU = ", then raise an error.
raise ValueError(
f"Expected to find a line that begins with 'MTU =' in {self.conf_file} file."
f"Expected to find a line that begins with 'MTU =' in {self.conf_file} "
f"file but it was not found. Please check the README file for instructions "
f"on how to add the missing line to the wg.conf file."
)
def update_mtu_in_conf_file(self):
@ -167,7 +177,9 @@ class MTUFinder(object):
)
stdout, stderr = process.communicate()
self.handle_returncode(returncode=process.returncode, stdout=stdout, stderr=stderr)
self.handle_returncode(
returncode=process.returncode, stdout=stdout, stderr=stderr
)
def run_iperf3_upload_test(self):
"""Run iperf3 upload test."""
@ -176,12 +188,17 @@ class MTUFinder(object):
command = ["iperf3", "-c", f"{self.server_ip}", "-J", "-t", "5", "-i", "5"]
# print(f"command: {' '.join(command)}")
process = subprocess.Popen(
command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, universal_newlines=True,
command,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
universal_newlines=True,
)
# Wait iperf3 test to be done.
stdout, stderr = process.communicate()
self.handle_returncode(returncode=process.returncode, stdout=stdout, stderr=stderr)
self.handle_returncode(
returncode=process.returncode, stdout=stdout, stderr=stderr
)
# load iperf3 output json which results from the -J flag
output = json.loads(stdout)
@ -203,7 +220,9 @@ class MTUFinder(object):
# Wait iperf3 test to be done.
stdout, stderr = process.communicate()
self.handle_returncode(returncode=process.returncode, stdout=stdout, stderr=stderr)
self.handle_returncode(
returncode=process.returncode, stdout=stdout, stderr=stderr
)
# load iperf3 output json which results from the -J flag
output = json.loads(stdout)
@ -216,8 +235,8 @@ class MTUFinder(object):
"""Get server mtu.
Raises:
- requests.Timeout, requests.ConnectionError or KeyError if there is something wrong
with the Flask server running on the WG server.
- requests.Timeout, requests.ConnectionError or KeyError if there is
something wrong with the Flask server running on the WG server.
"""
while True:
msg = f"Waiting for server init and status"
@ -229,10 +248,16 @@ class MTUFinder(object):
timeout=5,
)
server_mtu, server_status = resp.json()["server_mtu"], resp.json()["server_status"]
server_mtu, server_status = (
resp.json()["server_mtu"],
resp.json()["server_status"],
)
if (server_status == "INITIALIZED") or (server_status == "SHUTDOWN"):
print(f"SUCCESS, SERVER_MTU: {server_mtu}, SERVER_STATUS: {server_status}")
print(
f"SUCCESS, SERVER_MTU: {server_mtu}, "
f"SERVER_STATUS: {server_status}"
)
return server_mtu, server_status
else:
print(f"FAILED, SERVER_STATUS: {server_status}, Retrying...")
@ -248,18 +273,23 @@ class MTUFinder(object):
msg = f"Send peer ready for next loop to server"
print(f"{msg:<50s}", end=": ")
resp = requests.get(
f"http://{self.server_ip}:{self.server_port}/peer/ready", verify=False, timeout=5
f"http://{self.server_ip}:{self.server_port}/peer/ready",
verify=False,
timeout=5,
)
server_mtu, server_status = (
resp.json()["server_mtu"],
resp.json()["server_status"],
)
server_mtu, server_status = resp.json()["server_mtu"], resp.json()["server_status"]
print("SUCCESS")
return server_mtu, server_status
def __peer_mode__ping_server(self):
"""Ping server to reestablish connection between peer and server.
After server interface is spun down and spun up again, the peer is not guaranteed to be
connected to the server. Therefore we force a ping to make sure peer sends packets
on this network.
After server interface is spun down and spun up again, the peer is not
guaranteed to be connected to the server. Therefore we force a ping to make
sure peer sends packets on this network.
"""
msg = f"Pinging server to establish connection"
print(f"{msg:<50s}", end=": ")
@ -271,12 +301,14 @@ class MTUFinder(object):
)
stdout, stderr = process.communicate()
self.handle_returncode(returncode=process.returncode, stdout=stdout, stderr=stderr)
self.handle_returncode(
returncode=process.returncode, stdout=stdout, stderr=stderr
)
def run_peer_mode(self):
"""Run all steps for peer mode.
1. Peer is the one that logs bandwidth
IMPORTANT: Peer is the one that logs bandwidth into the log file (csv)
"""
self.create_log()
while True:
@ -297,9 +329,12 @@ class MTUFinder(object):
pass
elif server_status == "SHUTDOWN":
print(f"Server has shutdown... Shutting down peer script.")
print(f"Check final bandwidth log: {self.log_filename}")
plot_log(log_filename=self.log_filename, plot_filename=self.plot_filename)
print(f"Check final bandwidth plot: {self.plot_filename}")
print(f"Check final bandwidth log: {self.log_filepath}")
create_heatmap_from_log(
log_filepath=self.log_filepath,
heatmap_filepath=self.heatmap_filepath,
)
print(f"Check final bandwidth plot: {self.heatmap_filepath}")
sys.exit(0)
else:
raise NotImplementedError()
@ -333,12 +368,19 @@ class MTUFinder(object):
except ReturncodeError:
if self.peer_skip_errors:
print(
"Caught ReturncodeError. --peer-skip-errors is set to True. "
"Continuing with other peer MTUs. Bandwidth for this MTU will be "
"recorded as -1."
"Caught ReturncodeError: The --peer-skip-errors flag is "
"set to True so this Peer MTU iteration will be skipped. "
"Continuing with other peer MTUs. Bandwidth for this MTU "
"will be recorded as -1 in the log file (csv)."
)
self.append_log_with_bandwidth_info(-1000000, -1000000, -1000000, -1000000)
self.append_log_with_bandwidth_info(-1, -1, -1, -1)
else:
print(
"Caught ReturncodeError: The --peer-skip-errors flag is "
"set to False so the Peer loop will crash. If you wish "
"to skip MTUs that raise this error in the future, set the "
"--peer-skip-errors flag to True when running the script."
)
raise
def run_iperf3_server_test(self):
@ -392,9 +434,9 @@ class MTUFinder(object):
iperf3_server_process.terminate()
if sync_server_status == "INITIALIZE":
# We receive INITIALIZE from the peer but sometimes the connection is spun down too
# quickly before a response could be sent. Therefore we'll wait for a little while
# until the request has been handled.
# We receive INITIALIZE from the peer but sometimes the connection is
# spun down too quickly before a response could be sent. Therefore
# we'll wait for a little while until the request has been handled.
time.sleep(1)
try:
@ -419,11 +461,13 @@ class MTUFinder(object):
# Wait a short while after interface is spun up.
time.sleep(1)
to_server_queue.put({"server_mtu": self.server_mtu, "server_status": "INITIALIZED"})
to_server_queue.put(
{"server_mtu": self.server_mtu, "server_status": "INITIALIZED"}
)
# Now wait for peer to ping our server
# Peer will get a response that tells it that the iperf3 server is ready with
# the current_mtu
# Peer will get a response that tells it that the iperf3 server is
# ready with the current_mtu.
# Peer will start cycling through all of its MTUs
# Peer will send another "init" command if it needs the server to

@ -1,15 +1,19 @@
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import seaborn as sns
def plot_log(log_filename, plot_filename):
def create_heatmap_from_log(log_filepath, heatmap_filepath):
f, axes = plt.subplots(nrows=2, ncols=2, figsize=(12, 12))
df = pd.read_csv(log_filename)
df = pd.read_csv(log_filepath)
ax = axes[0, 0]
dfx = df.pivot(index="server_mtu", columns="peer_mtu", values="upload_rcv_mbps",)
dfx = df.pivot(
index="server_mtu",
columns="peer_mtu",
values="upload_rcv_mbps",
)
sns.heatmap(
dfx.values,
linewidth=0.5,
@ -74,4 +78,9 @@ def plot_log(log_filename, plot_filename):
f.suptitle("Peer MTU vs Server MTU Bandwidth (Mbps)")
f.tight_layout()
f.savefig(plot_filename, dpi=300)
f.savefig(heatmap_filepath, dpi=300)
print(
f"create_heatmap_from_log: Done generating heatmap from log file. Heatmap "
f"can be found at '{heatmap_filepath}'"
)

@ -1,15 +1,18 @@
import traceback
from typing import Optional
from flask import Flask, jsonify, request
from typing_extensions import Literal
from typing import Optional
status: Literal["NOT_INITIALIZED", "INITIALIZED", "SHUTDOWN"] = "NOT_INITIALIZED"
mtu: Optional[int] = None
def run_sync_server(host, port, to_server_queue, from_server_queue):
"""Run a temporary flask/http server that returns server mtu and status and restarts server."""
"""Run a flask/http server which is used to synchronize with the Peer script.
1. Peer can request the flask/http server for Server MTU and its status.
2. Peer can request the flask/http server to shutdown once the Peer is finished.
"""
app = Flask(__name__)
def shutdown_server():
@ -53,12 +56,4 @@ def run_sync_server(host, port, to_server_queue, from_server_queue):
from_server_queue.put("INITIALIZE")
return jsonify({"server_mtu": mtu, "server_status": status})
# @app.route("/server/shutdown", methods=["GET"])
# def server_restart():
# from_server_queue.put("shutdown")
# shutdown_server()
# return jsonify(jsonify({"server_mtu": mtu, "server_status": "shutdown"}))
# Blocking call until the server received a GET request on /server/shutdown after which
# the flask server is shutdown
app.run(host=host, port=port)

Loading…
Cancel
Save