diff --git a/server/server.go b/server/server.go index e12c792c..b1f90c55 100644 --- a/server/server.go +++ b/server/server.go @@ -7,6 +7,7 @@ import ( "net" "net/http" "os" + "strings" "time" "github.com/pkg/errors" @@ -65,16 +66,24 @@ func (srv *Server) ListenAndServe() error { // channels to reload or shutdown the server. func (srv *Server) Serve(ln net.Listener) error { var err error - // Store the current listener. - // In reloads we'll create a copy of the underlying os.File so the close of the server one does not affect the copy. - srv.listener = ln.(*net.TCPListener) + + switch l := ln.(type) { + case *net.TCPListener: + // Store the current listener. + // In reloads we'll create a copy of the underlying os.File so the close of the server one does not affect the copy. + srv.listener = l + } for { - // Start server - if srv.TLSConfig == nil || (len(srv.TLSConfig.Certificates) == 0 && srv.TLSConfig.GetCertificate == nil) { + wl, isWrapped := ln.(*WrappedListener) + switch { + case srv.TLSConfig == nil || (len(srv.TLSConfig.Certificates) == 0 && srv.TLSConfig.GetCertificate == nil): log.Printf("Serving HTTP on %s ...", srv.Addr) err = srv.Server.Serve(ln) - } else { + case isWrapped: + log.Printf("Serving %s on %s ...", wl.proto, wl.Addr()) + err = srv.Server.Serve(wl) + default: log.Printf("Serving HTTPS on %s ...", srv.Addr) err = srv.Server.ServeTLS(ln, "", "") } @@ -93,6 +102,36 @@ func (srv *Server) Serve(ln net.Listener) error { } } +// NewWrappedListener wraps the inner [net.Listener]. +func NewWrappedListener(inner net.Listener, proto string) *WrappedListener { + return &WrappedListener{ + inner: inner, + proto: strings.ToUpper(proto), + } +} + +// WrappedListener wraps a [net.Listener]. +type WrappedListener struct { + inner net.Listener + proto string +} + +// Accept waits for and returns the next connection to the listener. +func (w *WrappedListener) Accept() (net.Conn, error) { + return w.inner.Accept() +} + +// Close closes the listener. +// Any blocked Accept operations will be unblocked and return errors. +func (w *WrappedListener) Close() error { + return w.inner.Close() +} + +// Addr returns the listener's network address. +func (w *WrappedListener) Addr() net.Addr { + return w.inner.Addr() +} + // Shutdown gracefully shuts down the server without interrupting any active // connections. func (srv *Server) Shutdown() error {