trezor-agent/trezor_agent/__main__.py

169 lines
5.8 KiB
Python
Raw Normal View History

2016-02-19 09:35:16 +00:00
"""SSH-agent implementation using hardware authentication devices."""
import argparse
2016-06-11 17:26:10 +00:00
import functools
2016-01-04 17:17:08 +00:00
import logging
2016-01-09 14:06:47 +00:00
import os
2016-04-23 19:08:18 +00:00
import re
import subprocess
2016-01-09 14:06:47 +00:00
import sys
from . import client, device, formats, protocol, server, util
2015-06-16 07:20:11 +00:00
log = logging.getLogger(__name__)
2015-06-16 07:03:48 +00:00
2015-11-27 07:59:06 +00:00
def ssh_args(label):
2016-02-19 09:35:16 +00:00
"""Create SSH command for connecting specified server."""
identity = device.interface.string_to_identity(label)
2015-11-27 07:59:06 +00:00
args = []
if 'port' in identity:
args += ['-p', identity['port']]
if 'user' in identity:
args += ['-l', identity['user']]
return args + [identity['host']]
2015-11-27 07:59:06 +00:00
2016-03-05 08:46:36 +00:00
def create_parser():
2016-02-19 09:35:16 +00:00
"""Create argparse.ArgumentParser for this tool."""
p = argparse.ArgumentParser()
p.add_argument('-v', '--verbose', default=0, action='count')
2015-07-04 07:47:32 +00:00
2016-05-05 19:28:06 +00:00
curve_names = [name for name in formats.SUPPORTED_CURVES]
2016-01-04 17:17:08 +00:00
curve_names = ', '.join(sorted(curve_names))
p.add_argument('-e', '--ecdsa-curve-name', metavar='CURVE',
default=formats.CURVE_NIST256,
2016-01-04 17:17:08 +00:00
help='specify ECDSA curve name: ' + curve_names)
p.add_argument('--timeout',
default=server.UNIX_SOCKET_TIMEOUT, type=float,
help='Timeout for accepting SSH client connections')
p.add_argument('--debug', default=False, action='store_true',
help='Log SSH protocol messages for debugging.')
return p
2015-07-21 11:38:26 +00:00
2016-03-05 08:51:22 +00:00
2016-03-05 08:46:36 +00:00
def create_agent_parser():
2016-03-05 08:51:22 +00:00
"""Specific parser for SSH connection."""
2016-03-05 08:46:36 +00:00
p = create_parser()
g = p.add_mutually_exclusive_group()
g.add_argument('-s', '--shell', default=False, action='store_true',
help='run ${SHELL} as subprocess under SSH agent')
g.add_argument('-c', '--connect', default=False, action='store_true',
help='connect to specified host via SSH')
g.add_argument('--mosh', default=False, action='store_true',
help='connect to specified host via using Mosh')
2016-03-05 08:51:22 +00:00
p.add_argument('identity', type=str, default=None,
help='proto://[user@]host[:port][/path]')
p.add_argument('command', type=str, nargs='*', metavar='ARGUMENT',
help='command to run under the SSH agent')
return p
def create_git_parser():
"""Specific parser for git commands."""
p = create_parser()
p.add_argument('-r', '--remote', default='origin',
help='use this git remote URL to generate SSH identity')
p.add_argument('-t', '--test', action='store_true',
help='test connection using `ssh -T user@host` command')
2016-03-05 08:51:22 +00:00
p.add_argument('command', type=str, nargs='*', metavar='ARGUMENT',
help='Git command to run under the SSH agent')
2016-03-05 08:46:36 +00:00
return p
def git_host(remote_name, attributes):
2016-02-19 18:51:19 +00:00
"""Extract git SSH host for specified remote name."""
try:
output = subprocess.check_output('git config --local --list'.split())
except subprocess.CalledProcessError:
return
for attribute in attributes:
name = r'remote.{0}.{1}'.format(remote_name, attribute)
matches = re.findall(re.escape(name) + '=(.*)', output)
log.debug('%r: %r', name, matches)
if not matches:
continue
url = matches[0].strip()
match = re.match('(?P<user>.*?)@(?P<host>.*?):(?P<path>.*)', url)
if match:
return '{user}@{host}'.format(**match.groupdict())
2016-02-19 18:51:19 +00:00
2016-11-03 20:00:43 +00:00
def run_server(conn, public_keys, command, debug, timeout):
2016-03-05 09:15:23 +00:00
"""Common code for run_agent and run_git below."""
try:
2016-05-21 04:43:10 +00:00
signer = conn.sign_ssh_challenge
2016-11-03 20:00:43 +00:00
handler = protocol.Handler(keys=public_keys, signer=signer,
2016-03-05 09:15:23 +00:00
debug=debug)
with server.serve(handler=handler, timeout=timeout) as env:
return server.run_process(command=command, environ=env)
except KeyboardInterrupt:
log.info('server stopped')
2016-06-11 17:26:10 +00:00
def handle_connection_error(func):
"""Fail with non-zero exit code."""
@functools.wraps(func)
def wrapper(*args, **kwargs):
try:
return func(*args, **kwargs)
except IOError as e:
log.error('Connection error: %s', e)
return 1
return wrapper
def parse_config(fname):
"""Parse config file into a list of Identity objects."""
contents = open(fname).read()
2016-11-03 21:29:45 +00:00
for identity_str, curve_name in re.findall(r'\<(.*?)\|(.*?)\>', contents):
yield device.interface.Identity(identity_str=identity_str,
curve_name=curve_name)
2016-06-11 17:26:10 +00:00
@handle_connection_error
2016-03-05 09:15:23 +00:00
def run_agent(client_factory=client.Client):
2016-02-19 09:35:16 +00:00
"""Run ssh-agent using given hardware client factory."""
2015-08-17 15:09:45 +00:00
args = create_agent_parser().parse_args()
util.setup_logging(verbosity=args.verbose)
2016-11-03 20:00:43 +00:00
conn = client_factory(device=device.detect())
if args.identity.startswith('/'):
identities = list(parse_config(fname=args.identity))
else:
identities = [device.interface.Identity(
identity_str=args.identity, curve_name=args.ecdsa_curve_name)]
for index, identity in enumerate(identities):
2016-11-03 20:00:43 +00:00
identity.identity_dict['proto'] = 'ssh'
log.info('identity #%d: %s', index, identity)
2016-11-03 20:00:43 +00:00
public_keys = [conn.get_public_key(i) for i in identities]
if args.connect:
command = ['ssh'] + ssh_args(args.identity) + args.command
elif args.mosh:
command = ['mosh'] + ssh_args(args.identity) + args.command
else:
command = args.command
use_shell = bool(args.shell)
if use_shell:
command = os.environ['SHELL']
if not command:
2016-11-03 20:00:43 +00:00
for pk in public_keys:
sys.stdout.write(pk)
return
2016-11-03 20:00:43 +00:00
public_keys = [formats.import_public_key(pk) for pk in public_keys]
for pk, identity in zip(public_keys, identities):
pk['identity'] = identity
return run_server(conn=conn, public_keys=public_keys, command=command,
debug=args.debug, timeout=args.timeout)