diff --git a/trezorlib/client.py b/trezorlib/client.py index d17be03..552d73e 100644 --- a/trezorlib/client.py +++ b/trezorlib/client.py @@ -1,4 +1,5 @@ import os +import time import binascii import hashlib import unicodedata @@ -433,15 +434,106 @@ class ProtocolMixin(object): return msg @field('serialized_tx') - @expect(proto.TxRequest) + @expect(types.TxRequestSerializedType) def simple_sign_tx(self, coin_name, inputs, outputs): # TODO Deserialize tx and check if inputs/outputs fits msg = self._prepare_simple_sign_tx(coin_name, inputs, outputs) - return self.call(msg) + return self.call(msg).serialized + + def _prepare_sign_tx(self, coin_name, inputs, outputs): + tx = types.TransactionType() + tx.inputs.extend(inputs) + tx.outputs.extend (outputs) + + txes = {} + txes[''] = tx + + known_hashes = [] + for inp in inputs: + if inp.prev_hash in known_hashes: + continue + + txes[inp.prev_hash] = self.tx_api.get_tx(binascii.hexlify(inp.prev_hash)) + known_hashes.append(inp.prev_hash) + + return txes def sign_tx(self, coin_name, inputs, outputs): # Temporary solution, until streaming is implemented in the firmware - return self.simple_sign_tx(coin_name, inputs, outputs) + # return self.simple_sign_tx(coin_name, inputs, outputs) + + start = time.time() + txes = self._prepare_sign_tx(coin_name, inputs, outputs) + + try: + self.transport.session_begin() + + # Prepare and send initial message + tx = proto.SignTx() + tx.inputs_count = len(inputs) + tx.outputs_count = len(outputs) + tx.coin_name = coin_name + res = self.call(tx) + + # Prepare structure for signatures + signatures = [None] * len(inputs) + serialized_tx = '' + + counter = 0 + while True: + counter += 1 + + if isinstance(res, proto.Failure): + raise CallException("Signing failed") + + if not isinstance(res, proto.TxRequest): + raise CallException("Unexpected message") + + # If there's some part of signed transaction, let's add it + if res.HasField('serialized_tx'): + print "!!! RECEIVED PART OF SERIALIED TX (%d BYTES)" % len(res.serialized_tx) + serialized_tx += res.serialized_tx + + if res.HasField('signature_index') and res.HasField('signature'): + print "!!! SIGNED INPUT", res.signature_index + signatures[res.signature_index] = res.signature + + if res.request_type == types.TXFINISHED: + # Device didn't ask for more information, finish workflow + break + + # Device asked for one more information, let's process it. + current_tx = txes[res.tx_hash] + + if res.request_type == types.TXMETA: + print "REQUESTING META OF", binascii.hexlify(res.tx_hash) + msg = types.TransactionType() + msg.version = current_tx.version + msg.lock_time = current_tx.lock_time + res = self.call(proto.TxAck(tx=msg)) + continue + + elif res.request_type == types.TXINPUT: + print "REQUESTING INPUT", res.request_index, "OF", binascii.hexlify(res.tx_hash) + msg = types.TransactionType() + msg.inputs.extend([current_tx.inputs[res.request_index], ]) + res = self.call(proto.TxAck(tx=msg)) + continue + + elif res.request_type == types.TXOUTPUT: + print "REQUESTING OUTOUT", res.request_index, "OF", binascii.hexlify(res.tx_hash) + msg = types.TransactionType() + msg.outputs.extend([current_tx.outputs[res.request_index], ]) + res = self.call(proto.TxAck(tx=msg)) + continue + + finally: + self.transport.session_end() + + print "SIGNED IN %.03f SECONDS, CALLED %d MESSAGES, %d BYTES" % \ + (time.time() - start, counter, len(serialized_tx)) + + return (signatures, serialized_tx) @field('message') @expect(proto.Success)