|
|
|
@ -1,16 +1,23 @@
|
|
|
|
|
"""SSH-agent implementation using hardware authentication devices."""
|
|
|
|
|
import argparse
|
|
|
|
|
import contextlib
|
|
|
|
|
import functools
|
|
|
|
|
import logging
|
|
|
|
|
import os
|
|
|
|
|
import re
|
|
|
|
|
import subprocess
|
|
|
|
|
import sys
|
|
|
|
|
import tempfile
|
|
|
|
|
import threading
|
|
|
|
|
|
|
|
|
|
from .. import client, device, formats, protocol, server, util
|
|
|
|
|
|
|
|
|
|
from .. import device, formats, server, util
|
|
|
|
|
from . import client, protocol
|
|
|
|
|
|
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
|
UNIX_SOCKET_TIMEOUT = 0.1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def ssh_args(label):
|
|
|
|
|
"""Create SSH command for connecting specified server."""
|
|
|
|
@ -51,7 +58,7 @@ def create_parser():
|
|
|
|
|
default=formats.CURVE_NIST256,
|
|
|
|
|
help='specify ECDSA curve name: ' + curve_names)
|
|
|
|
|
p.add_argument('--timeout',
|
|
|
|
|
default=server.UNIX_SOCKET_TIMEOUT, type=float,
|
|
|
|
|
default=UNIX_SOCKET_TIMEOUT, type=float,
|
|
|
|
|
help='Timeout for accepting SSH client connections')
|
|
|
|
|
p.add_argument('--debug', default=False, action='store_true',
|
|
|
|
|
help='Log SSH protocol messages for debugging.')
|
|
|
|
@ -110,11 +117,44 @@ def git_host(remote_name, attributes):
|
|
|
|
|
return '{user}@{host}'.format(**match.groupdict())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@contextlib.contextmanager
|
|
|
|
|
def serve(handler, sock_path=None, timeout=UNIX_SOCKET_TIMEOUT):
|
|
|
|
|
"""
|
|
|
|
|
Start the ssh-agent server on a UNIX-domain socket.
|
|
|
|
|
|
|
|
|
|
If no connection is made during the specified timeout,
|
|
|
|
|
retry until the context is over.
|
|
|
|
|
"""
|
|
|
|
|
ssh_version = subprocess.check_output(['ssh', '-V'],
|
|
|
|
|
stderr=subprocess.STDOUT)
|
|
|
|
|
log.debug('local SSH version: %r', ssh_version)
|
|
|
|
|
if sock_path is None:
|
|
|
|
|
sock_path = tempfile.mktemp(prefix='trezor-ssh-agent-')
|
|
|
|
|
|
|
|
|
|
environ = {'SSH_AUTH_SOCK': sock_path, 'SSH_AGENT_PID': str(os.getpid())}
|
|
|
|
|
device_mutex = threading.Lock()
|
|
|
|
|
with server.unix_domain_socket_server(sock_path) as sock:
|
|
|
|
|
sock.settimeout(timeout)
|
|
|
|
|
quit_event = threading.Event()
|
|
|
|
|
handle_conn = functools.partial(server.handle_connection,
|
|
|
|
|
handler=handler,
|
|
|
|
|
mutex=device_mutex)
|
|
|
|
|
kwargs = dict(sock=sock,
|
|
|
|
|
handle_conn=handle_conn,
|
|
|
|
|
quit_event=quit_event)
|
|
|
|
|
with server.spawn(server.server_thread, kwargs):
|
|
|
|
|
try:
|
|
|
|
|
yield environ
|
|
|
|
|
finally:
|
|
|
|
|
log.debug('closing server')
|
|
|
|
|
quit_event.set()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def run_server(conn, command, debug, timeout):
|
|
|
|
|
"""Common code for run_agent and run_git below."""
|
|
|
|
|
try:
|
|
|
|
|
handler = protocol.Handler(conn=conn, debug=debug)
|
|
|
|
|
with server.serve(handler=handler, timeout=timeout) as env:
|
|
|
|
|
with serve(handler=handler, timeout=timeout) as env:
|
|
|
|
|
return server.run_process(command=command, environ=env)
|
|
|
|
|
except KeyboardInterrupt:
|
|
|
|
|
log.info('server stopped')
|
|
|
|
|