diff --git a/sshagent/trezor.py b/sshagent/trezor.py index f6953ab..33d378e 100644 --- a/sshagent/trezor.py +++ b/sshagent/trezor.py @@ -7,36 +7,17 @@ import os from . import util from . import formats +from . import trezor_library import logging log = logging.getLogger(__name__) -class TrezorLibrary(object): - - @staticmethod - def client(): - # pylint: disable=import-error - from trezorlib.client import TrezorClient - from trezorlib.transport_hid import HidTransport - devices = HidTransport.enumerate() - if len(devices) != 1: - msg = '{:d} Trezor devices found'.format(len(devices)) - raise IOError(msg) - return TrezorClient(HidTransport(devices[0])) - - @staticmethod - def identity_type(**kwargs): - # pylint: disable=import-error - from trezorlib.types_pb2 import IdentityType - return IdentityType(**kwargs) - - class Client(object): curve_name = 'nist256p1' - def __init__(self, factory=TrezorLibrary): + def __init__(self, factory=trezor_library): self.factory = factory self.client = self.factory.client() f = self.client.features diff --git a/sshagent/trezor_library.py b/sshagent/trezor_library.py new file mode 100644 index 0000000..b857abb --- /dev/null +++ b/sshagent/trezor_library.py @@ -0,0 +1,18 @@ +''' Thin wrapper around trezorlib. ''' + + +def client(): + # pylint: disable=import-error + from trezorlib.client import TrezorClient + from trezorlib.transport_hid import HidTransport + devices = HidTransport.enumerate() + if len(devices) != 1: + msg = '{:d} Trezor devices found'.format(len(devices)) + raise IOError(msg) + return TrezorClient(HidTransport(devices[0])) + + +def identity_type(**kwargs): + # pylint: disable=import-error + from trezorlib.types_pb2 import IdentityType + return IdentityType(**kwargs)