factory: refactor for easier testing

This commit is contained in:
Roman Zeyde 2016-01-19 22:52:52 +02:00
parent 9afd07e867
commit 2eff21f96c
2 changed files with 106 additions and 5 deletions

View File

@ -0,0 +1,92 @@
import mock
import pytest
from ..trezor import factory
def test_load():
def single():
return [0]
def nothing():
return []
def double():
return [1, 2]
assert factory.load(loaders=[single]) == 0
assert factory.load(loaders=[single, nothing]) == 0
assert factory.load(loaders=[nothing, single]) == 0
with pytest.raises(IOError):
factory.load(loaders=[])
with pytest.raises(IOError):
factory.load(loaders=[single, single])
with pytest.raises(IOError):
factory.load(loaders=[double])
factory_load_client = factory._load_client # pylint: disable=protected-access
def test_load_nothing():
hid_transport = mock.Mock()
hid_transport.enumerate.return_value = []
result = factory_load_client(
name=None,
client_type=None,
hid_transport=hid_transport,
passphrase_ack=None,
identity_type=None,
required_version=None)
assert list(result) == []
def create_client_type(version):
conn = mock.Mock()
conn.features = mock.Mock()
major, minor, patch = version.split('.')
conn.features.major_version = major
conn.features.minor_version = minor
conn.features.patch_version = patch
conn.features.revision = b'\x12\x34\x56\x78'
client_type = mock.Mock()
client_type.return_value = conn
return client_type
def test_load_single():
hid_transport = mock.Mock()
hid_transport.enumerate.return_value = [0]
for version in ('1.3.4', '1.3.5', '1.4.0', '2.0.0'):
passphrase_ack = mock.Mock()
client_type = create_client_type(version)
result = factory_load_client(
name='DEVICE_NAME',
client_type=client_type,
hid_transport=hid_transport,
passphrase_ack=passphrase_ack,
identity_type=None,
required_version='>=1.3.4')
client_wrapper, = result
assert client_wrapper.connection is client_type.return_value
assert client_wrapper.device_name == 'DEVICE_NAME'
client_wrapper.connection.callback_PassphraseRequest('MESSAGE')
assert passphrase_ack.mock_calls == [mock.call(passphrase='')]
def test_load_old():
hid_transport = mock.Mock()
hid_transport.enumerate.return_value = [0]
for version in ('1.3.3', '1.2.5', '1.1.0', '0.9.9'):
with pytest.raises(ValueError):
next(factory_load_client(
name='DEVICE_NAME',
client_type=create_client_type(version),
hid_transport=hid_transport,
passphrase_ack=None,
identity_type=None,
required_version='>=1.3.4'))

View File

@ -68,11 +68,20 @@ def _load_keepkey():
identity_type=IdentityType,
required_version='>=1.0.4')
LOADERS = [
_load_trezor,
_load_keepkey
]
def load():
devices = list(_load_trezor()) + list(_load_keepkey())
if len(devices) == 1:
return devices[0]
msg = '{:d} devices found'.format(len(devices))
def load(loaders=None):
loaders = loaders if loaders is not None else LOADERS
device_list = []
for loader in loaders:
device_list.extend(loader())
if len(device_list) == 1:
return device_list[0]
msg = '{:d} devices found'.format(len(device_list))
raise IOError(msg)