trezor-agent/libagent/tests/test_server.py

136 lines
3.5 KiB
Python
Raw Normal View History

import functools
2016-01-09 14:06:47 +00:00
import io
import os
2015-08-12 17:53:11 +00:00
import socket
2016-01-09 14:06:47 +00:00
import tempfile
import threading
2016-01-09 14:06:47 +00:00
import mock
2016-01-09 14:06:47 +00:00
import pytest
2015-08-12 17:53:11 +00:00
from .. import server, util
from ..ssh import protocol
2015-07-21 11:16:13 +00:00
def test_socket():
2015-08-12 17:53:11 +00:00
path = tempfile.mktemp()
with server.unix_domain_socket_server(path):
pass
assert not os.path.isfile(path)
class FakeSocket(object):
2015-08-12 17:53:11 +00:00
def __init__(self, data=b''):
self.rx = io.BytesIO(data)
self.tx = io.BytesIO()
def sendall(self, data):
self.tx.write(data)
def recv(self, size):
return self.rx.read(size)
def close(self):
pass
def settimeout(self, value):
pass
2015-08-12 17:53:11 +00:00
def empty_device():
c = mock.Mock(spec=['parse_public_keys'])
c.parse_public_keys.return_value = []
return c
2015-08-12 17:53:11 +00:00
def test_handle():
2016-10-23 14:35:12 +00:00
mutex = threading.Lock()
handler = protocol.Handler(conn=empty_device())
conn = FakeSocket()
2016-10-23 14:35:12 +00:00
server.handle_connection(conn, handler, mutex)
2015-08-12 17:53:11 +00:00
2016-03-12 18:38:37 +00:00
msg = bytearray([protocol.msg_code('SSH_AGENTC_REQUEST_RSA_IDENTITIES')])
conn = FakeSocket(util.frame(msg))
2016-10-23 14:35:12 +00:00
server.handle_connection(conn, handler, mutex)
2015-08-12 17:53:11 +00:00
assert conn.tx.getvalue() == b'\x00\x00\x00\x05\x02\x00\x00\x00\x00'
2016-03-12 18:38:37 +00:00
msg = bytearray([protocol.msg_code('SSH2_AGENTC_REQUEST_IDENTITIES')])
conn = FakeSocket(util.frame(msg))
2016-10-23 14:35:12 +00:00
server.handle_connection(conn, handler, mutex)
2015-08-12 17:53:11 +00:00
assert conn.tx.getvalue() == b'\x00\x00\x00\x05\x0C\x00\x00\x00\x00'
2016-03-12 18:38:37 +00:00
msg = bytearray([protocol.msg_code('SSH2_AGENTC_ADD_IDENTITY')])
conn = FakeSocket(util.frame(msg))
2016-10-23 14:35:12 +00:00
server.handle_connection(conn, handler, mutex)
2016-03-12 18:38:37 +00:00
conn.tx.seek(0)
reply = util.read_frame(conn.tx)
assert reply == util.pack('B', protocol.msg_code('SSH_AGENT_FAILURE'))
conn_mock = mock.Mock(spec=FakeSocket)
conn_mock.recv.side_effect = [Exception, EOFError]
2016-10-23 14:35:12 +00:00
server.handle_connection(conn=conn_mock, handler=None, mutex=mutex)
2015-08-13 15:03:25 +00:00
2015-08-12 17:53:11 +00:00
def test_server_thread():
connections = [FakeSocket()]
quit_event = threading.Event()
2015-08-12 17:53:11 +00:00
class FakeServer(object):
def accept(self): # pylint: disable=no-self-use
if connections:
return connections.pop(), 'address'
quit_event.set()
raise socket.timeout()
2015-08-12 17:53:11 +00:00
def getsockname(self): # pylint: disable=no-self-use
return 'fake_server'
2015-08-12 17:53:11 +00:00
handler = protocol.Handler(conn=empty_device()),
handle_conn = functools.partial(server.handle_connection,
handler=handler,
mutex=None)
2016-02-19 09:19:01 +00:00
server.server_thread(sock=FakeServer(),
handle_conn=handle_conn,
quit_event=quit_event)
2015-08-12 17:53:11 +00:00
def test_spawn():
obj = []
def thread(x):
obj.append(x)
with server.spawn(thread, dict(x=1)):
2015-08-12 17:53:11 +00:00
pass
assert obj == [1]
def test_run():
assert server.run_process(['true'], environ={}) == 0
assert server.run_process(['false'], environ={}) == 1
assert server.run_process(command=['bash', '-c', 'exit $X'],
environ={'X': '42'}) == 42
2015-08-12 17:53:11 +00:00
with pytest.raises(OSError):
server.run_process([''], environ={})
2015-08-13 15:03:25 +00:00
def test_remove():
path = 'foo.bar'
def remove(p):
assert p == path
server.remove_file(path, remove=remove)
def remove_raise(_):
raise OSError('boom')
server.remove_file(path, remove=remove_raise, exists=lambda _: False)
with pytest.raises(OSError):
server.remove_file(path, remove=remove_raise, exists=lambda _: True)