diff --git a/trezor_agent/__main__.py b/trezor_agent/__main__.py index fd86baf..0a291c3 100644 --- a/trezor_agent/__main__.py +++ b/trezor_agent/__main__.py @@ -7,14 +7,14 @@ import re import subprocess import sys -from . import client, formats, protocol, server, util +from . import client, device, formats, protocol, server, util log = logging.getLogger(__name__) def ssh_args(label): """Create SSH command for connecting specified server.""" - identity = util.string_to_identity(label, identity_type=dict) + identity = device.interface.string_to_identity(label) args = [] if 'port' in identity: @@ -125,27 +125,28 @@ def run_agent(client_factory=client.Client): args = create_agent_parser().parse_args() util.setup_logging(verbosity=args.verbose) - with client_factory(curve=args.ecdsa_curve_name) as conn: - label = args.identity - command = args.command + d = device.detect(identity_str=args.identity, + curve_name=args.ecdsa_curve_name) + conn = client_factory(device=d) - public_key = conn.get_public_key(label=label) + command = args.command + public_key = conn.get_public_key() - if args.connect: - command = ssh_args(label) + args.command - log.debug('SSH connect: %r', command) + if args.connect: + command = ssh_args(args.identity) + args.command + log.debug('SSH connect: %r', command) - use_shell = bool(args.shell) - if use_shell: - command = os.environ['SHELL'] - log.debug('using shell: %r', command) + use_shell = bool(args.shell) + if use_shell: + command = os.environ['SHELL'] + log.debug('using shell: %r', command) - if not command: - sys.stdout.write(public_key) - return + if not command: + sys.stdout.write(public_key) + return - return run_server(conn=conn, public_key=public_key, command=command, - debug=args.debug, timeout=args.timeout) + return run_server(conn=conn, public_key=public_key, command=command, + debug=args.debug, timeout=args.timeout) @handle_connection_error diff --git a/trezor_agent/client.py b/trezor_agent/client.py index b4624c1..2e5b074 100644 --- a/trezor_agent/client.py +++ b/trezor_agent/client.py @@ -3,11 +3,10 @@ Connection to hardware authentication device. It is used for getting SSH public keys and ECDSA signing of server requests. """ -import binascii import io import logging -from . import factory, formats, util +from . import formats, util log = logging.getLogger(__name__) @@ -15,79 +14,36 @@ log = logging.getLogger(__name__) class Client(object): """Client wrapper for SSH authentication device.""" - def __init__(self, loader=factory.load, curve=formats.CURVE_NIST256): + def __init__(self, device): """Connect to hardware device.""" - client_wrapper = loader() - self.client = client_wrapper.connection - self.identity_type = client_wrapper.identity_type - self.device_name = client_wrapper.device_name - self.call_exception = client_wrapper.call_exception - self.curve = curve + device.identity_dict['proto'] = 'ssh' + self.device = device - def __enter__(self): - """Start a session, and test connection.""" - msg = 'Hello World!' - assert self.client.ping(msg) == msg - return self + def get_public_key(self): + """Get SSH public key from the device.""" + with self.device: + pubkey = self.device.pubkey() - def __exit__(self, *args): - """Keep the session open (doesn't forget PIN).""" - log.info('disconnected from %s', self.device_name) - self.client.close() + vk = formats.decompress_pubkey(pubkey=pubkey, + curve_name=self.device.curve_name) + return formats.export_public_key(vk=vk, + label=self.device.identity_str()) - def get_identity(self, label, index=0): - """Parse label string into Identity protobuf.""" - identity = util.string_to_identity(label, self.identity_type) - identity.proto = 'ssh' - identity.index = index - return identity - - def get_public_key(self, label): - """Get SSH public key corresponding to specified by label.""" - identity = self.get_identity(label=label) - label = util.identity_to_string(identity) # canonize key label - log.info('getting "%s" public key (%s) from %s...', - label, self.curve, self.device_name) - addr = util.get_bip32_address(identity) - node = self.client.get_public_node(n=addr, - ecdsa_curve_name=self.curve) - - pubkey = node.node.public_key - vk = formats.decompress_pubkey(pubkey=pubkey, curve_name=self.curve) - return formats.export_public_key(vk=vk, label=label) - - def sign_ssh_challenge(self, label, blob): - """Sign given blob using a private key, specified by the label.""" - identity = self.get_identity(label=label) + def sign_ssh_challenge(self, blob): + """Sign given blob using a private key on the device.""" msg = _parse_ssh_blob(blob) log.debug('%s: user %r via %r (%r)', msg['conn'], msg['user'], msg['auth'], msg['key_type']) - log.debug('nonce: %s', binascii.hexlify(msg['nonce'])) + log.debug('nonce: %r', msg['nonce']) log.debug('fingerprint: %s', msg['public_key']['fingerprint']) log.debug('hidden challenge size: %d bytes', len(blob)) log.info('please confirm user "%s" login to "%s" using %s...', - msg['user'].decode('ascii'), label, self.device_name) - - try: - result = self.client.sign_identity(identity=identity, - challenge_hidden=blob, - challenge_visual='', - ecdsa_curve_name=self.curve) - except self.call_exception as e: - code, msg = e.args - log.warning('%s error #%s: %s', self.device_name, code, msg) - raise IOError(msg) # close current connection, keep server open - - verifying_key = formats.decompress_pubkey(pubkey=result.public_key, - curve_name=self.curve) - key_type, blob = formats.serialize_verifying_key(verifying_key) - assert blob == msg['public_key']['blob'] - assert key_type == msg['key_type'] - assert len(result.signature) == 65 - assert result.signature[:1] == bytearray([0]) + msg['user'].decode('ascii'), self.device.identity_str(), + self.device) - return result.signature[1:] + with self.device: + return self.device.sign(blob=blob) def _parse_ssh_blob(data): diff --git a/trezor_agent/device/interface.py b/trezor_agent/device/interface.py index b175efe..52bec1a 100644 --- a/trezor_agent/device/interface.py +++ b/trezor_agent/device/interface.py @@ -76,7 +76,6 @@ class Device(object): def __init__(self, identity_str, curve_name): """Configure for specific identity and elliptic curve usage.""" self.identity_dict = string_to_identity(identity_str) - assert curve_name in formats.SUPPORTED_CURVES self.curve_name = curve_name self.conn = None diff --git a/trezor_agent/protocol.py b/trezor_agent/protocol.py index efa2fea..e1a763f 100644 --- a/trezor_agent/protocol.py +++ b/trezor_agent/protocol.py @@ -7,7 +7,6 @@ for more details. The server's source code can be found here: https://github.com/openssh/openssh-portable/blob/master/authfd.c """ -import binascii import io import logging @@ -138,13 +137,13 @@ class Handler(object): else: raise KeyError('key not found') - log.debug('signing %d-byte blob', len(blob)) label = key['name'].decode('ascii') # label should be a string + log.debug('signing %d-byte blob with "%s" key', len(blob), label) try: - signature = self.signer(label=label, blob=blob) + signature = self.signer(blob=blob) except IOError: return failure() - log.debug('signature: %s', binascii.hexlify(signature)) + log.debug('signature: %r', signature) try: sig_bytes = key['verifier'](sig=signature, msg=blob) diff --git a/trezor_agent/tests/test_client.py b/trezor_agent/tests/test_client.py index 6a8273f..b3f4bad 100644 --- a/trezor_agent/tests/test_client.py +++ b/trezor_agent/tests/test_client.py @@ -3,7 +3,7 @@ import io import mock import pytest -from .. import client, factory, formats, util +from .. import client, device, formats, util ADDR = [2147483661, 2810943954, 3938368396, 3454558782, 3848009040] CURVE = 'nist256p1' @@ -15,29 +15,23 @@ PUBKEY_TEXT = ('ecdsa-sha2-nistp256 AAAAE2VjZHNhLXNoYTItbmlzdHAyNTYAAAAIbmlzd' 'VUfhvrGljR2Z/CMRONY6ejB+9PnpUOPuzYqi8= ssh://localhost:22\n') -class FakeConnection(object): +class MockDevice(device.interface.Device): # pylint: disable=abstract-method - def __init__(self): - self.closed = False + def connect(self): # pylint: disable=no-self-use + return mock.Mock() def close(self): - self.closed = True + self.conn = None - def clear_session(self): - self.closed = True + def pubkey(self, ecdh=False): # pylint: disable=unused-argument + assert self.conn + return PUBKEY - def get_public_node(self, n, ecdsa_curve_name=b'secp256k1'): - assert not self.closed - assert n == ADDR - assert ecdsa_curve_name in {'secp256k1', 'nist256p1'} - result = mock.Mock(spec=[]) - result.node = mock.Mock(spec=[]) - result.node.public_key = PUBKEY - return result - - def ping(self, msg): - assert not self.closed - return msg + def sign(self, blob): + """Sign given blob and return the signature (as bytes).""" + assert self.conn + assert blob == BLOB + return SIG def identity_type(**kwargs): @@ -50,13 +44,6 @@ def identity_type(**kwargs): return result -def load_client(): - return factory.ClientWrapper(connection=FakeConnection(), - identity_type=identity_type, - device_name='DEVICE_NAME', - call_exception=Exception) - - BLOB = (b'\x00\x00\x00 \xce\xe0\xc9\xd5\xceu/\xe8\xc5\xf2\xbfR+x\xa1\xcf\xb0' b'\x8e;R\xd3)m\x96\x1b\xb4\xd8s\xf1\x99\x16\xaa2\x00\x00\x00\x05roman' b'\x00\x00\x00\x0essh-connection\x00\x00\x00\tpublickey' @@ -66,71 +53,33 @@ BLOB = (b'\x00\x00\x00 \xce\xe0\xc9\xd5\xceu/\xe8\xc5\xf2\xbfR+x\xa1\xcf\xb0' b'\xdd\xbc+\xfar~\x9dAis4\xc1\x10yeT~\x1b\xeb\x1aX\xd1\xd9\x9f\xc21' b'\x13\x8dc\xa7\xa3\x07\xefO\x9e\x95\x0e>\xec\xd8\xaa/') -SIG = (b'\x00R\x19T\xf2\x84$\xef#\x0e\xee\x04X\xc6\xc3\x99T`\xd1\xd8\xf7!' +SIG = (b'R\x19T\xf2\x84$\xef#\x0e\xee\x04X\xc6\xc3\x99T`\xd1\xd8\xf7!' b'\x862@cx\xb8\xb9i@1\x1b3#\x938\x86]\x97*Y\xb2\x02Xa\xdf@\xecK' b'\xdc\xf0H\xab\xa8\xac\xa7? \x8f=C\x88N\xe2') def test_ssh_agent(): - label = 'localhost:22' - c = client.Client(loader=load_client) - ident = c.get_identity(label=label) - assert ident.host == 'localhost' - assert ident.proto == 'ssh' - assert ident.port == '22' - assert ident.user is None - assert ident.path is None - assert ident.index == 0 - - with c: - assert c.get_public_key(label) == PUBKEY_TEXT - - def ssh_sign_identity(identity, challenge_hidden, - challenge_visual, ecdsa_curve_name): - assert (util.identity_to_string(identity) == - util.identity_to_string(ident)) - assert challenge_hidden == BLOB - assert challenge_visual == '' - assert ecdsa_curve_name == 'nist256p1' - - result = mock.Mock(spec=[]) - result.public_key = PUBKEY - result.signature = SIG - return result - - c.client.sign_identity = ssh_sign_identity - signature = c.sign_ssh_challenge(label=label, blob=BLOB) - - key = formats.import_public_key(PUBKEY_TEXT) - serialized_sig = key['verifier'](sig=signature, msg=BLOB) - - stream = io.BytesIO(serialized_sig) - r = util.read_frame(stream) - s = util.read_frame(stream) - assert not stream.read() - assert r[:1] == b'\x00' - assert s[:1] == b'\x00' - assert r[1:] + s[1:] == SIG[1:] - - c.client.call_exception = ValueError - - # pylint: disable=unused-argument - def cancel_sign_identity(identity, challenge_hidden, - challenge_visual, ecdsa_curve_name): - raise c.client.call_exception(42, 'ERROR') - - c.client.sign_identity = cancel_sign_identity - with pytest.raises(IOError): - c.sign_ssh_challenge(label=label, blob=BLOB) - - -def test_utils(): - identity = mock.Mock(spec=[]) - identity.proto = 'https' - identity.user = 'user' - identity.host = 'host' - identity.port = '443' - identity.path = '/path' - - url = 'https://user@host:443/path' - assert util.identity_to_string(identity) == url + identity_str = 'localhost:22' + c = client.Client(device=MockDevice(identity_str=identity_str, + curve_name=CURVE)) + assert c.get_public_key() == PUBKEY_TEXT + signature = c.sign_ssh_challenge(blob=BLOB) + + key = formats.import_public_key(PUBKEY_TEXT) + serialized_sig = key['verifier'](sig=signature, msg=BLOB) + + stream = io.BytesIO(serialized_sig) + r = util.read_frame(stream) + s = util.read_frame(stream) + assert not stream.read() + assert r[:1] == b'\x00' + assert s[:1] == b'\x00' + assert r[1:] + s[1:] == SIG + + # pylint: disable=unused-argument + def cancel_sign(blob): + raise IOError(42, 'ERROR') + + c.device.sign = cancel_sign + with pytest.raises(IOError): + c.sign_ssh_challenge(blob=BLOB) diff --git a/trezor_agent/tests/test_protocol.py b/trezor_agent/tests/test_protocol.py index 541fecb..17a1001 100644 --- a/trezor_agent/tests/test_protocol.py +++ b/trezor_agent/tests/test_protocol.py @@ -28,8 +28,7 @@ def test_unsupported(): assert reply == b'\x00\x00\x00\x01\x05' -def ecdsa_signer(label, blob): - assert label == 'ssh://localhost' +def ecdsa_signer(blob): assert blob == NIST256_BLOB return NIST256_SIG @@ -49,8 +48,7 @@ def test_sign_missing(): def test_sign_wrong(): - def wrong_signature(label, blob): - assert label == 'ssh://localhost' + def wrong_signature(blob): assert blob == NIST256_BLOB return b'\x00' * 64 @@ -62,7 +60,7 @@ def test_sign_wrong(): def test_sign_cancel(): - def cancel_signature(label, blob): # pylint: disable=unused-argument + def cancel_signature(blob): # pylint: disable=unused-argument raise IOError() key = formats.import_public_key(NIST256_KEY) @@ -79,8 +77,7 @@ ED25519_BLOB = b'''\x00\x00\x00 i3\xae}yk\\\xa1L\xb9\xe1\xbf\xbc\x8e\x87\r\x0e\x ED25519_SIG = b'''\x8eb)\xa6\xe9P\x83VE\xfbq\xc6\xbf\x1dV3\xe3.*)://)?', - r'(?:(?P.*)@)?', - r'(?P.*?)', - r'(?::(?P\w*))?', - r'(?P/.*)?', - '$' -])) - - -def string_to_identity(s, identity_type): - """Parse string into Identity protobuf.""" - m = _identity_regexp.match(s) - result = m.groupdict() - log.debug('parsed identity: %s', result) - kwargs = {k: v for k, v in result.items() if v} - return identity_type(**kwargs) - - -def identity_to_string(identity): - """Dump Identity protobuf into its string representation.""" - result = [] - if identity.proto: - result.append(identity.proto + '://') - if identity.user: - result.append(identity.user + '@') - result.append(identity.host) - if identity.port: - result.append(':' + identity.port) - if identity.path: - result.append(identity.path) - return ''.join(result) - - -def get_bip32_address(identity, ecdh=False): - """Compute BIP32 derivation address according to SLIP-0013/0017.""" - index = struct.pack('