add session request to TransportV2, add @session helper

This commit is contained in:
Jan Pochyla 2016-09-13 12:25:06 +02:00
parent 5fc6dc3155
commit 4d3e4574ef
2 changed files with 130 additions and 121 deletions

View File

@ -88,6 +88,18 @@ class expect(object):
return ret return ret
return wrapped_f return wrapped_f
def session(f):
# Decorator wraps a BaseClient method
# with session activation / deactivation
def wrapped_f(*args, **kwargs):
client = args[0]
try:
client.transport.session_begin()
return f(*args, **kwargs)
finally:
client.transport.session_end()
return wrapped_f
def normalize_nfc(txt): def normalize_nfc(txt):
if sys.version_info[0] < 3: if sys.version_info[0] < 3:
if isinstance(txt, unicode): if isinstance(txt, unicode):
@ -112,33 +124,22 @@ class BaseClient(object):
def cancel(self): def cancel(self):
self.transport.write(proto.Cancel()) self.transport.write(proto.Cancel())
@session
def call_raw(self, msg): def call_raw(self, msg):
try: self.transport.write(msg)
self.transport.session_begin() return self.transport.read_blocking()
self.transport.write(msg)
resp = self.transport.read_blocking()
finally:
self.transport.session_end()
return resp
@session
def call(self, msg): def call(self, msg):
try: resp = self.call_raw(msg)
self.transport.session_begin() handler_name = "callback_%s" % resp.__class__.__name__
handler = getattr(self, handler_name, None)
resp = self.call_raw(msg) if handler != None:
handler_name = "callback_%s" % resp.__class__.__name__ msg = handler(resp)
handler = getattr(self, handler_name, None) if msg == None:
raise Exception("Callback %s must return protobuf message, not None" % handler)
if handler != None: resp = self.call(msg)
msg = handler(resp)
if msg == None:
raise Exception("Callback %s must return protobuf message, not None" % handler)
resp = self.call(msg)
finally:
self.transport.session_end()
return resp return resp
@ -423,6 +424,7 @@ class ProtocolMixin(object):
n = self._convert_prime(n) n = self._convert_prime(n)
return self.call(proto.EthereumGetAddress(address_n=n, show_display=show_display)) return self.call(proto.EthereumGetAddress(address_n=n, show_display=show_display))
@session
def ethereum_sign_tx(self, n, nonce, gas_price, gas_limit, to, value, data=None): def ethereum_sign_tx(self, n, nonce, gas_price, gas_limit, to, value, data=None):
def int_to_big_endian(value): def int_to_big_endian(value):
import rlp.utils import rlp.utils
@ -432,35 +434,30 @@ class ProtocolMixin(object):
n = self._convert_prime(n) n = self._convert_prime(n)
try: msg = proto.EthereumSignTx(
self.transport.session_begin() address_n=n,
nonce=int_to_big_endian(nonce),
gas_price=int_to_big_endian(gas_price),
gas_limit=int_to_big_endian(gas_limit),
value=int_to_big_endian(value))
msg = proto.EthereumSignTx( if to:
address_n=n, msg.to = to
nonce=int_to_big_endian(nonce),
gas_price=int_to_big_endian(gas_price),
gas_limit=int_to_big_endian(gas_limit),
value=int_to_big_endian(value))
if to: if data:
msg.to = to msg.data_length = len(data)
data, chunk = data[1024:], data[:1024]
msg.data_initial_chunk = chunk
if data: response = self.call(msg)
msg.data_length = len(data)
data, chunk = data[1024:], data[:1024]
msg.data_initial_chunk = chunk
response = self.call(msg) while response.HasField('data_length'):
data_length = response.data_length
data, chunk = data[data_length:], data[:data_length]
response = self.call(proto.EthereumTxAck(data_chunk=chunk))
while response.HasField('data_length'): return response.signature_v, response.signature_r, response.signature_s
data_length = response.data_length
data, chunk = data[data_length:], data[:data_length]
response = self.call(proto.EthereumTxAck(data_chunk=chunk))
return response.signature_v, response.signature_r, response.signature_s
finally:
self.transport.session_end()
@field('entropy') @field('entropy')
@expect(proto.Entropy) @expect(proto.Entropy)
@ -634,88 +631,83 @@ class ProtocolMixin(object):
return txes return txes
@session
def sign_tx(self, coin_name, inputs, outputs, debug_processor=None): def sign_tx(self, coin_name, inputs, outputs, debug_processor=None):
start = time.time() start = time.time()
txes = self._prepare_sign_tx(coin_name, inputs, outputs) txes = self._prepare_sign_tx(coin_name, inputs, outputs)
try: # Prepare and send initial message
self.transport.session_begin() tx = proto.SignTx()
tx.inputs_count = len(inputs)
tx.outputs_count = len(outputs)
tx.coin_name = coin_name
res = self.call(tx)
# Prepare and send initial message # Prepare structure for signatures
tx = proto.SignTx() signatures = [None] * len(inputs)
tx.inputs_count = len(inputs) serialized_tx = b''
tx.outputs_count = len(outputs)
tx.coin_name = coin_name
res = self.call(tx)
# Prepare structure for signatures counter = 0
signatures = [None] * len(inputs) while True:
serialized_tx = b'' counter += 1
counter = 0 if isinstance(res, proto.Failure):
while True: raise CallException("Signing failed")
counter += 1
if isinstance(res, proto.Failure): if not isinstance(res, proto.TxRequest):
raise CallException("Signing failed") raise CallException("Unexpected message")
if not isinstance(res, proto.TxRequest): # If there's some part of signed transaction, let's add it
raise CallException("Unexpected message") if res.HasField('serialized') and res.serialized.HasField('serialized_tx'):
log("RECEIVED PART OF SERIALIZED TX (%d BYTES)" % len(res.serialized.serialized_tx))
serialized_tx += res.serialized.serialized_tx
# If there's some part of signed transaction, let's add it if res.HasField('serialized') and res.serialized.HasField('signature_index'):
if res.HasField('serialized') and res.serialized.HasField('serialized_tx'): if signatures[res.serialized.signature_index] != None:
log("RECEIVED PART OF SERIALIZED TX (%d BYTES)" % len(res.serialized.serialized_tx)) raise Exception("Signature for index %d already filled" % res.serialized.signature_index)
serialized_tx += res.serialized.serialized_tx signatures[res.serialized.signature_index] = res.serialized.signature
if res.HasField('serialized') and res.serialized.HasField('signature_index'): if res.request_type == types.TXFINISHED:
if signatures[res.serialized.signature_index] != None: # Device didn't ask for more information, finish workflow
raise Exception("Signature for index %d already filled" % res.serialized.signature_index) break
signatures[res.serialized.signature_index] = res.serialized.signature
if res.request_type == types.TXFINISHED: # Device asked for one more information, let's process it.
# Device didn't ask for more information, finish workflow current_tx = txes[res.details.tx_hash]
break
# Device asked for one more information, let's process it. if res.request_type == types.TXMETA:
current_tx = txes[res.details.tx_hash] msg = types.TransactionType()
msg.version = current_tx.version
msg.lock_time = current_tx.lock_time
msg.inputs_cnt = len(current_tx.inputs)
if res.details.tx_hash:
msg.outputs_cnt = len(current_tx.bin_outputs)
else:
msg.outputs_cnt = len(current_tx.outputs)
res = self.call(proto.TxAck(tx=msg))
continue
if res.request_type == types.TXMETA: elif res.request_type == types.TXINPUT:
msg = types.TransactionType() msg = types.TransactionType()
msg.version = current_tx.version msg.inputs.extend([current_tx.inputs[res.details.request_index], ])
msg.lock_time = current_tx.lock_time res = self.call(proto.TxAck(tx=msg))
msg.inputs_cnt = len(current_tx.inputs) continue
if res.details.tx_hash:
msg.outputs_cnt = len(current_tx.bin_outputs)
else:
msg.outputs_cnt = len(current_tx.outputs)
res = self.call(proto.TxAck(tx=msg))
continue
elif res.request_type == types.TXINPUT: elif res.request_type == types.TXOUTPUT:
msg = types.TransactionType() msg = types.TransactionType()
msg.inputs.extend([current_tx.inputs[res.details.request_index], ]) if res.details.tx_hash:
res = self.call(proto.TxAck(tx=msg)) msg.bin_outputs.extend([current_tx.bin_outputs[res.details.request_index], ])
continue else:
msg.outputs.extend([current_tx.outputs[res.details.request_index], ])
elif res.request_type == types.TXOUTPUT: if debug_processor != None:
msg = types.TransactionType() # If debug_processor function is provided,
if res.details.tx_hash: # pass thru it the request and prepared response.
msg.bin_outputs.extend([current_tx.bin_outputs[res.details.request_index], ]) # This is useful for unit tests, see test_msg_signtx
else: msg = debug_processor(res, msg)
msg.outputs.extend([current_tx.outputs[res.details.request_index], ])
if debug_processor != None: res = self.call(proto.TxAck(tx=msg))
# If debug_processor function is provided, continue
# pass thru it the request and prepared response.
# This is useful for unit tests, see test_msg_signtx
msg = debug_processor(res, msg)
res = self.call(proto.TxAck(tx=msg))
continue
finally:
self.transport.session_end()
if None in signatures: if None in signatures:
raise Exception("Some signatures are missing!") raise Exception("Some signatures are missing!")
@ -753,6 +745,7 @@ class ProtocolMixin(object):
@field('message') @field('message')
@expect(proto.Success) @expect(proto.Success)
@session
def reset_device(self, display_random, strength, passphrase_protection, pin_protection, label, language): def reset_device(self, display_random, strength, passphrase_protection, pin_protection, label, language):
if self.features.initialized: if self.features.initialized:
raise Exception("Device is initialized already. Call wipe_device() and try again.") raise Exception("Device is initialized already. Call wipe_device() and try again.")
@ -843,6 +836,7 @@ class ProtocolMixin(object):
self.init_device() self.init_device()
return resp return resp
@session
def firmware_update(self, fp): def firmware_update(self, fp):
if self.features.bootloader_mode == False: if self.features.bootloader_mode == False:
raise Exception("Device must be in bootloader mode") raise Exception("Device must be in bootloader mode")

View File

@ -71,9 +71,10 @@ class Transport(object):
def _parse_message(self, data): def _parse_message(self, data):
(session_id, msg_type, data) = data (session_id, msg_type, data) = data
# Raise exception if we get the response with # Raise exception if we get the response with unexpected session ID
# unexpected session ID if session_id != self.session_id:
self._check_session_id(session_id) raise Exception("Session ID mismatch. Have %d, got %d" %
(self.session_id, session_id))
if msg_type == 'protobuf': if msg_type == 'protobuf':
return data return data
@ -82,14 +83,6 @@ class Transport(object):
inst.ParseFromString(bytes(data)) inst.ParseFromString(bytes(data))
return inst return inst
def _check_session_id(self, session_id):
if self.session_id == 0:
# Let the device set the session ID
self.session_id = session_id
elif session_id != self.session_id:
# Session ID has been already set, but it differs from response
raise Exception("Session ID mismatch. Have %d, got %d" % (self.session_id, session_id))
# Functions to be implemented in specific transports: # Functions to be implemented in specific transports:
def _open(self): def _open(self):
raise NotImplementedException("Not implemented") raise NotImplementedException("Not implemented")
@ -237,6 +230,28 @@ class TransportV2(Transport):
data = chunk[1 + headerlen:] data = chunk[1 + headerlen:]
return (session_id, data) return (session_id, data)
def parse_session(self, chunk):
if chunk[0:1] != b"!":
raise Exception("Unexpected magic character")
try:
headerlen = struct.calcsize(">LL")
(null_session_id, new_session_id) = struct.unpack(
">LL", bytes(chunk[1:1 + headerlen]))
except:
raise Exception("Cannot parse header")
if null_session_id != 0:
raise Exception("Session response needs to use session ID 0")
return new_session_id
def _session_begin(self):
self._write_chunk(b'!' + b'\0' * 63)
self.session_id = self.parse_session(self._read_chunk())
def _session_end(self):
pass
''' '''
def read_headers(self, read_f): def read_headers(self, read_f):
c = read_f.read(2) c = read_f.read(2)