trezor-agent/libagent/util.py

216 lines
5.3 KiB
Python
Raw Normal View History

2016-02-19 09:34:20 +00:00
"""Various I/O and serialization utilities."""
import binascii
import contextlib
2017-02-21 11:05:48 +00:00
import functools
2015-06-15 15:13:10 +00:00
import io
import logging
2016-01-09 14:06:47 +00:00
import struct
2015-06-15 15:13:10 +00:00
log = logging.getLogger(__name__)
2015-06-16 07:03:48 +00:00
2016-02-19 09:34:20 +00:00
def send(conn, data):
"""Send data blob to connection socket."""
2015-06-15 15:13:10 +00:00
conn.sendall(data)
2015-06-16 07:03:48 +00:00
2015-06-15 15:13:10 +00:00
def recv(conn, size):
2016-02-19 09:34:20 +00:00
"""
Receive bytes from connection socket or stream.
If size is struct.calcsize()-compatible format, use it to unpack the data.
Otherwise, return the plain blob as bytes.
"""
2015-06-15 15:13:10 +00:00
try:
fmt = size
size = struct.calcsize(fmt)
except TypeError:
fmt = None
try:
_read = conn.recv
except AttributeError:
_read = conn.read
res = io.BytesIO()
while size > 0:
buf = _read(size)
if not buf:
raise EOFError
size = size - len(buf)
res.write(buf)
res = res.getvalue()
if fmt:
return struct.unpack(fmt, res)
else:
return res
def read_frame(conn):
2016-02-19 09:34:20 +00:00
"""Read size-prefixed frame from connection."""
2015-06-15 15:13:10 +00:00
size, = recv(conn, '>L')
return recv(conn, size)
2015-06-16 07:03:48 +00:00
2015-06-15 15:13:10 +00:00
def bytes2num(s):
2016-02-19 09:34:20 +00:00
"""Convert MSB-first bytes to an unsigned integer."""
2015-06-15 15:13:10 +00:00
res = 0
for i, c in enumerate(reversed(bytearray(s))):
res += c << (i * 8)
return res
2015-06-16 07:03:48 +00:00
2015-06-15 15:13:10 +00:00
def num2bytes(value, size):
2016-02-19 09:34:20 +00:00
"""Convert an unsigned integer to MSB-first bytes with specified size."""
2015-06-15 15:13:10 +00:00
res = []
2015-06-16 07:33:48 +00:00
for _ in range(size):
2015-06-15 15:13:10 +00:00
res.append(value & 0xFF)
value = value >> 8
assert value == 0
2016-04-22 18:43:54 +00:00
return bytes(bytearray(list(reversed(res))))
2015-06-15 15:13:10 +00:00
2015-06-16 07:03:48 +00:00
2015-06-15 15:13:10 +00:00
def pack(fmt, *args):
2016-02-19 09:34:20 +00:00
"""Serialize MSB-first message."""
2015-06-15 15:13:10 +00:00
return struct.pack('>' + fmt, *args)
2015-06-16 07:03:48 +00:00
2015-06-15 15:13:10 +00:00
def frame(*msgs):
2016-02-19 09:34:20 +00:00
"""Serialize MSB-first length-prefixed frame."""
2015-06-15 15:13:10 +00:00
res = io.BytesIO()
for msg in msgs:
res.write(msg)
msg = res.getvalue()
return pack('L', len(msg)) + msg
2016-04-23 19:08:18 +00:00
def crc24(blob):
2016-04-24 09:22:02 +00:00
"""See https://tools.ietf.org/html/rfc4880#section-6.1 for details."""
CRC24_INIT = 0x0B704CE
CRC24_POLY = 0x1864CFB
2016-04-23 19:08:18 +00:00
crc = CRC24_INIT
for octet in bytearray(blob):
crc ^= (octet << 16)
for _ in range(8):
crc <<= 1
if crc & 0x1000000:
crc ^= CRC24_POLY
assert 0 <= crc < 0x1000000
crc_bytes = struct.pack('>L', crc)
2016-05-01 18:50:48 +00:00
assert crc_bytes[:1] == b'\x00'
2016-04-23 19:08:18 +00:00
return crc_bytes[1:]
def bit(value, i):
"""Extract the i-th bit out of value."""
return 1 if value & (1 << i) else 0
def low_bits(value, n):
"""Extract the lowest n bits out of value."""
return value & ((1 << n) - 1)
def split_bits(value, *bits):
"""
Split integer value into list of ints, according to `bits` list.
For example, split_bits(0x1234, 4, 8, 4) == [0x1, 0x23, 0x4]
"""
result = []
for b in reversed(bits):
mask = (1 << b) - 1
result.append(value & mask)
value = value >> b
assert value == 0
2016-05-01 18:50:48 +00:00
result.reverse()
return result
def readfmt(stream, fmt):
"""Read and unpack an object from stream, using a struct format string."""
size = struct.calcsize(fmt)
blob = stream.read(size)
return struct.unpack(fmt, blob)
def prefix_len(fmt, blob):
"""Prefix `blob` with its size, serialized using `fmt` format."""
return struct.pack(fmt, len(blob)) + blob
def hexlify(blob):
"""Utility for consistent hexadecimal formatting."""
return binascii.hexlify(blob).decode('ascii').upper()
class Reader(object):
"""Read basic type objects out of given stream."""
def __init__(self, stream):
"""Create a non-capturing reader."""
self.s = stream
self._captured = None
def readfmt(self, fmt):
"""Read a specified object, using a struct format string."""
size = struct.calcsize(fmt)
blob = self.read(size)
obj, = struct.unpack(fmt, blob)
return obj
def read(self, size=None):
"""Read `size` bytes from stream."""
blob = self.s.read(size)
if size is not None and len(blob) < size:
raise EOFError
if self._captured:
self._captured.write(blob)
return blob
@contextlib.contextmanager
def capture(self, stream):
"""Capture all data read during this context."""
self._captured = stream
try:
yield
finally:
self._captured = None
2017-05-03 18:30:33 +00:00
def setup_logging(verbosity, filename=None):
"""Configure logging for this tool."""
levels = [logging.WARNING, logging.INFO, logging.DEBUG]
level = levels[min(verbosity, len(levels) - 1)]
2017-05-03 18:30:33 +00:00
logging.root.setLevel(level)
fmt = logging.Formatter('%(asctime)s %(levelname)-12s %(message)-100s '
'[%(filename)s:%(lineno)d]')
hdlr = logging.StreamHandler() # stderr
hdlr.setFormatter(fmt)
logging.root.addHandler(hdlr)
if filename:
hdlr = logging.FileHandler(filename, 'a')
hdlr.setFormatter(fmt)
logging.root.addHandler(hdlr)
2017-02-21 11:05:48 +00:00
def memoize(func):
"""Simple caching decorator."""
cache = {}
@functools.wraps(func)
def wrapper(*args, **kwargs):
"""Caching wrapper."""
key = (args, tuple(sorted(kwargs.items())))
if key in cache:
return cache[key]
else:
result = func(*args, **kwargs)
cache[key] = result
return result
return wrapper