2015-06-06 14:52:10 +00:00
|
|
|
import io
|
2015-07-04 18:15:09 +00:00
|
|
|
import re
|
2015-07-03 13:09:49 +00:00
|
|
|
import struct
|
2015-06-15 06:50:51 +00:00
|
|
|
import binascii
|
2015-06-06 14:52:10 +00:00
|
|
|
|
2015-06-15 15:13:10 +00:00
|
|
|
from . import util
|
|
|
|
from . import formats
|
2015-06-06 14:52:10 +00:00
|
|
|
|
2015-06-15 15:13:10 +00:00
|
|
|
import logging
|
2015-06-06 14:52:10 +00:00
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
2015-06-17 13:29:12 +00:00
|
|
|
class TrezorLibrary(object):
|
2015-06-06 14:52:10 +00:00
|
|
|
|
2015-06-17 13:29:12 +00:00
|
|
|
@staticmethod
|
|
|
|
def client():
|
2015-06-17 14:56:12 +00:00
|
|
|
# pylint: disable=import-error
|
2015-06-17 13:29:12 +00:00
|
|
|
from trezorlib.client import TrezorClient
|
|
|
|
from trezorlib.transport_hid import HidTransport
|
2015-06-15 06:50:51 +00:00
|
|
|
devices = HidTransport.enumerate()
|
|
|
|
if len(devices) != 1:
|
|
|
|
raise ValueError('{:d} Trezor devices found'.format(len(devices)))
|
2015-06-17 13:29:12 +00:00
|
|
|
return TrezorClient(HidTransport(devices[0]))
|
|
|
|
|
|
|
|
@staticmethod
|
2015-07-03 13:09:49 +00:00
|
|
|
def parse_identity(s):
|
2015-06-17 14:56:12 +00:00
|
|
|
# pylint: disable=import-error
|
2015-06-17 13:29:12 +00:00
|
|
|
from trezorlib.types_pb2 import IdentityType
|
2015-07-03 13:09:49 +00:00
|
|
|
return IdentityType(**_string_to_identity(s))
|
2015-06-17 13:29:12 +00:00
|
|
|
|
|
|
|
|
|
|
|
class Client(object):
|
|
|
|
|
2015-06-23 14:53:59 +00:00
|
|
|
curve_name = 'nist256p1'
|
|
|
|
|
2015-06-17 13:29:12 +00:00
|
|
|
def __init__(self, factory=TrezorLibrary):
|
|
|
|
self.factory = factory
|
|
|
|
self.client = self.factory.client()
|
|
|
|
f = self.client.features
|
2015-07-04 05:48:36 +00:00
|
|
|
log.debug('connected to Trezor %s', f.device_id)
|
2015-06-16 07:33:48 +00:00
|
|
|
log.debug('label : %s', f.label)
|
|
|
|
log.debug('vendor : %s', f.vendor)
|
2015-06-15 06:50:51 +00:00
|
|
|
version = [f.major_version, f.minor_version, f.patch_version]
|
2015-06-16 07:33:48 +00:00
|
|
|
log.debug('version : %s', '.'.join([str(v) for v in version]))
|
|
|
|
log.debug('revision : %s', binascii.hexlify(f.revision))
|
2015-06-06 14:52:10 +00:00
|
|
|
|
2015-06-17 13:51:42 +00:00
|
|
|
def __enter__(self):
|
|
|
|
return self
|
|
|
|
|
|
|
|
def __exit__(self, *args):
|
|
|
|
log.info('disconnected from Trezor')
|
2015-06-06 14:52:10 +00:00
|
|
|
self.client.close()
|
|
|
|
|
2015-07-04 07:47:32 +00:00
|
|
|
def get_identity(self, label):
|
|
|
|
return self.factory.parse_identity(label)
|
|
|
|
|
|
|
|
def get_public_key(self, identity):
|
|
|
|
label = _identity_to_string(identity)
|
2015-07-04 05:48:56 +00:00
|
|
|
log.info('getting "%s" public key from Trezor...', label)
|
2015-07-03 13:09:49 +00:00
|
|
|
addr = _get_address(identity)
|
2015-06-23 14:53:59 +00:00
|
|
|
node = self.client.get_public_node(addr, self.curve_name)
|
2015-07-04 05:35:59 +00:00
|
|
|
|
|
|
|
pubkey = node.node.public_key
|
|
|
|
return formats.export_public_key(pubkey=pubkey, label=label)
|
2015-06-06 14:52:10 +00:00
|
|
|
|
2015-07-04 18:15:09 +00:00
|
|
|
def sign_ssh_challenge(self, identity, blob):
|
|
|
|
label = _identity_to_string(identity)
|
2015-06-16 07:04:02 +00:00
|
|
|
msg = _parse_ssh_blob(blob)
|
2015-06-06 14:52:10 +00:00
|
|
|
|
2015-07-05 18:35:45 +00:00
|
|
|
log.info('please confirm user %s connection to "%s" using Trezor...',
|
2015-07-04 07:47:32 +00:00
|
|
|
msg['user'], label)
|
2015-07-05 12:35:25 +00:00
|
|
|
|
|
|
|
assert identity.proto == 'ssh'
|
|
|
|
visual = identity.path # not signed when proto='ssh'
|
|
|
|
result = self.client.sign_identity(identity=identity,
|
|
|
|
challenge_hidden=blob,
|
|
|
|
challenge_visual=visual,
|
|
|
|
ecdsa_curve_name=self.curve_name)
|
2015-07-20 14:21:46 +00:00
|
|
|
public_key_blob = formats.decompress_pubkey(result.public_key)
|
|
|
|
assert public_key_blob == msg['public_key']['blob']
|
2015-07-05 12:35:25 +00:00
|
|
|
assert len(result.signature) == 65
|
|
|
|
assert result.signature[0] == b'\x00'
|
|
|
|
|
|
|
|
sig = result.signature[1:]
|
2015-06-22 18:23:49 +00:00
|
|
|
r = util.bytes2num(sig[:32])
|
|
|
|
s = util.bytes2num(sig[32:])
|
2015-06-06 14:52:10 +00:00
|
|
|
return (r, s)
|
|
|
|
|
|
|
|
|
2015-07-04 18:57:57 +00:00
|
|
|
_identity_regexp = re.compile(''.join([
|
2015-07-04 18:15:09 +00:00
|
|
|
'^'
|
|
|
|
r'(?:(?P<proto>.*)://)?',
|
|
|
|
r'(?:(?P<user>.*)@)?',
|
|
|
|
r'(?P<host>.*?)',
|
|
|
|
r'(?::(?P<port>\w*))?',
|
2015-07-04 18:57:57 +00:00
|
|
|
r'(?P<path>/.*)?',
|
2015-07-04 18:15:09 +00:00
|
|
|
'$'
|
2015-07-04 18:57:57 +00:00
|
|
|
]))
|
2015-07-03 13:09:49 +00:00
|
|
|
|
|
|
|
def _string_to_identity(s):
|
2015-07-04 18:15:09 +00:00
|
|
|
m = _identity_regexp.match(s)
|
|
|
|
result = m.groupdict()
|
|
|
|
if not result.get('proto'):
|
|
|
|
result['proto'] = 'ssh' # otherwise, Trezor will use SECP256K1 curve
|
2015-07-03 13:09:49 +00:00
|
|
|
|
2015-07-04 18:15:09 +00:00
|
|
|
log.debug('parsed identity: %s', result)
|
|
|
|
return {k: v for k, v in result.items() if v}
|
2015-07-03 13:09:49 +00:00
|
|
|
|
|
|
|
|
|
|
|
def _identity_to_string(identity):
|
2015-07-04 18:15:09 +00:00
|
|
|
assert identity.proto == 'ssh'
|
|
|
|
result = [identity.proto + '://']
|
2015-07-03 13:09:49 +00:00
|
|
|
if identity.user:
|
|
|
|
result.append(identity.user + '@')
|
|
|
|
result.append(identity.host)
|
|
|
|
if identity.port:
|
|
|
|
result.append(':' + identity.port)
|
|
|
|
if identity.path:
|
2015-07-04 18:57:57 +00:00
|
|
|
result.append(identity.path)
|
2015-07-03 13:09:49 +00:00
|
|
|
return ''.join(result)
|
|
|
|
|
|
|
|
|
|
|
|
def _get_address(identity):
|
|
|
|
index = struct.pack('<L', identity.index)
|
|
|
|
addr = index + _identity_to_string(identity)
|
2015-07-04 18:57:57 +00:00
|
|
|
log.debug('address string: %r', addr)
|
2015-06-16 07:04:02 +00:00
|
|
|
digest = formats.hashfunc(addr).digest()
|
|
|
|
s = io.BytesIO(bytearray(digest))
|
|
|
|
|
2015-06-23 05:59:10 +00:00
|
|
|
hardened = 0x80000000
|
2015-06-23 14:54:26 +00:00
|
|
|
address_n = [13] + list(util.recv(s, '<LLLL'))
|
2015-06-23 05:59:10 +00:00
|
|
|
return [(hardened | value) for value in address_n]
|
2015-06-16 07:04:02 +00:00
|
|
|
|
|
|
|
|
|
|
|
def _parse_ssh_blob(data):
|
2015-06-06 14:52:10 +00:00
|
|
|
res = {}
|
|
|
|
if data:
|
|
|
|
i = io.BytesIO(data)
|
2015-06-15 15:13:10 +00:00
|
|
|
res['nonce'] = util.read_frame(i)
|
2015-06-06 14:52:10 +00:00
|
|
|
i.read(1) # TBD
|
2015-06-15 15:13:10 +00:00
|
|
|
res['user'] = util.read_frame(i)
|
|
|
|
res['conn'] = util.read_frame(i)
|
|
|
|
res['auth'] = util.read_frame(i)
|
2015-06-06 14:52:10 +00:00
|
|
|
i.read(1) # TBD
|
2015-06-15 15:13:10 +00:00
|
|
|
res['key_type'] = util.read_frame(i)
|
2015-07-20 14:21:46 +00:00
|
|
|
public_key = util.read_frame(i)
|
|
|
|
res['public_key'] = formats.parse_pubkey(public_key)
|
2015-07-04 18:15:09 +00:00
|
|
|
assert not i.read()
|
2015-06-06 14:52:10 +00:00
|
|
|
log.debug('%s: user %r via %r (%r)',
|
|
|
|
res['conn'], res['user'], res['auth'], res['key_type'])
|
2015-07-03 13:35:21 +00:00
|
|
|
log.debug('nonce: %s', binascii.hexlify(res['nonce']))
|
2015-07-20 14:21:46 +00:00
|
|
|
log.debug('fingerprint: %s', res['public_key']['fingerprint'])
|
2015-06-06 14:52:10 +00:00
|
|
|
return res
|