trezor-agent/sshagent/trezor.py

157 lines
4.6 KiB
Python
Raw Normal View History

2015-06-06 14:52:10 +00:00
import io
import struct
import binascii
2015-06-06 14:52:10 +00:00
2015-06-15 15:13:10 +00:00
from . import util
from . import formats
2015-06-06 14:52:10 +00:00
2015-06-15 15:13:10 +00:00
import logging
2015-06-06 14:52:10 +00:00
log = logging.getLogger(__name__)
class TrezorLibrary(object):
2015-06-06 14:52:10 +00:00
@staticmethod
def client():
2015-06-17 14:56:12 +00:00
# pylint: disable=import-error
from trezorlib.client import TrezorClient
from trezorlib.transport_hid import HidTransport
devices = HidTransport.enumerate()
if len(devices) != 1:
raise ValueError('{:d} Trezor devices found'.format(len(devices)))
return TrezorClient(HidTransport(devices[0]))
@staticmethod
def parse_identity(s):
2015-06-17 14:56:12 +00:00
# pylint: disable=import-error
from trezorlib.types_pb2 import IdentityType
return IdentityType(**_string_to_identity(s))
class Client(object):
2015-06-23 14:53:59 +00:00
curve_name = 'nist256p1'
def __init__(self, factory=TrezorLibrary):
self.factory = factory
self.client = self.factory.client()
f = self.client.features
2015-07-04 05:48:36 +00:00
log.debug('connected to Trezor %s', f.device_id)
2015-06-16 07:33:48 +00:00
log.debug('label : %s', f.label)
log.debug('vendor : %s', f.vendor)
version = [f.major_version, f.minor_version, f.patch_version]
2015-06-16 07:33:48 +00:00
log.debug('version : %s', '.'.join([str(v) for v in version]))
log.debug('revision : %s', binascii.hexlify(f.revision))
2015-06-06 14:52:10 +00:00
def __enter__(self):
return self
def __exit__(self, *args):
log.info('disconnected from Trezor')
self.client.clear_session()
2015-06-06 14:52:10 +00:00
self.client.close()
def get_public_key(self, label):
identity = self.factory.parse_identity(label)
label = _identity_to_string(identity) # update label after parsing
log.info('getting "%s" public key from Trezor...', label)
addr = _get_address(identity)
2015-06-23 14:53:59 +00:00
node = self.client.get_public_node(addr, self.curve_name)
2015-07-04 05:35:59 +00:00
pubkey = node.node.public_key
return formats.export_public_key(pubkey=pubkey, label=label)
2015-06-06 14:52:10 +00:00
def sign_ssh_challenge(self, label, blob):
identity = self.factory.parse_identity(label)
2015-06-16 07:04:02 +00:00
msg = _parse_ssh_blob(blob)
2015-06-06 14:52:10 +00:00
request = 'user: "{user}"'.format(**msg)
log.info('confirm %s connection to %r using Trezor...',
request, label)
s = self.client.sign_identity(identity=identity,
2015-06-06 14:52:10 +00:00
challenge_hidden=blob,
2015-06-23 14:53:59 +00:00
challenge_visual=request,
ecdsa_curve_name=self.curve_name)
assert len(s.signature) == 65
assert s.signature[0] == b'\x00'
sig = s.signature[1:]
r = util.bytes2num(sig[:32])
s = util.bytes2num(sig[32:])
2015-06-06 14:52:10 +00:00
return (r, s)
def _lsplit(s, sep):
p = None
if sep in s:
p, s = s.split(sep, 1)
return (p, s)
def _rsplit(s, sep):
p = None
if sep in s:
s, p = s.rsplit(sep, 1)
return (s, p)
def _string_to_identity(s):
proto, s = _lsplit(s, '://')
user, s = _lsplit(s, '@')
s, path = _rsplit(s, '/')
host, port = _rsplit(s, ':')
if not proto:
proto = 'ssh' # otherwise, Trezor will use SECP256K1 curve
result = [
('proto', proto), ('user', user), ('host', host),
('port', port), ('path', path)
]
return {k: v for k, v in result if v}
def _identity_to_string(identity):
result = []
if identity.proto:
result.append(identity.proto + '://')
if identity.user:
result.append(identity.user + '@')
result.append(identity.host)
if identity.port:
result.append(':' + identity.port)
if identity.path:
result.append('/' + identity.path)
return ''.join(result)
def _get_address(identity):
index = struct.pack('<L', identity.index)
addr = index + _identity_to_string(identity)
2015-06-16 07:04:02 +00:00
digest = formats.hashfunc(addr).digest()
s = io.BytesIO(bytearray(digest))
2015-06-23 05:59:10 +00:00
hardened = 0x80000000
address_n = [13] + list(util.recv(s, '<LLLL'))
2015-06-23 05:59:10 +00:00
return [(hardened | value) for value in address_n]
2015-06-16 07:04:02 +00:00
def _parse_ssh_blob(data):
2015-06-06 14:52:10 +00:00
res = {}
if data:
i = io.BytesIO(data)
2015-06-15 15:13:10 +00:00
res['nonce'] = util.read_frame(i)
2015-06-06 14:52:10 +00:00
i.read(1) # TBD
2015-06-15 15:13:10 +00:00
res['user'] = util.read_frame(i)
res['conn'] = util.read_frame(i)
res['auth'] = util.read_frame(i)
2015-06-06 14:52:10 +00:00
i.read(1) # TBD
2015-06-15 15:13:10 +00:00
res['key_type'] = util.read_frame(i)
res['pubkey'] = util.read_frame(i)
2015-06-06 14:52:10 +00:00
log.debug('%s: user %r via %r (%r)',
res['conn'], res['user'], res['auth'], res['key_type'])
log.debug('nonce: %s', binascii.hexlify(res['nonce']))
pubkey = formats.parse_pubkey(res['pubkey'])
log.debug('fingerprint: %s', pubkey['fingerprint'])
2015-06-06 14:52:10 +00:00
return res