diff --git a/trezorlib/transport.py b/trezorlib/transport.py index 9ff25df..dc59619 100644 --- a/trezorlib/transport.py +++ b/trezorlib/transport.py @@ -9,34 +9,12 @@ class ConnectionError(Exception): class Transport(object): def __init__(self, device, *args, **kwargs): + print("Transport constructor") self.device = device + self.session_id = 0 self.session_depth = 0 self._open() - def _open(self): - raise NotImplementedException("Not implemented") - - def _close(self): - raise NotImplementedException("Not implemented") - - def _write(self, msg, protobuf_msg): - raise NotImplementedException("Not implemented") - - def _read(self): - raise NotImplementedException("Not implemented") - - def _session_begin(self): - pass - - def _session_end(self): - pass - - def ready_to_read(self): - """ - Returns True if there is data to be read from the transport. Otherwise, False. - """ - raise NotImplementedException("Not implemented") - def session_begin(self): """ Apply a lock to the device in order to preform synchronous multistep "conversations" with the device. For example, before entering the transaction signing workflow, one begins a session. After the transaction is complete, the session may be ended. @@ -64,16 +42,14 @@ class Transport(object): """ Write mesage to tansport. msg should be a member of a valid `protobuf class `_ with a SerializeToString() method. """ - ser = msg.SerializeToString() - header = struct.pack(">HL", mapping.get_type(msg), len(ser)) - self._write(b"##" + header + ser, msg) + raise NotImplementedException("Not implemented") def read(self): """ If there is data available to be read from the transport, reads the data and tries to parse it as a protobuf message. If the parsing succeeds, return a protobuf object. Otherwise, returns None. """ - if not self.ready_to_read(): + if not self._ready_to_read(): return None data = self._read() @@ -84,7 +60,7 @@ class Transport(object): def read_blocking(self): """ - Same as read, except blocks untill data is available to be read. + Same as read, except blocks until data is available to be read. """ while True: data = self._read() @@ -94,39 +70,115 @@ class Transport(object): return self._parse_message(data) def _parse_message(self, data): - (msg_type, data) = data + (session_id, msg_type, data) = data + + # Raise exception if we get the response with + # unexpected session ID + self._check_session_id(session_id) + if msg_type == 'protobuf': return data else: + print mapping.get_class(msg_type) inst = mapping.get_class(msg_type)() - inst.ParseFromString(data) + print inst, data + inst.ParseFromString(bytes(data)) return inst - def _read_headers(self, read_f): - # Try to read headers until some sane value are detected - is_ok = False - while not is_ok: + def _check_session_id(self, session_id): + if self.session_id == 0: + # Let the device set the session ID + self.session_id = session_id + elif session_id != self.session_id: + # Session ID has been already set, but it differs from response + raise Exception("Session ID mismatch. Have %d, got %d" % (self.session_id, session_id)) - # Align cursor to the beginning of the header ("##") - c = read_f.read(1) - i = 0 - while c != b"#": - i += 1 - if i >= 64: - # timeout - raise Exception("Timed out while waiting for the magic character") - c = read_f.read(1) + # Functions to be implemented in specific transports: + def _open(self): + raise NotImplementedException("Not implemented") - if read_f.read(1) != b"#": - # Second character must be # to be valid header - raise Exception("Second magic character is broken") + def _close(self): + raise NotImplementedException("Not implemented") - # Now we're most likely on the beginning of the header - try: - headerlen = struct.calcsize(">HL") - (msg_type, datalen) = struct.unpack(">HL", read_f.read(headerlen)) - break - except: - raise Exception("Cannot parse header length") + def _write_chunk(self, chunk): + raise NotImplementedException("Not implemented") - return (msg_type, datalen) + def _read_chunk(self): + raise NotImplementedException("Not implemented") + + def _ready_to_read(self): + """ + Returns True if there is data to be read from the transport. Otherwise, False. + """ + raise NotImplementedException("Not implemented") + + def _session_begin(self): + pass + + def _session_end(self): + pass + +class TransportV1(Transport): + def write(self, msg): + ser = msg.SerializeToString() + header = struct.pack(">HL", mapping.get_type(msg), len(ser)) + data = bytearray(b"##" + header + ser) + + while len(data): + # Report ID, data padded to 63 bytes + chunk = b'?' + data[:63] + b'\0' * (63 - len(data[:63])) + self._write_chunk(chunk) + data = data[63:] + + def _read(self): + chunk = self._read_chunk() + (msg_type, datalen, data) = self.parse_first(chunk) + + while len(data) < datalen: + chunk = self._read_chunk() + data.extend(self.parse_next(chunk)) + + # Strip padding zeros + data = data[:datalen] + return (0, msg_type, data) + + def parse_first(self, chunk): + if chunk[:3] != b"?##": + raise Exception("Unexpected magic characters") + + try: + headerlen = struct.calcsize(">HL") + (msg_type, datalen) = struct.unpack(">HL", chunk[3:3 + headerlen]) + except: + raise Exception("Cannot parse header length") + + data = chunk[3 + headerlen:] + return (msg_type, datalen, data) + + def parse_next(self, chunk): + if chunk[0:1] != b"?": + raise Exception("Unexpected magic characters") + + return chunk[1:] + +class TransportV2(Transport): + def write(self, msg): + ser = msg.SerializeToString() + raise NotImplemented() + + def _read(self): + pass + + def read_headers(self, read_f): + c = read_f.read(2) + if c != b"?!": + raise Exception("Unexpected magic characters") + + try: + headerlen = struct.calcsize(">HL") + (session_id, msg_type, datalen) = struct.unpack(">LLL", read_f.read(headerlen)) + except: + raise Exception("Cannot parse header length") + + print datalen + return (0, msg_type, datalen) diff --git a/trezorlib/transport_hid.py b/trezorlib/transport_hid.py index c920945..02b55f9 100644 --- a/trezorlib/transport_hid.py +++ b/trezorlib/transport_hid.py @@ -2,56 +2,57 @@ import hid import time -from .transport import Transport, ConnectionError +from .transport import TransportV1, TransportV2, ConnectionError -DEVICE_IDS = [ - (0x534c, 0x0001), # TREZOR -] +def enumerate(): + """ + Return a list of available TREZOR devices. + """ + devices = {} + for d in hid.enumerate(0, 0): + vendor_id = d['vendor_id'] + product_id = d['product_id'] + serial_number = d['serial_number'] + interface_number = d['interface_number'] + path = d['path'] -class FakeRead(object): - # Let's pretend we have a file-like interface - def __init__(self, func): - self.func = func + # HIDAPI on Mac cannot detect correct HID interfaces, so device with + # DebugLink doesn't work on Mac... + if devices.get(serial_number) != None and devices[serial_number][0] == path: + raise Exception("Two devices with the same path and S/N found. This is Mac, right? :-/") - def read(self, size): - return self.func(size) + if (vendor_id, product_id) in [ x[0:2] for x in DEVICE_IDS]: + devices.setdefault(serial_number, [None, None]) + if interface_number == 0 or interface_number == -1: # normal link + devices[serial_number][0] = path + elif interface_number == 1: # debug link + devices[serial_number][1] = path -class HidTransport(Transport): + # List of two-tuples (path_normal, path_debuglink) + return list(devices.values()) + +def path_to_transport(path): + try: + device = [ d for d in hid.enumerate(0, 0) if d['path'] == path ][0] + except IndexError: + raise ConnectionError("Connection failed") + + # VID/PID found, let's find proper transport + vid, pid = device['vendor_id'], device['product_id'] + try: + transport = [ transport for (_vid, _pid, transport) in DEVICE_IDS if _vid == vid and _pid == pid ][0] + except IndexError: + raise Exception("Unknown transport for VID:PID %04x:%04x" % (vid, pid)) + + return transport + +class _HidTransport(object): def __init__(self, device, *args, **kwargs): self.hid = None self.hid_version = None - self.buffer = '' - # self.read_timeout = kwargs.get('read_timeout') + device = device[int(bool(kwargs.get('debug_link')))] - super(HidTransport, self).__init__(device, *args, **kwargs) - - @classmethod - def enumerate(cls): - """ - Return a list of available TREZOR devices. - """ - devices = {} - for d in hid.enumerate(0, 0): - vendor_id = d['vendor_id'] - product_id = d['product_id'] - serial_number = d['serial_number'] - interface_number = d['interface_number'] - path = d['path'] - - # HIDAPI on Mac cannot detect correct HID interfaces, so device with - # DebugLink doesn't work on Mac... - if devices.get(serial_number) != None and devices[serial_number][0] == path: - raise Exception("Two devices with the same path and S/N found. This is Mac, right? :-/") - - if (vendor_id, product_id) in DEVICE_IDS: - devices.setdefault(serial_number, [None, None]) - if interface_number == 0 or interface_number == -1: # normal link - devices[serial_number][0] = path - elif interface_number == 1: # debug link - devices[serial_number][1] = path - - # List of two-tuples (path_normal, path_debuglink) - return list(devices.values()) + super(_HidTransport, self).__init__(device, *args, **kwargs) def is_connected(self): """ @@ -63,10 +64,10 @@ class HidTransport(Transport): return False def _open(self): - self.buffer = bytearray() self.hid = hid.device() self.hid.open_path(self.device) self.hid.set_nonblocking(True) + # determine hid_version r = self.hid.write([0, 63, ] + [0xFF] * 63) if r == 65: @@ -80,28 +81,21 @@ class HidTransport(Transport): def _close(self): self.hid.close() - self.buffer = bytearray() self.hid = None - def ready_to_read(self): - return False + def _write_chunk(self, chunk): + if len(chunk) != 64: + raise Exception("Unexpected data length") - def _write(self, msg, protobuf_msg): - msg = bytearray(msg) - while len(msg): - if self.hid_version == 2: - self.hid.write([0, 63, ] + list(msg[:63]) + [0] * (63 - len(msg[:63]))) - else: - self.hid.write([63, ] + list(msg[:63]) + [0] * (63 - len(msg[:63]))) - msg = msg[63:] + if self.hid_version == 2: + self.hid.write([0,] + chunk) + else: + self.hid.write(chunk) - def _read(self): - (msg_type, datalen) = self._read_headers(FakeRead(self._raw_read)) - return (msg_type, self._raw_read(datalen)) - - def _raw_read(self, length): + def _read_chunk(self): start = time.time() - while len(self.buffer) < length: + + while True: data = self.hid.read(64) if not len(data): if time.time() - start > 10: @@ -109,22 +103,37 @@ class HidTransport(Transport): # device is still alive if not self.is_connected(): raise ConnectionError("Connection failed") - else: - # Restart timer - start = time.time() + + # Restart timer + start = time.time() time.sleep(0.001) continue - report_id = data[0] + break - if report_id > 63: - # Command report - raise Exception("Not implemented") + if len(data) != 64: + raise Exception("Unexpected chunk size: %d" % len(data)) - # Payload received, skip the report ID - self.buffer.extend(bytearray(data[1:])) + return bytearray(data) - ret = self.buffer[:length] - self.buffer = self.buffer[length:] - return bytes(ret) +class HidTransportV1(_HidTransport, TransportV1): + pass + +class HidTransportV2(_HidTransport, TransportV2): + pass + +DEVICE_IDS = [ + (0x534c, 0x0001, HidTransportV1), # TREZOR + (0x1209, 0x53C0, HidTransportV2), # TREZORv2 Bootloader + (0x1209, 0x53C1, HidTransportV2), # TREZORv2 +] + +# Backward compatible wrapper, decides for proper transport +# based on VID/PID of given path +def HidTransport(device, *args, **kwargs): + transport = path_to_transport(device[0]) + return transport(device, *args, **kwargs) + +# Backward compatibility hack; HidTransport is a function, not a class like before +HidTransport.enumerate = enumerate diff --git a/trezorlib/transport_udp.py b/trezorlib/transport_udp.py index eaf4b1b..9c7dd5c 100644 --- a/trezorlib/transport_udp.py +++ b/trezorlib/transport_udp.py @@ -43,16 +43,18 @@ class UdpTransport(Transport): rlist, _, _ = select([self.socket], [], [], 0) return len(rlist) > 0 + def _write_chunk(self, chunk): + if len(chunk) != 64: + raise Exception("Unexpected data length") + + self.socket.sendall(chunk) + def _write(self, msg, protobuf_msg): - msg = bytearray(msg) - while len(msg): - # Report ID, data padded to 63 bytes - self.socket.sendall(chr(63) + msg[:63] + b'\0' * (63 - len(msg[:63]))) - msg = msg[63:] + raise NotImplemented() def _read(self): - (msg_type, datalen) = self._read_headers(FakeRead(self._raw_read)) - return (msg_type, self._raw_read(datalen)) + (session_id, msg_type, datalen) = self._read_headers(FakeRead(self._raw_read)) + return (session_id, msg_type, self._raw_read(datalen)) def _raw_read(self, length): start = time.time() @@ -71,14 +73,7 @@ class UdpTransport(Transport): time.sleep(0.001) continue - report_id = ord(data[0]) - - if report_id > 63: - # Command report - raise Exception("Not implemented") - - # Payload received, skip the report ID - self.buffer += str(bytearray(data[1:])) + self.buffer += data ret = self.buffer[:length] self.buffer = self.buffer[length:]