add base_netns functionality

Allows to specify the netns in which the WireGuard interface
is initialized. This allows for multi-hop VPNs.
pull/10/head v2.2.1
Jendrik Weise 1 year ago committed by Daniel
parent 41665ca136
commit 6486b2fad1

@ -65,6 +65,8 @@ Full YAML example:
~~~ yaml ~~~ yaml
# name of the network namespace # name of the network namespace
name: ns-example name: ns-example
# namespace where the interface is initialized, defaults to the main/default namespace
base_netns: null
# if false, the netns itself won't be created or deleted, just the interfaces inside it # if false, the netns itself won't be created or deleted, just the interfaces inside it
managed: true managed: true
# list of dns servers, if empty dns servers from default netns will be used # list of dns servers, if empty dns servers from default netns will be used

@ -4,7 +4,6 @@ from argparse import ArgumentParser, RawDescriptionHelpFormatter
from pathlib import Path from pathlib import Path
from typing import Any, Optional from typing import Any, Optional
import dataclasses import dataclasses
import itertools
import json import json
import os import os
import subprocess import subprocess
@ -135,6 +134,7 @@ class Peer:
@dataclasses.dataclass @dataclasses.dataclass
class Interface: class Interface:
name: str name: str
base_netns: str
private_key: str private_key: str
public_key: Optional[str] = None public_key: Optional[str] = None
address: list[str] = dataclasses.field(default_factory=list) address: list[str] = dataclasses.field(default_factory=list)
@ -144,10 +144,10 @@ class Interface:
peers: list[Peer] = dataclasses.field(default_factory=list) peers: list[Peer] = dataclasses.field(default_factory=list)
@classmethod @classmethod
def from_dict(cls, data: dict[str, Any]) -> Interface: def from_dict(cls, data: dict[str, Any], base_netns=None) -> Interface:
peers = data.pop('peers', list()) peers = data.pop('peers', list())
peers = [Peer.from_dict({key.replace('-', '_'): value for key, value in peer.items()}) for peer in peers] peers = [Peer.from_dict({key.replace('-', '_'): value for key, value in peer.items()}) for peer in peers]
return cls(**data, peers=peers) return cls(**data, peers=peers, base_netns=base_netns)
def setup(self, namespace: Namespace) -> Interface: def setup(self, namespace: Namespace) -> Interface:
self._create(namespace) self._create(namespace)
@ -160,8 +160,8 @@ class Interface:
return self return self
def _create(self, namespace: Namespace) -> None: def _create(self, namespace: Namespace) -> None:
ip('link', 'add', self.name, 'type', 'wireguard') ip('link', 'add', self.name, 'type', 'wireguard', netns=self.base_netns)
ip('link', 'set', self.name, 'netns', namespace.name) ip('link', 'set', self.name, 'netns', namespace.name, netns=self.base_netns)
def _configure_wireguard(self, namespace: Namespace) -> None: def _configure_wireguard(self, namespace: Namespace) -> None:
wg('set', self.name, 'listen-port', self.listen_port, netns=namespace.name) wg('set', self.name, 'listen-port', self.listen_port, netns=namespace.name)
@ -170,26 +170,26 @@ class Interface:
def _assign_addresses(self, namespace: Namespace) -> None: def _assign_addresses(self, namespace: Namespace) -> None:
for address in self.address: for address in self.address:
ip('-n', namespace.name, '-6' if ':' in address else '-4', 'address', 'add', address, 'dev', self.name) ip('-6' if ':' in address else '-4', 'address', 'add', address, 'dev', self.name, netns=namespace.name)
def _bring_up(self, namespace: Namespace) -> None: def _bring_up(self, namespace: Namespace) -> None:
ip('-n', namespace.name, 'link', 'set', 'dev', self.name, 'mtu', self.mtu, 'up') ip('link', 'set', 'dev', self.name, 'mtu', self.mtu, 'up', netns=namespace.name)
def _create_routes(self, namespace: Namespace): def _create_routes(self, namespace: Namespace):
for peer in self.peers: for peer in self.peers:
networks = peer.routes if peer.routes is not None else peer.allowed_ips networks = peer.routes if peer.routes is not None else peer.allowed_ips
for network in networks: for network in networks:
ip('-n', namespace.name, '-6' if ':' in network else '-4', 'route', 'add', network, 'dev', self.name) ip('-6' if ':' in network else '-4', 'route', 'add', network, 'dev', self.name, netns=namespace.name)
def teardown(self, namespace: Namespace, check=True) -> Interface: def teardown(self, namespace: Namespace, check=True) -> Interface:
if self.exists(namespace): if self.exists(namespace):
ip('-n', namespace.name, 'link', 'set', self.name, 'down', check=check) ip('link', 'set', self.name, 'down', check=check, netns=namespace.name)
ip('-n', namespace.name, 'link', 'delete', self.name, check=check) ip('link', 'delete', self.name, check=check, netns=namespace.name)
return self return self
def exists(self, namespace: Namespace) -> bool: def exists(self, namespace: Namespace) -> bool:
try: try:
ip('-n', namespace.name, 'link', 'show', self.name, capture=True) ip('link', 'show', self.name, capture=True, netns=namespace.name)
return True return True
except Exception: except Exception:
return False return False
@ -290,7 +290,8 @@ class Namespace:
scriptlets = {key: data.pop(key, None) for key in ['pre_up', 'post_up', 'pre_down', 'post_down']} scriptlets = {key: data.pop(key, None) for key in ['pre_up', 'post_up', 'pre_down', 'post_down']}
scriptlets = {key: Scriptlet.from_value(value) for key, value in scriptlets.items() if value is not None} scriptlets = {key: Scriptlet.from_value(value) for key, value in scriptlets.items() if value is not None}
interfaces = data.pop('interfaces', list()) interfaces = data.pop('interfaces', list())
interfaces = [Interface.from_dict({key.replace('-', '_'): value for key, value in interface.items()}) for interface in interfaces] base_netns = data.pop('base_netns', None)
interfaces = [Interface.from_dict({key.replace('-', '_'): value for key, value in interface.items()}, base_netns=base_netns) for interface in interfaces]
return cls(**data, **scriptlets, interfaces=interfaces) return cls(**data, **scriptlets, interfaces=interfaces)
def setup(self) -> Namespace: def setup(self) -> Namespace:
@ -323,7 +324,7 @@ class Namespace:
def _create(self) -> None: def _create(self) -> None:
ip('netns', 'add', self.name) ip('netns', 'add', self.name)
ip('-n', self.name, 'link', 'set', 'dev', 'lo', 'up') ip('link', 'set', 'dev', 'lo', 'up', netns=self.name)
def _delete(self, check=True) -> None: def _delete(self, check=True) -> None:
ip('netns', 'delete', self.name, check=check) ip('netns', 'delete', self.name, check=check)
@ -359,8 +360,8 @@ def ip_netns_exec(*args, netns: str = None, stdin: str = None, check=True, captu
return ip('netns', 'exec', netns, *args, stdin=stdin, check=check, capture=capture) return ip('netns', 'exec', netns, *args, stdin=stdin, check=check, capture=capture)
def ip(*args, stdin: str = None, check=True, capture=False) -> str: def ip(*args, stdin: str = None, netns=None, check=True, capture=False) -> str:
return run('ip', *args, stdin=stdin, check=check, capture=capture) return run('ip', *([] if netns is None else ['-n', netns]), *args, stdin=stdin, check=check, capture=capture)
def host_eval(*args, stdin: str = None, check=True, capture=False) -> str: def host_eval(*args, stdin: str = None, check=True, capture=False) -> str:

Loading…
Cancel
Save