trezor-agent/trezor_agent/server.py

155 lines
4.1 KiB
Python
Raw Normal View History

2016-02-19 09:19:01 +00:00
"""UNIX-domain socket server for ssh-agent implementation."""
2016-01-09 14:06:47 +00:00
import contextlib
import functools
2016-01-09 14:06:47 +00:00
import logging
2015-06-06 14:52:10 +00:00
import os
2016-01-09 14:06:47 +00:00
import socket
2015-06-06 14:52:10 +00:00
import subprocess
import tempfile
import threading
from . import util
2015-06-06 14:52:10 +00:00
2015-06-16 07:20:11 +00:00
log = logging.getLogger(__name__)
UNIX_SOCKET_TIMEOUT = 0.1
2015-06-16 07:03:48 +00:00
2015-08-13 15:03:25 +00:00
def remove_file(path, remove=os.remove, exists=os.path.exists):
2016-02-19 09:19:01 +00:00
"""Remove file, and raise OSError if still exists."""
2015-06-06 14:52:10 +00:00
try:
2015-08-13 15:03:25 +00:00
remove(path)
2015-06-06 14:52:10 +00:00
except OSError:
2015-08-13 15:03:25 +00:00
if exists(path):
2015-06-06 14:52:10 +00:00
raise
2015-08-13 15:03:25 +00:00
@contextlib.contextmanager
def unix_domain_socket_server(sock_path):
2016-02-19 09:19:01 +00:00
"""
Create UNIX-domain socket on specified path.
Listen on it, and delete it after the generated context is over.
"""
2015-08-13 15:03:25 +00:00
log.debug('serving on SSH_AUTH_SOCK=%s', sock_path)
remove_file(sock_path)
2015-06-06 14:52:10 +00:00
server = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
server.bind(sock_path)
server.listen(1)
try:
yield server
finally:
2015-08-13 15:03:25 +00:00
remove_file(sock_path)
2015-06-06 14:52:10 +00:00
2015-06-16 07:03:48 +00:00
def handle_connection(conn, handler):
2016-02-19 09:19:01 +00:00
"""
Handle a single connection using the specified protocol handler in a loop.
Exit when EOFError is raised.
All other exceptions are logged as warnings.
"""
2015-06-15 15:13:10 +00:00
try:
log.debug('welcome agent')
while True:
msg = util.read_frame(conn)
reply = handler.handle(msg=msg)
2015-06-15 15:13:10 +00:00
util.send(conn, reply)
except EOFError:
log.debug('goodbye agent')
except Exception as e: # pylint: disable=broad-except
log.warning('error: %s', e, exc_info=True)
2015-06-06 14:52:10 +00:00
2015-06-16 07:03:48 +00:00
def retry(func, exception_type, quit_event):
2016-02-19 09:19:01 +00:00
"""
Run the function, retrying when the specified exception_type occurs.
Poll quit_event on each iteration, to be responsive to an external
exit request.
"""
while True:
if quit_event.is_set():
raise StopIteration
try:
return func()
except exception_type:
pass
def server_thread(sock, handle_conn, quit_event):
2016-02-19 09:19:01 +00:00
"""Run a server on the specified socket."""
2015-06-15 06:16:47 +00:00
log.debug('server thread started')
def accept_connection():
2016-02-19 09:19:01 +00:00
conn, _ = sock.accept()
conn.settimeout(None)
return conn
2015-06-06 14:52:10 +00:00
while True:
2016-02-19 09:19:01 +00:00
log.debug('waiting for connection on %s', sock.getsockname())
2015-06-06 14:52:10 +00:00
try:
conn = retry(accept_connection, socket.timeout, quit_event)
except StopIteration:
log.debug('server stopped')
2015-06-06 14:52:10 +00:00
break
with contextlib.closing(conn):
handle_conn(conn)
2015-06-15 06:16:47 +00:00
log.debug('server thread stopped')
2015-06-06 14:52:10 +00:00
@contextlib.contextmanager
def spawn(func, kwargs):
2016-02-19 09:19:01 +00:00
"""Spawn a thread, and join it after the context is over."""
2015-06-06 14:52:10 +00:00
t = threading.Thread(target=func, kwargs=kwargs)
t.start()
yield
t.join()
@contextlib.contextmanager
def serve(handler, sock_path=None, timeout=UNIX_SOCKET_TIMEOUT):
2016-02-19 09:19:01 +00:00
"""
Start the ssh-agent server on a UNIX-domain socket.
If no connection is made during the specified timeout,
retry until the context is over.
"""
2015-06-06 14:52:10 +00:00
if sock_path is None:
sock_path = tempfile.mktemp(prefix='ssh-agent-')
environ = {'SSH_AUTH_SOCK': sock_path, 'SSH_AGENT_PID': str(os.getpid())}
2016-02-19 09:19:01 +00:00
with unix_domain_socket_server(sock_path) as sock:
sock.settimeout(timeout)
quit_event = threading.Event()
handle_conn = functools.partial(handle_connection, handler=handler)
kwargs = dict(sock=sock,
handle_conn=handle_conn,
quit_event=quit_event)
with spawn(server_thread, kwargs):
try:
yield environ
finally:
log.debug('closing server')
quit_event.set()
def run_process(command, environ):
2016-02-19 09:19:01 +00:00
"""
Run the specified process and wait until it finishes.
Use environ dict for environment variables.
"""
2016-01-08 14:04:57 +00:00
log.info('running %r with %r', command, environ)
env = dict(os.environ)
env.update(environ)
try:
p = subprocess.Popen(args=command, env=env)
except OSError as e:
raise OSError('cannot run %r: %s' % (command, e))
log.debug('subprocess %d is running', p.pid)
ret = p.wait()
log.debug('subprocess %d exited: %d', p.pid, ret)
2015-06-15 15:13:10 +00:00
return ret