119 lines
2.6 KiB
Python
119 lines
2.6 KiB
Python
import tempfile
|
|
import socket
|
|
import os
|
|
import io
|
|
import pytest
|
|
|
|
from .. import server
|
|
from .. import protocol
|
|
from .. import util
|
|
|
|
|
|
def test_socket():
|
|
path = tempfile.mktemp()
|
|
with server.unix_domain_socket_server(path):
|
|
pass
|
|
assert not os.path.isfile(path)
|
|
|
|
|
|
class SocketMock(object):
|
|
|
|
def __init__(self, data=b''):
|
|
self.rx = io.BytesIO(data)
|
|
self.tx = io.BytesIO()
|
|
|
|
def sendall(self, data):
|
|
self.tx.write(data)
|
|
|
|
def recv(self, size):
|
|
return self.rx.read(size)
|
|
|
|
def close(self):
|
|
pass
|
|
|
|
|
|
def test_handle():
|
|
handler = protocol.Handler(keys=[], signer=None)
|
|
conn = SocketMock()
|
|
server.handle_connection(conn, handler)
|
|
|
|
msg = bytearray([protocol.SSH_AGENTC_REQUEST_RSA_IDENTITIES])
|
|
conn = SocketMock(util.frame(msg))
|
|
server.handle_connection(conn, handler)
|
|
assert conn.tx.getvalue() == b'\x00\x00\x00\x05\x02\x00\x00\x00\x00'
|
|
|
|
msg = bytearray([protocol.SSH2_AGENTC_REQUEST_IDENTITIES])
|
|
conn = SocketMock(util.frame(msg))
|
|
server.handle_connection(conn, handler)
|
|
assert conn.tx.getvalue() == b'\x00\x00\x00\x05\x0C\x00\x00\x00\x00'
|
|
|
|
with pytest.raises(AttributeError):
|
|
server.handle_connection(conn=None, handler=None)
|
|
|
|
|
|
class ServerMock(object):
|
|
|
|
def __init__(self, connections, name):
|
|
self.connections = connections
|
|
self.name = name
|
|
|
|
def getsockname(self):
|
|
return self.name
|
|
|
|
def accept(self):
|
|
if self.connections:
|
|
return self.connections.pop(), 'address'
|
|
raise socket.error('stop')
|
|
|
|
|
|
def test_server_thread():
|
|
s = ServerMock(connections=[SocketMock()], name='mock')
|
|
h = protocol.Handler(keys=[], signer=None)
|
|
server.server_thread(s, h)
|
|
|
|
|
|
def test_spawn():
|
|
obj = []
|
|
|
|
def thread(x):
|
|
obj.append(x)
|
|
|
|
with server.spawn(thread, x=1):
|
|
pass
|
|
|
|
assert obj == [1]
|
|
|
|
|
|
def test_run():
|
|
assert server.run_process(['true'], environ={}) == 0
|
|
assert server.run_process(['false'], environ={}) == 1
|
|
assert server.run_process(
|
|
command='exit $X',
|
|
environ={'X': '42'},
|
|
use_shell=True) == 42
|
|
|
|
with pytest.raises(OSError):
|
|
server.run_process([''], environ={})
|
|
|
|
|
|
def test_serve_main():
|
|
with server.serve(public_keys=[], signer=None, sock_path=None):
|
|
pass
|
|
|
|
|
|
def test_remove():
|
|
path = 'foo.bar'
|
|
|
|
def remove(p):
|
|
assert p == path
|
|
|
|
server.remove_file(path, remove=remove)
|
|
|
|
def remove_raise(_):
|
|
raise OSError('boom')
|
|
|
|
server.remove_file(path, remove=remove_raise, exists=lambda _: False)
|
|
|
|
with pytest.raises(OSError):
|
|
server.remove_file(path, remove=remove_raise, exists=lambda _: True)
|