From e254565901bd443b25ad81d16a752f9594eb1cea Mon Sep 17 00:00:00 2001 From: Daniel Pavel Date: Mon, 10 Jun 2019 20:19:27 +0300 Subject: [PATCH] support binding the http server to a unix socket file instead of TCP socket --- cps.py | 5 +- cps/__init__.py | 6 +- cps/about.py | 5 +- cps/admin.py | 12 +-- cps/logger.py | 61 ++++++++------- cps/server.py | 201 +++++++++++++++++++++++++----------------------- cps/ub.py | 34 ++++---- cps/updater.py | 5 +- 8 files changed, 165 insertions(+), 164 deletions(-) diff --git a/cps.py b/cps.py index c5ac1e7d..16d84957 100755 --- a/cps.py +++ b/cps.py @@ -28,8 +28,8 @@ sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'vendor from cps import create_app +from cps import web_server from cps.opds import opds -from cps import Server from cps.web import web from cps.jinjia import jinjia from cps.about import about @@ -56,7 +56,8 @@ def main(): app.register_blueprint(editbook) if oauth_available: app.register_blueprint(oauth) - Server.startServer() + success = web_server.start() + sys.exit(0 if success else 1) if __name__ == '__main__': diff --git a/cps/__init__.py b/cps/__init__.py index 1d3d598d..50bd1781 100755 --- a/cps/__init__.py +++ b/cps/__init__.py @@ -84,8 +84,8 @@ searched_ids = {} from .worker import WorkerThread global_WorkerThread = WorkerThread() -from .server import server -Server = server() +from .server import WebServer +web_server = WebServer() from .ldap import Ldap ldap = Ldap() @@ -103,7 +103,7 @@ def create_app(): Principal(app) lm.init_app(app) app.secret_key = os.getenv('SECRET_KEY', 'A0Zr98j/3yX R~XHH!jmN]LWX/,?RT') - Server.init_app(app) + web_server.init_app(app, config) db.setup_db() babel.init_app(app) ldap.init_app(app) diff --git a/cps/about.py b/cps/about.py index bc7b0e8a..162a4191 100644 --- a/cps/about.py +++ b/cps/about.py @@ -41,8 +41,9 @@ from jinja2 import __version__ as jinja2Version from pytz import __version__ as pytzVersion from sqlalchemy import __version__ as sqlalchemyVersion -from . import db, converter, Server, uploader +from . import db, converter, uploader from .isoLanguages import __version__ as iso639Version +from .server import VERSION as serverVersion from .web import render_title_template @@ -71,7 +72,7 @@ def stats(): versions['pySqlite'] = 'v' + db.engine.dialect.dbapi.version versions['Sqlite'] = 'v' + db.engine.dialect.dbapi.sqlite_version versions.update(converter.versioncheck()) - versions.update(Server.getNameVersion()) + versions.update(serverVersion) versions['Python'] = sys.version return render_title_template('stats.html', bookcounter=counter, authorcounter=authors, versions=versions, categorycounter=categorys, seriecounter=series, title=_(u"Statistics"), page="stat") diff --git a/cps/admin.py b/cps/admin.py index f6fd838f..a4473fdd 100644 --- a/cps/admin.py +++ b/cps/admin.py @@ -41,7 +41,7 @@ from sqlalchemy.exc import IntegrityError from werkzeug.security import generate_password_hash from . import constants, logger, ldap -from . import db, ub, Server, get_locale, config, updater_thread, babel, gdriveutils +from . import db, ub, web_server, get_locale, config, updater_thread, babel, gdriveutils from .helper import speaking_language, check_valid_domain, check_unrar, send_test_mail, generate_random_password, \ send_registration_mail from .gdriveutils import is_gdrive_ready, gdrive_support, downloadFile, deleteDatabaseOnChange, listRootFolders @@ -102,12 +102,10 @@ def shutdown(): showtext = {} if task == 0: showtext['text'] = _(u'Server restarted, please reload page') - Server.setRestartTyp(True) else: showtext['text'] = _(u'Performing shutdown of server, please close window') - Server.setRestartTyp(False) # stop gevent/tornado server - Server.stopServer() + web_server.stop(task == 0) return json.dumps(showtext) else: if task == 2: @@ -220,8 +218,7 @@ def view_configuration(): # ub.session.close() # ub.engine.dispose() # stop Server - Server.setRestartTyp(True) - Server.stopServer() + web_server.stop(True) log.info('Reboot required, restarting') readColumn = db.session.query(db.Custom_Columns)\ .filter(and_(db.Custom_Columns.datatype == 'bool',db.Custom_Columns.mark_for_delete == 0)).all() @@ -554,8 +551,7 @@ def configuration_helper(origin): title=_(u"Basic Configuration"), page="config") if reboot_required: # stop Server - Server.setRestartTyp(True) - Server.stopServer() + web_server.stop(True) log.info('Reboot required, restarting') if origin: success = True diff --git a/cps/logger.py b/cps/logger.py index c249cad7..408a9c02 100644 --- a/cps/logger.py +++ b/cps/logger.py @@ -25,6 +25,7 @@ from logging.handlers import RotatingFileHandler from .constants import BASE_DIR as _BASE_DIR + ACCESS_FORMATTER_GEVENT = Formatter("%(message)s") ACCESS_FORMATTER_TORNADO = Formatter("[%(asctime)s] %(message)s") @@ -33,7 +34,6 @@ DEFAULT_LOG_LEVEL = logging.INFO DEFAULT_LOG_FILE = os.path.join(_BASE_DIR, "calibre-web.log") DEFAULT_ACCESS_LOG = os.path.join(_BASE_DIR, "access.log") LOG_TO_STDERR = '/dev/stderr' -DEFAULT_ACCESS_LEVEL= logging.INFO logging.addLevelName(logging.WARNING, "WARN") logging.addLevelName(logging.CRITICAL, "CRIT") @@ -73,35 +73,26 @@ def is_valid_logfile(file_path): return (not log_dir) or os.path.isdir(log_dir) -def setup(log_file, log_level=None, logger=None): - if logger != "access" and logger != "tornado.access": - formatter = FORMATTER - default_file = DEFAULT_LOG_FILE - else: - if logger == "tornado.access": - formatter = ACCESS_FORMATTER_TORNADO - else: - formatter = ACCESS_FORMATTER_GEVENT - default_file = DEFAULT_ACCESS_LOG +def _absolute_log_file(log_file, default_log_file): if log_file: if not os.path.dirname(log_file): log_file = os.path.join(_BASE_DIR, log_file) - log_file = os.path.abspath(log_file) - else: - log_file = LOG_TO_STDERR - # log_file = default_file + return os.path.abspath(log_file) - # print ('%r -- %r' % (log_level, log_file)) - if logger != "access" and logger != "tornado.access": - r = logging.root - else: - r = logging.getLogger(logger) - r.propagate = False + return default_log_file + + +def setup(log_file, log_level=None): + ''' + Configure the logging output. + May be called multiple times. + ''' + log_file = _absolute_log_file(log_file, DEFAULT_LOG_FILE) + + r = logging.root r.setLevel(log_level or DEFAULT_LOG_LEVEL) previous_handler = r.handlers[0] if r.handlers else None - # print ('previous %r' % previous_handler) - if previous_handler: # if the log_file has not changed, don't create a new handler if getattr(previous_handler, 'baseFilename', None) == log_file: @@ -115,16 +106,32 @@ def setup(log_file, log_level=None, logger=None): try: file_handler = RotatingFileHandler(log_file, maxBytes=50000, backupCount=2) except IOError: - if log_file == default_file: + if log_file == DEFAULT_LOG_FILE: raise - file_handler = RotatingFileHandler(default_file, maxBytes=50000, backupCount=2) - file_handler.setFormatter(formatter) + file_handler = RotatingFileHandler(DEFAULT_LOG_FILE, maxBytes=50000, backupCount=2) + file_handler.setFormatter(FORMATTER) for h in r.handlers: r.removeHandler(h) h.close() r.addHandler(file_handler) - # print ('new handler %r' % file_handler) + + +def create_access_log(log_file, log_name, formatter): + ''' + One-time configuration for the web server's access log. + ''' + log_file = _absolute_log_file(log_file, DEFAULT_ACCESS_LOG) + logging.debug("access log: %s", log_file) + + access_log = logging.getLogger(log_name) + access_log.propagate = False + access_log.setLevel(logging.INFO) + + file_handler = RotatingFileHandler(log_file, maxBytes=50000, backupCount=2) + file_handler.setFormatter(formatter) + access_log.addHandler(file_handler) + return access_log # Enable logging of smtp lib debug output diff --git a/cps/server.py b/cps/server.py index 88c54c78..16b20549 100644 --- a/cps/server.py +++ b/cps/server.py @@ -20,54 +20,55 @@ from __future__ import division, print_function, unicode_literals import sys import os +import errno import signal import socket -import logging try: from gevent.pywsgi import WSGIServer from gevent.pool import Pool - from gevent import __version__ as geventVersion - gevent_present = True + from gevent import __version__ as _version + VERSION = {'Gevent': 'v' + _version} + _GEVENT = True except ImportError: from tornado.wsgi import WSGIContainer from tornado.httpserver import HTTPServer from tornado.ioloop import IOLoop - from tornado import version as tornadoVersion - from tornado import log as tornadoLog - from tornado import options as tornadoOptions - gevent_present = False + from tornado import version as _version + VERSION = {'Tornado': 'v' + _version} + _GEVENT = False -from . import logger, config, global_WorkerThread +from . import logger, global_WorkerThread log = logger.create() -class server: - - wsgiserver = None - restart = False - app = None - access_logger = None +class WebServer: def __init__(self): - signal.signal(signal.SIGINT, self.killServer) - signal.signal(signal.SIGTERM, self.killServer) + signal.signal(signal.SIGINT, self._killServer) + signal.signal(signal.SIGTERM, self._killServer) - def init_app(self, application): - self.app = application - self.port = config.config_port - self.listening = config.get_config_ipaddress(readable=True) + ":" + str(self.port) + self.wsgiserver = None self.access_logger = None + self.restart = False + self.app = None + self.listen_address = None + self.listen_port = None + self.unix_socket_file = None + self.ssl_args = None + + def init_app(self, application, config): + self.app = application + self.listen_address = config.get_config_ipaddress() + self.listen_port = config.config_port + if config.config_access_log: - if gevent_present: - logger.setup(config.config_access_logfile, logger.DEFAULT_ACCESS_LEVEL, "access") - self.access_logger = logging.getLogger("access") - else: - logger.setup(config.config_access_logfile, logger.DEFAULT_ACCESS_LEVEL, "tornado.access") + log_name = "gevent.access" if _GEVENT else "tornado.access" + formatter = logger.ACCESS_FORMATTER_GEVENT if _GEVENT else logger.ACCESS_FORMATTER_TORNADO + self.access_logger = logger.create_access_log(config.config_access_logfile, log_name, formatter) - self.ssl_args = None certfile_path = config.get_config_certfile() keyfile_path = config.get_config_keyfile() if certfile_path and keyfile_path: @@ -79,22 +80,46 @@ class server: log.warning('Cert path: %s', certfile_path) log.warning('Key path: %s', keyfile_path) + def _make_gevent_unix_socket(self, socket_file): + # the socket file must not exist prior to bind() + if os.path.exists(socket_file): + # avoid nuking regular files and symbolic links (could be a mistype or security issue) + if os.path.isfile(socket_file) or os.path.islink(socket_file): + raise OSError(errno.EEXIST, os.strerror(errno.EEXIST), socket_file) + os.remove(socket_file) + + unix_sock = WSGIServer.get_listener(socket_file, family=socket.AF_UNIX) + self.unix_socket_file = socket_file + + # ensure current user and group have r/w permissions, no permissions for other users + # this way the socket can be shared in a semi-secure manner + # between the user running calibre-web and the user running the fronting webserver + os.chmod(socket_file, 0o660); + + return unix_sock + def _make_gevent_socket(self): - if config.get_config_ipaddress(): - return (config.get_config_ipaddress(), self.port) + if os.name != 'nt': + unix_socket_file = os.environ.get("CALIBRE_UNIX_SOCKET") + if unix_socket_file: + return self._make_gevent_unix_socket(unix_socket_file) + + if self.listen_address: + return (self.listen_address, self.listen_port) + if os.name == 'nt': - return ('0.0.0.0', self.port) + return ('0.0.0.0', self.listen_port) + address = ('', self.listen_port) try: - s = WSGIServer.get_listener(('', self.port), family=socket.AF_INET6) + sock = WSGIServer.get_listener(address, family=socket.AF_INET6) except socket.error as ex: log.error('%s', ex) log.warning('Unable to listen on \'\', trying on IPv4 only...') - s = WSGIServer.get_listener(('', self.port), family=socket.AF_INET) - log.debug("%r %r", s._sock, s._sock.getsockname()) - return s + sock = WSGIServer.get_listener(address, family=socket.AF_INET) + return sock - def start_gevent(self): + def _start_gevent(self): ssl_args = self.ssl_args or {} log.info('Starting Gevent server') @@ -102,78 +127,58 @@ class server: sock = self._make_gevent_socket() self.wsgiserver = WSGIServer(sock, self.app, log=self.access_logger, spawn=Pool(), **ssl_args) self.wsgiserver.serve_forever() - except socket.error: - try: - log.info('Unable to listen on "", trying on "0.0.0.0" only...') - self.wsgiserver = WSGIServer(('0.0.0.0', config.config_port), self.app, spawn=Pool(), **ssl_args) - self.wsgiserver.serve_forever() - except (OSError, socket.error) as e: - log.info("Error starting server: %s", e.strerror) - print("Error starting server: %s" % e.strerror) - global_WorkerThread.stop() - sys.exit(1) - except Exception: - log.exception("Unknown error while starting gevent") - sys.exit(0) - - def start_tornado(self): - log.info('Starting Tornado server on %s', self.listening) - + finally: + if self.unix_socket_file: + os.remove(self.unix_socket_file) + self.unix_socket_file = None + + def _start_tornado(self): + log.info('Starting Tornado server on %s', self.listen_address) + + # Max Buffersize set to 200MB ) + http_server = HTTPServer(WSGIContainer(self.app), + max_buffer_size = 209700000, + ssl_options=self.ssl_args) + http_server.listen(self.listen_port, self.listen_address) + self.wsgiserver=IOLoop.instance() + self.wsgiserver.start() + # wait for stop signal + self.wsgiserver.close(True) + + def start(self): try: - # Max Buffersize set to 200MB ) - http_server = HTTPServer(WSGIContainer(self.app), - max_buffer_size = 209700000, - ssl_options=self.ssl_args) - address = config.get_config_ipaddress() - http_server.listen(self.port, address) - self.wsgiserver=IOLoop.instance() - self.wsgiserver.start() - # wait for stop signal - self.wsgiserver.close(True) - except socket.error as err: - log.exception("Error starting tornado server") - print("Error starting server: %s" % err.strerror) + if _GEVENT: + # leave subprocess out to allow forking for fetchers and processors + self._start_gevent() + else: + self._start_tornado() + except Exception as ex: + log.error("Error starting server: %s", ex) + print("Error starting server: %s" % ex) + return False + finally: + self.wsgiserver = None global_WorkerThread.stop() - sys.exit(1) - - def startServer(self): - if gevent_present: - # leave subprocess out to allow forking for fetchers and processors - self.start_gevent() - else: - self.start_tornado() - if self.restart is True: - log.info("Performing restart of Calibre-Web") - global_WorkerThread.stop() - if os.name == 'nt': - arguments = ["\"" + sys.executable + "\""] - for e in sys.argv: - arguments.append("\"" + e + "\"") - os.execv(sys.executable, arguments) - else: - os.execl(sys.executable, sys.executable, *sys.argv) - else: + if not self.restart: log.info("Performing shutdown of Calibre-Web") - global_WorkerThread.stop() - sys.exit(0) + return True - def setRestartTyp(self,starttyp): - self.restart = starttyp + log.info("Performing restart of Calibre-Web") + arguments = list(sys.argv) + arguments.insert(0, sys.executable) + if os.name == 'nt': + arguments = ["\"%s\"" % a for a in arguments] + os.execv(sys.executable, arguments) + return True - def killServer(self, signum, frame): - self.stopServer() + def _killServer(self, signum, frame): + self.stop() - def stopServer(self): + def stop(self, restart=False): + self.restart = restart if self.wsgiserver: - if gevent_present: + if _GEVENT: self.wsgiserver.close() else: self.wsgiserver.add_callback(self.wsgiserver.stop) - - @staticmethod - def getNameVersion(): - if gevent_present: - return {'Gevent': 'v' + geventVersion} - else: - return {'Tornado': 'v' + tornadoVersion} diff --git a/cps/ub.py b/cps/ub.py index 70d97a20..52e8423c 100644 --- a/cps/ub.py +++ b/cps/ub.py @@ -476,35 +476,27 @@ class Config: def get_config_certfile(self): if cli.certfilepath: return cli.certfilepath - else: - if cli.certfilepath is "": - return None - else: - return self.config_certfile + if cli.certfilepath is "": + return None + return self.config_certfile def get_config_keyfile(self): if cli.keyfilepath: return cli.keyfilepath - else: - if cli.certfilepath is "": - return None - else: - return self.config_keyfile + if cli.certfilepath is "": + return None + return self.config_keyfile def get_config_ipaddress(self, readable=False): if not readable: - if cli.ipadress: - return cli.ipadress + return cli.ipadress or "" + answer="0.0.0.0" + if cli.ipadress: + if cli.ipv6: + answer = "["+cli.ipadress+"]" else: - return "" - else: - answer="0.0.0.0" - if cli.ipadress: - if cli.ipv6: - answer = "["+cli.ipadress+"]" - else: - answer = cli.ipadress - return answer + answer = cli.ipadress + return answer def _has_role(self, role_flag): return constants.has_flag(self.config_default_role, role_flag) diff --git a/cps/updater.py b/cps/updater.py index 28be3b3a..33012db1 100644 --- a/cps/updater.py +++ b/cps/updater.py @@ -33,7 +33,7 @@ from tempfile import gettempdir from babel.dates import format_datetime from flask_babel import gettext as _ -from . import constants, logger, config, get_locale, Server +from . import constants, logger, config, get_locale, web_server log = logger.create() @@ -95,8 +95,7 @@ class Updater(threading.Thread): self.status = 6 log.debug(u'Preparing restart of server') time.sleep(2) - Server.setRestartTyp(True) - Server.stopServer() + web_server.stop(True) self.status = 7 time.sleep(2) except requests.exceptions.HTTPError as ex: