trezor: use identities instead of labels

nistp521
Roman Zeyde 9 years ago
parent 58c650c84a
commit 3057a3d7a9

@ -1,4 +1,5 @@
import io import io
import struct
import binascii import binascii
from . import util from . import util
@ -21,10 +22,10 @@ class TrezorLibrary(object):
return TrezorClient(HidTransport(devices[0])) return TrezorClient(HidTransport(devices[0]))
@staticmethod @staticmethod
def identity(label, proto='ssh'): def parse_identity(s):
# pylint: disable=import-error # pylint: disable=import-error
from trezorlib.types_pb2 import IdentityType from trezorlib.types_pb2 import IdentityType
return IdentityType(host=label, proto=proto) return IdentityType(**_string_to_identity(s))
class Client(object): class Client(object):
@ -52,19 +53,20 @@ class Client(object):
self.client.close() self.client.close()
def get_public_key(self, label): def get_public_key(self, label):
addr = _get_address(self.factory.identity(label)) log.info('getting %r public key from Trezor...', label)
log.info('getting %r SSH public key from Trezor...', label) identity = self.factory.parse_identity(label)
addr = _get_address(identity)
node = self.client.get_public_node(addr, self.curve_name) node = self.client.get_public_node(addr, self.curve_name)
return node.node.public_key return node.node.public_key
def sign_ssh_challenge(self, label, blob): def sign_ssh_challenge(self, label, blob):
ident = self.factory.identity(label) identity = self.factory.parse_identity(label)
msg = _parse_ssh_blob(blob) msg = _parse_ssh_blob(blob)
request = 'user: "{user}"'.format(**msg) request = 'user: "{user}"'.format(**msg)
log.info('confirm %s connection to %r using Trezor...', log.info('confirm %s connection to %r using Trezor...',
request, label) request, label)
s = self.client.sign_identity(identity=ident, s = self.client.sign_identity(identity=identity,
challenge_hidden=blob, challenge_hidden=blob,
challenge_visual=request, challenge_visual=request,
ecdsa_curve_name=self.curve_name) ecdsa_curve_name=self.curve_name)
@ -77,9 +79,53 @@ class Client(object):
return (r, s) return (r, s)
def _get_address(ident): def _lsplit(s, sep):
index = '\x00' * 4 p = None
addr = index + '{}://{}'.format(ident.proto, ident.host) if sep in s:
p, s = s.split(sep, 1)
return (p, s)
def _rsplit(s, sep):
p = None
if sep in s:
s, p = s.rsplit(sep, 1)
return (s, p)
def _string_to_identity(s):
proto, s = _lsplit(s, '://')
user, s = _lsplit(s, '@')
s, path = _rsplit(s, '/')
host, port = _rsplit(s, ':')
if not proto:
proto = 'ssh'
result = [
('proto', proto), ('user', user), ('host', host),
('port', port), ('path', path)
]
return {k: v for k, v in result if v}
def _identity_to_string(identity):
result = []
if identity.proto:
result.append(identity.proto + '://')
if identity.user:
result.append(identity.user + '@')
result.append(identity.host)
if identity.port:
result.append(':' + identity.port)
if identity.path:
result.append('/' + identity.path)
return ''.join(result)
def _get_address(identity):
index = struct.pack('<L', identity.index)
addr = index + _identity_to_string(identity)
digest = formats.hashfunc(addr).digest() digest = formats.hashfunc(addr).digest()
s = io.BytesIO(bytearray(digest)) s = io.BytesIO(bytearray(digest))

@ -1,3 +1,4 @@
import os
import sys import sys
import argparse import argparse
@ -10,12 +11,14 @@ log = logging.getLogger(__name__)
def main(): def main():
fmt = '%(asctime)s %(levelname)-12s %(message)-100s [%(filename)s]' fmt = '%(asctime)s %(levelname)-12s %(message)-100s [%(filename)s:%(lineno)d]'
p = argparse.ArgumentParser() p = argparse.ArgumentParser()
p.add_argument('-k', '--key-label', p.add_argument('-v', '--verbose', action='count', default=0,
metavar='LABEL', dest='labels', action='append', default=[]) help='increase the the logging verbosity')
p.add_argument('-v', '--verbose', action='count', default=0) p.add_argument('-c', dest='command', type=str, default=None,
p.add_argument('command', type=str, nargs='*') help='command to run under the SSH agent')
p.add_argument('identity', type=str, nargs='*',
help='proto://[user@]host[:port][/path]')
args = p.parse_args() args = p.parse_args()
verbosity = [logging.WARNING, logging.INFO, logging.DEBUG] verbosity = [logging.WARNING, logging.INFO, logging.DEBUG]
@ -24,20 +27,21 @@ def main():
with trezor.Client(factory=trezor.TrezorLibrary) as client: with trezor.Client(factory=trezor.TrezorLibrary) as client:
key_files = [] key_files = []
for label in args.labels: for label in args.identity:
pubkey = client.get_public_key(label=label) pubkey = client.get_public_key(label)
key_file = formats.export_public_key(pubkey=pubkey, label=label) key_file = formats.export_public_key(pubkey=pubkey, label=label)
key_files.append(key_file) key_files.append(key_file)
if not args.command: command = args.command
sys.stdout.write(''.join(key_files)) if not command:
return command = os.environ['SHELL']
log.info('using %r shell', command)
signer = client.sign_ssh_challenge signer = client.sign_ssh_challenge
try: try:
with server.serve(key_files=key_files, signer=signer) as env: with server.serve(key_files=key_files, signer=signer) as env:
return server.run_process(command=args.command, environ=env) return server.run_process(command=command, environ=env)
except KeyboardInterrupt: except KeyboardInterrupt:
log.info('server stopped') log.info('server stopped')

Loading…
Cancel
Save