add session request to TransportV2, add @session helper
This commit is contained in:
parent
5fc6dc3155
commit
4d3e4574ef
@ -88,6 +88,18 @@ class expect(object):
|
|||||||
return ret
|
return ret
|
||||||
return wrapped_f
|
return wrapped_f
|
||||||
|
|
||||||
|
def session(f):
|
||||||
|
# Decorator wraps a BaseClient method
|
||||||
|
# with session activation / deactivation
|
||||||
|
def wrapped_f(*args, **kwargs):
|
||||||
|
client = args[0]
|
||||||
|
try:
|
||||||
|
client.transport.session_begin()
|
||||||
|
return f(*args, **kwargs)
|
||||||
|
finally:
|
||||||
|
client.transport.session_end()
|
||||||
|
return wrapped_f
|
||||||
|
|
||||||
def normalize_nfc(txt):
|
def normalize_nfc(txt):
|
||||||
if sys.version_info[0] < 3:
|
if sys.version_info[0] < 3:
|
||||||
if isinstance(txt, unicode):
|
if isinstance(txt, unicode):
|
||||||
@ -112,20 +124,13 @@ class BaseClient(object):
|
|||||||
def cancel(self):
|
def cancel(self):
|
||||||
self.transport.write(proto.Cancel())
|
self.transport.write(proto.Cancel())
|
||||||
|
|
||||||
|
@session
|
||||||
def call_raw(self, msg):
|
def call_raw(self, msg):
|
||||||
try:
|
|
||||||
self.transport.session_begin()
|
|
||||||
self.transport.write(msg)
|
self.transport.write(msg)
|
||||||
resp = self.transport.read_blocking()
|
return self.transport.read_blocking()
|
||||||
finally:
|
|
||||||
self.transport.session_end()
|
|
||||||
|
|
||||||
return resp
|
|
||||||
|
|
||||||
|
@session
|
||||||
def call(self, msg):
|
def call(self, msg):
|
||||||
try:
|
|
||||||
self.transport.session_begin()
|
|
||||||
|
|
||||||
resp = self.call_raw(msg)
|
resp = self.call_raw(msg)
|
||||||
handler_name = "callback_%s" % resp.__class__.__name__
|
handler_name = "callback_%s" % resp.__class__.__name__
|
||||||
handler = getattr(self, handler_name, None)
|
handler = getattr(self, handler_name, None)
|
||||||
@ -134,12 +139,8 @@ class BaseClient(object):
|
|||||||
msg = handler(resp)
|
msg = handler(resp)
|
||||||
if msg == None:
|
if msg == None:
|
||||||
raise Exception("Callback %s must return protobuf message, not None" % handler)
|
raise Exception("Callback %s must return protobuf message, not None" % handler)
|
||||||
|
|
||||||
resp = self.call(msg)
|
resp = self.call(msg)
|
||||||
|
|
||||||
finally:
|
|
||||||
self.transport.session_end()
|
|
||||||
|
|
||||||
return resp
|
return resp
|
||||||
|
|
||||||
def callback_Failure(self, msg):
|
def callback_Failure(self, msg):
|
||||||
@ -423,6 +424,7 @@ class ProtocolMixin(object):
|
|||||||
n = self._convert_prime(n)
|
n = self._convert_prime(n)
|
||||||
return self.call(proto.EthereumGetAddress(address_n=n, show_display=show_display))
|
return self.call(proto.EthereumGetAddress(address_n=n, show_display=show_display))
|
||||||
|
|
||||||
|
@session
|
||||||
def ethereum_sign_tx(self, n, nonce, gas_price, gas_limit, to, value, data=None):
|
def ethereum_sign_tx(self, n, nonce, gas_price, gas_limit, to, value, data=None):
|
||||||
def int_to_big_endian(value):
|
def int_to_big_endian(value):
|
||||||
import rlp.utils
|
import rlp.utils
|
||||||
@ -432,9 +434,6 @@ class ProtocolMixin(object):
|
|||||||
|
|
||||||
n = self._convert_prime(n)
|
n = self._convert_prime(n)
|
||||||
|
|
||||||
try:
|
|
||||||
self.transport.session_begin()
|
|
||||||
|
|
||||||
msg = proto.EthereumSignTx(
|
msg = proto.EthereumSignTx(
|
||||||
address_n=n,
|
address_n=n,
|
||||||
nonce=int_to_big_endian(nonce),
|
nonce=int_to_big_endian(nonce),
|
||||||
@ -459,8 +458,6 @@ class ProtocolMixin(object):
|
|||||||
|
|
||||||
return response.signature_v, response.signature_r, response.signature_s
|
return response.signature_v, response.signature_r, response.signature_s
|
||||||
|
|
||||||
finally:
|
|
||||||
self.transport.session_end()
|
|
||||||
|
|
||||||
@field('entropy')
|
@field('entropy')
|
||||||
@expect(proto.Entropy)
|
@expect(proto.Entropy)
|
||||||
@ -634,14 +631,12 @@ class ProtocolMixin(object):
|
|||||||
|
|
||||||
return txes
|
return txes
|
||||||
|
|
||||||
|
@session
|
||||||
def sign_tx(self, coin_name, inputs, outputs, debug_processor=None):
|
def sign_tx(self, coin_name, inputs, outputs, debug_processor=None):
|
||||||
|
|
||||||
start = time.time()
|
start = time.time()
|
||||||
txes = self._prepare_sign_tx(coin_name, inputs, outputs)
|
txes = self._prepare_sign_tx(coin_name, inputs, outputs)
|
||||||
|
|
||||||
try:
|
|
||||||
self.transport.session_begin()
|
|
||||||
|
|
||||||
# Prepare and send initial message
|
# Prepare and send initial message
|
||||||
tx = proto.SignTx()
|
tx = proto.SignTx()
|
||||||
tx.inputs_count = len(inputs)
|
tx.inputs_count = len(inputs)
|
||||||
@ -714,9 +709,6 @@ class ProtocolMixin(object):
|
|||||||
res = self.call(proto.TxAck(tx=msg))
|
res = self.call(proto.TxAck(tx=msg))
|
||||||
continue
|
continue
|
||||||
|
|
||||||
finally:
|
|
||||||
self.transport.session_end()
|
|
||||||
|
|
||||||
if None in signatures:
|
if None in signatures:
|
||||||
raise Exception("Some signatures are missing!")
|
raise Exception("Some signatures are missing!")
|
||||||
|
|
||||||
@ -753,6 +745,7 @@ class ProtocolMixin(object):
|
|||||||
|
|
||||||
@field('message')
|
@field('message')
|
||||||
@expect(proto.Success)
|
@expect(proto.Success)
|
||||||
|
@session
|
||||||
def reset_device(self, display_random, strength, passphrase_protection, pin_protection, label, language):
|
def reset_device(self, display_random, strength, passphrase_protection, pin_protection, label, language):
|
||||||
if self.features.initialized:
|
if self.features.initialized:
|
||||||
raise Exception("Device is initialized already. Call wipe_device() and try again.")
|
raise Exception("Device is initialized already. Call wipe_device() and try again.")
|
||||||
@ -843,6 +836,7 @@ class ProtocolMixin(object):
|
|||||||
self.init_device()
|
self.init_device()
|
||||||
return resp
|
return resp
|
||||||
|
|
||||||
|
@session
|
||||||
def firmware_update(self, fp):
|
def firmware_update(self, fp):
|
||||||
if self.features.bootloader_mode == False:
|
if self.features.bootloader_mode == False:
|
||||||
raise Exception("Device must be in bootloader mode")
|
raise Exception("Device must be in bootloader mode")
|
||||||
|
@ -71,9 +71,10 @@ class Transport(object):
|
|||||||
def _parse_message(self, data):
|
def _parse_message(self, data):
|
||||||
(session_id, msg_type, data) = data
|
(session_id, msg_type, data) = data
|
||||||
|
|
||||||
# Raise exception if we get the response with
|
# Raise exception if we get the response with unexpected session ID
|
||||||
# unexpected session ID
|
if session_id != self.session_id:
|
||||||
self._check_session_id(session_id)
|
raise Exception("Session ID mismatch. Have %d, got %d" %
|
||||||
|
(self.session_id, session_id))
|
||||||
|
|
||||||
if msg_type == 'protobuf':
|
if msg_type == 'protobuf':
|
||||||
return data
|
return data
|
||||||
@ -82,14 +83,6 @@ class Transport(object):
|
|||||||
inst.ParseFromString(bytes(data))
|
inst.ParseFromString(bytes(data))
|
||||||
return inst
|
return inst
|
||||||
|
|
||||||
def _check_session_id(self, session_id):
|
|
||||||
if self.session_id == 0:
|
|
||||||
# Let the device set the session ID
|
|
||||||
self.session_id = session_id
|
|
||||||
elif session_id != self.session_id:
|
|
||||||
# Session ID has been already set, but it differs from response
|
|
||||||
raise Exception("Session ID mismatch. Have %d, got %d" % (self.session_id, session_id))
|
|
||||||
|
|
||||||
# Functions to be implemented in specific transports:
|
# Functions to be implemented in specific transports:
|
||||||
def _open(self):
|
def _open(self):
|
||||||
raise NotImplementedException("Not implemented")
|
raise NotImplementedException("Not implemented")
|
||||||
@ -237,6 +230,28 @@ class TransportV2(Transport):
|
|||||||
data = chunk[1 + headerlen:]
|
data = chunk[1 + headerlen:]
|
||||||
return (session_id, data)
|
return (session_id, data)
|
||||||
|
|
||||||
|
def parse_session(self, chunk):
|
||||||
|
if chunk[0:1] != b"!":
|
||||||
|
raise Exception("Unexpected magic character")
|
||||||
|
|
||||||
|
try:
|
||||||
|
headerlen = struct.calcsize(">LL")
|
||||||
|
(null_session_id, new_session_id) = struct.unpack(
|
||||||
|
">LL", bytes(chunk[1:1 + headerlen]))
|
||||||
|
except:
|
||||||
|
raise Exception("Cannot parse header")
|
||||||
|
|
||||||
|
if null_session_id != 0:
|
||||||
|
raise Exception("Session response needs to use session ID 0")
|
||||||
|
return new_session_id
|
||||||
|
|
||||||
|
def _session_begin(self):
|
||||||
|
self._write_chunk(b'!' + b'\0' * 63)
|
||||||
|
self.session_id = self.parse_session(self._read_chunk())
|
||||||
|
|
||||||
|
def _session_end(self):
|
||||||
|
pass
|
||||||
|
|
||||||
'''
|
'''
|
||||||
def read_headers(self, read_f):
|
def read_headers(self, read_f):
|
||||||
c = read_f.read(2)
|
c = read_f.read(2)
|
||||||
|
Loading…
Reference in New Issue
Block a user