From 257992d04c0935455cd63a84c9993a3cf2d06e4b Mon Sep 17 00:00:00 2001 From: Roman Zeyde Date: Fri, 5 May 2017 11:22:00 +0300 Subject: [PATCH] ssh: move related code to a separate subdirectory --- libagent/server.py | 37 ------------------ libagent/ssh/__init__.py | 46 +++++++++++++++++++++-- libagent/{ => ssh}/client.py | 0 libagent/{ => ssh}/protocol.py | 0 libagent/ssh/tests/__init__.py | 1 + libagent/{ => ssh}/tests/test_client.py | 0 libagent/{ => ssh}/tests/test_protocol.py | 0 libagent/tests/test_server.py | 9 +---- 8 files changed, 46 insertions(+), 47 deletions(-) rename libagent/{ => ssh}/client.py (100%) rename libagent/{ => ssh}/protocol.py (100%) create mode 100644 libagent/ssh/tests/__init__.py rename libagent/{ => ssh}/tests/test_client.py (100%) rename libagent/{ => ssh}/tests/test_protocol.py (100%) diff --git a/libagent/server.py b/libagent/server.py index bb8935f..54b063c 100644 --- a/libagent/server.py +++ b/libagent/server.py @@ -1,19 +1,15 @@ """UNIX-domain socket server for ssh-agent implementation.""" import contextlib -import functools import logging import os import socket import subprocess -import tempfile import threading from . import util log = logging.getLogger(__name__) -UNIX_SOCKET_TIMEOUT = 0.1 - def remove_file(path, remove=os.remove, exists=os.path.exists): """Remove file, and raise OSError if still exists.""" @@ -114,39 +110,6 @@ def spawn(func, kwargs): t.join() -@contextlib.contextmanager -def serve(handler, sock_path=None, timeout=UNIX_SOCKET_TIMEOUT): - """ - Start the ssh-agent server on a UNIX-domain socket. - - If no connection is made during the specified timeout, - retry until the context is over. - """ - ssh_version = subprocess.check_output(['ssh', '-V'], - stderr=subprocess.STDOUT) - log.debug('local SSH version: %r', ssh_version) - if sock_path is None: - sock_path = tempfile.mktemp(prefix='trezor-ssh-agent-') - - environ = {'SSH_AUTH_SOCK': sock_path, 'SSH_AGENT_PID': str(os.getpid())} - device_mutex = threading.Lock() - with unix_domain_socket_server(sock_path) as sock: - sock.settimeout(timeout) - quit_event = threading.Event() - handle_conn = functools.partial(handle_connection, - handler=handler, - mutex=device_mutex) - kwargs = dict(sock=sock, - handle_conn=handle_conn, - quit_event=quit_event) - with spawn(server_thread, kwargs): - try: - yield environ - finally: - log.debug('closing server') - quit_event.set() - - def run_process(command, environ): """ Run the specified process and wait until it finishes. diff --git a/libagent/ssh/__init__.py b/libagent/ssh/__init__.py index 8122058..2c12003 100644 --- a/libagent/ssh/__init__.py +++ b/libagent/ssh/__init__.py @@ -1,16 +1,23 @@ """SSH-agent implementation using hardware authentication devices.""" import argparse +import contextlib import functools import logging import os import re import subprocess import sys +import tempfile +import threading -from .. import client, device, formats, protocol, server, util + +from .. import device, formats, server, util +from . import client, protocol log = logging.getLogger(__name__) +UNIX_SOCKET_TIMEOUT = 0.1 + def ssh_args(label): """Create SSH command for connecting specified server.""" @@ -51,7 +58,7 @@ def create_parser(): default=formats.CURVE_NIST256, help='specify ECDSA curve name: ' + curve_names) p.add_argument('--timeout', - default=server.UNIX_SOCKET_TIMEOUT, type=float, + default=UNIX_SOCKET_TIMEOUT, type=float, help='Timeout for accepting SSH client connections') p.add_argument('--debug', default=False, action='store_true', help='Log SSH protocol messages for debugging.') @@ -110,11 +117,44 @@ def git_host(remote_name, attributes): return '{user}@{host}'.format(**match.groupdict()) +@contextlib.contextmanager +def serve(handler, sock_path=None, timeout=UNIX_SOCKET_TIMEOUT): + """ + Start the ssh-agent server on a UNIX-domain socket. + + If no connection is made during the specified timeout, + retry until the context is over. + """ + ssh_version = subprocess.check_output(['ssh', '-V'], + stderr=subprocess.STDOUT) + log.debug('local SSH version: %r', ssh_version) + if sock_path is None: + sock_path = tempfile.mktemp(prefix='trezor-ssh-agent-') + + environ = {'SSH_AUTH_SOCK': sock_path, 'SSH_AGENT_PID': str(os.getpid())} + device_mutex = threading.Lock() + with server.unix_domain_socket_server(sock_path) as sock: + sock.settimeout(timeout) + quit_event = threading.Event() + handle_conn = functools.partial(server.handle_connection, + handler=handler, + mutex=device_mutex) + kwargs = dict(sock=sock, + handle_conn=handle_conn, + quit_event=quit_event) + with server.spawn(server.server_thread, kwargs): + try: + yield environ + finally: + log.debug('closing server') + quit_event.set() + + def run_server(conn, command, debug, timeout): """Common code for run_agent and run_git below.""" try: handler = protocol.Handler(conn=conn, debug=debug) - with server.serve(handler=handler, timeout=timeout) as env: + with serve(handler=handler, timeout=timeout) as env: return server.run_process(command=command, environ=env) except KeyboardInterrupt: log.info('server stopped') diff --git a/libagent/client.py b/libagent/ssh/client.py similarity index 100% rename from libagent/client.py rename to libagent/ssh/client.py diff --git a/libagent/protocol.py b/libagent/ssh/protocol.py similarity index 100% rename from libagent/protocol.py rename to libagent/ssh/protocol.py diff --git a/libagent/ssh/tests/__init__.py b/libagent/ssh/tests/__init__.py new file mode 100644 index 0000000..e171647 --- /dev/null +++ b/libagent/ssh/tests/__init__.py @@ -0,0 +1 @@ +"""Unit-tests for this package.""" diff --git a/libagent/tests/test_client.py b/libagent/ssh/tests/test_client.py similarity index 100% rename from libagent/tests/test_client.py rename to libagent/ssh/tests/test_client.py diff --git a/libagent/tests/test_protocol.py b/libagent/ssh/tests/test_protocol.py similarity index 100% rename from libagent/tests/test_protocol.py rename to libagent/ssh/tests/test_protocol.py diff --git a/libagent/tests/test_server.py b/libagent/tests/test_server.py index c680470..8e6717e 100644 --- a/libagent/tests/test_server.py +++ b/libagent/tests/test_server.py @@ -8,7 +8,8 @@ import threading import mock import pytest -from .. import protocol, server, util +from .. import server, util +from ..ssh import protocol def test_socket(): @@ -117,12 +118,6 @@ def test_run(): server.run_process([''], environ={}) -def test_serve_main(): - handler = protocol.Handler(conn=empty_device()) - with server.serve(handler=handler, sock_path=None): - pass - - def test_remove(): path = 'foo.bar'