smallstep-certificates/server/server.go
Mariano Cano 23b8f45b37 Address gosec warnings
Most if not all false positives
2022-08-18 17:46:20 -07:00

159 lines
4.2 KiB
Go

package server
import (
"context"
"crypto/tls"
"log"
"net"
"net/http"
"os"
"time"
"github.com/pkg/errors"
)
// ServerShutdownTimeout is the default time to wait before closing
// connections on shutdown.
const ServerShutdownTimeout = 60 * time.Second
// Server is a incomplete component that implements a basic HTTP/HTTPS
// server.
type Server struct {
*http.Server
listener *net.TCPListener
reloadCh chan net.Listener
shutdownCh chan struct{}
}
// New creates a new HTTP/HTTPS server configured with the passed
// address, http.Handler and tls.Config.
func New(addr string, handler http.Handler, tlsConfig *tls.Config) *Server {
return &Server{
reloadCh: make(chan net.Listener),
shutdownCh: make(chan struct{}),
Server: newHTTPServer(addr, handler, tlsConfig),
}
}
// newHTTPServer creates a new http.Server with the TCP address, handler and
// tls.Config.
func newHTTPServer(addr string, handler http.Handler, tlsConfig *tls.Config) *http.Server {
return &http.Server{
Addr: addr,
Handler: handler,
TLSConfig: tlsConfig,
WriteTimeout: 15 * time.Second,
ReadTimeout: 15 * time.Second,
ReadHeaderTimeout: 15 * time.Second,
IdleTimeout: 15 * time.Second,
ErrorLog: log.New(os.Stderr, "", log.Ldate|log.Ltime|log.Llongfile),
}
}
// ListenAndServe listens on the TCP network address srv.Addr and then calls
// Serve to handle requests on incoming connections.
func (srv *Server) ListenAndServe() error {
ln, err := net.Listen("tcp", srv.Addr)
if err != nil {
return err
}
return srv.Serve(ln)
}
// Serve runs Serve or ServeTLS on the underlying http.Server and listen to
// channels to reload or shutdown the server.
func (srv *Server) Serve(ln net.Listener) error {
var err error
// Store the current listener.
// In reloads we'll create a copy of the underlying os.File so the close of the server one does not affect the copy.
srv.listener = ln.(*net.TCPListener)
for {
// Start server
if srv.TLSConfig == nil || (len(srv.TLSConfig.Certificates) == 0 && srv.TLSConfig.GetCertificate == nil) {
log.Printf("Serving HTTP on %s ...", srv.Addr)
err = srv.Server.Serve(ln)
} else {
log.Printf("Serving HTTPS on %s ...", srv.Addr)
err = srv.Server.ServeTLS(ln, "", "")
}
// log unexpected errors
if err != http.ErrServerClosed {
log.Println(errors.Wrap(err, "unexpected error"))
}
select {
case ln = <-srv.reloadCh:
srv.listener = ln.(*net.TCPListener)
case <-srv.shutdownCh:
return http.ErrServerClosed
}
}
}
// Shutdown gracefully shuts down the server without interrupting any active
// connections.
func (srv *Server) Shutdown() error {
ctx, cancel := context.WithTimeout(context.Background(), ServerShutdownTimeout)
defer cancel() // release resources if Shutdown ends before the timeout
defer close(srv.shutdownCh) // close shutdown channel
return srv.Server.Shutdown(ctx)
}
func (srv *Server) reloadShutdown() error {
ctx, cancel := context.WithTimeout(context.Background(), ServerShutdownTimeout)
defer cancel() // release resources if Shutdown ends before the timeout
return srv.Server.Shutdown(ctx)
}
// Reload reloads the current server with the configuration of the passed
// server.
func (srv *Server) Reload(ns *Server) error {
var err error
var ln net.Listener
if srv.Addr != ns.Addr {
// Open new address
ln, err = net.Listen("tcp", ns.Addr)
if err != nil {
return errors.WithStack(err)
}
} else {
// Get a copy of the underlying os.File
fd, err := srv.listener.File()
if err != nil {
return errors.WithStack(err)
}
// Make sure to close the copy
defer fd.Close()
// Creates a new listener copying fd
ln, err = net.FileListener(fd)
if err != nil {
return errors.WithStack(err)
}
}
// Close old server without sending a signal
if err := srv.reloadShutdown(); err != nil {
return err
}
// Update old server
srv.Server = ns.Server
srv.reloadCh <- ln
return nil
}
// Forbidden writes on the http.ResponseWriter a text/plain forbidden
// response.
func (srv *Server) Forbidden(w http.ResponseWriter) {
header := w.Header()
header.Set("Content-Type", "text/plain; charset=utf-8")
header.Set("Content-Length", "11")
w.WriteHeader(http.StatusForbidden)
w.Write([]byte("Forbidden.\n"))
}