From e3552b2afc8f50af4df85d945e00a651bfd99f48 Mon Sep 17 00:00:00 2001 From: Andy Wang Date: Sun, 29 Dec 2019 13:28:52 +0000 Subject: [PATCH] Refactor server config parser and add tests --- internal/server/state.go | 173 ++++++++++++++++++++-------------- internal/server/state_test.go | 157 ++++++++++++++++++++++++++++++ 2 files changed, 260 insertions(+), 70 deletions(-) create mode 100644 internal/server/state_test.go diff --git a/internal/server/state.go b/internal/server/state.go index f4bb58c..15abe99 100644 --- a/internal/server/state.go +++ b/internal/server/state.go @@ -2,12 +2,10 @@ package server import ( "crypto" - "encoding/base64" "encoding/json" "errors" "fmt" "github.com/cbeuw/Cloak/internal/server/usermanager" - "github.com/sirupsen/logrus" "io/ioutil" "net" "strings" @@ -22,8 +20,8 @@ type rawConfig struct { BindAddr []string BypassUID [][]byte RedirAddr string - PrivateKey string - AdminUID string + PrivateKey []byte + AdminUID []byte DatabasePath string StreamTimeout int CncMode bool @@ -41,7 +39,8 @@ type State struct { BypassUID map[[16]byte]struct{} staticPv crypto.PrivateKey - RedirAddr net.Addr + RedirHost net.Addr + RedirPort string usedRandomM sync.RWMutex usedRandom map[[32]byte]int64 @@ -61,6 +60,89 @@ func InitState(nowFunc func() time.Time) (*State, error) { return ret, nil } +func parseRedirAddr(redirAddr string) (net.Addr, string, error) { + var host string + var port string + colonSep := strings.Split(redirAddr, ":") + if len(colonSep) > 1 { + if len(colonSep) == 2 { + // domain or ipv4 with port + host = colonSep[0] + port = colonSep[1] + } else { + if strings.Contains(redirAddr, "[") { + // ipv6 with port + port = colonSep[len(colonSep)-1] + host = strings.TrimSuffix(redirAddr, "]:"+port) + host = strings.TrimPrefix(host, "[") + } else { + // ipv6 without port + host = redirAddr + } + } + } else { + // domain or ipv4 without port + host = redirAddr + } + + redirHost, err := net.ResolveIPAddr("ip", host) + if err != nil { + return nil, "", fmt.Errorf("unable to resolve RedirAddr: %v. ", err) + } + return redirHost, port, nil +} + +func parseLocalPanel(databasePath string) (*userPanel, *gmux.Router, error) { + manager, err := usermanager.MakeLocalManager(databasePath) + if err != nil { + return nil, nil, err + } + panel := MakeUserPanel(manager) + router := manager.Router + return panel, router, nil + +} + +func parseBindAddr(bindAddrs []string) ([]net.Addr, error) { + var addrs []net.Addr + for _, addr := range bindAddrs { + bindAddr, err := net.ResolveTCPAddr("tcp", addr) + if err != nil { + return nil, err + } + addrs = append(addrs, bindAddr) + } + return addrs, nil +} + +func parseProxyBook(bookEntries map[string][]string) (map[string]net.Addr, error) { + proxyBook := map[string]net.Addr{} + for name, pair := range bookEntries { + name = strings.ToLower(name) + if len(pair) != 2 { + return nil, fmt.Errorf("invalid proxy endpoint and address pair for %v: %v", name, pair) + } + network := strings.ToLower(pair[0]) + switch network { + case "tcp": + addr, err := net.ResolveTCPAddr("tcp", pair[1]) + if err != nil { + return nil, err + } + proxyBook[name] = addr + continue + case "udp": + addr, err := net.ResolveUDPAddr("udp", pair[1]) + if err != nil { + return nil, err + } + proxyBook[name] = addr + continue + } + } + return proxyBook, nil +} + // ParseConfig parses the config (either a path to json or the json itself as argument) into a State variable func (sta *State) ParseConfig(conf string) (err error) { var content []byte @@ -80,14 +162,9 @@ func (sta *State) ParseConfig(conf string) (err error) { } if preParse.CncMode { - //TODO: implement command & control mode + return errors.New("command & control mode not implemented") } else { - manager, err := usermanager.MakeLocalManager(preParse.DatabasePath) - if err != nil { - return err - } - sta.Panel = MakeUserPanel(manager) - sta.LocalAPIRouter = manager.Router + sta.Panel, sta.LocalAPIRouter, err = parseLocalPanel(preParse.DatabasePath) } if preParse.StreamTimeout == 0 { @@ -96,79 +173,35 @@ func (sta *State) ParseConfig(conf string) (err error) { sta.Timeout = time.Duration(preParse.StreamTimeout) * time.Second } - redirAddr := preParse.RedirAddr - colonSep := strings.Split(redirAddr, ":") - if len(colonSep) != 0 { - if len(colonSep) == 2 { - logrus.Error("If RedirAddr contains a port number, please remove it.") - redirAddr = colonSep[0] - } else { - if strings.Contains(redirAddr, "[") { - logrus.Error("If RedirAddr contains a port number, please remove it.") - redirAddr = strings.TrimRight(redirAddr, "]:"+colonSep[len(colonSep)-1]) - redirAddr = strings.TrimPrefix(redirAddr, "[") - } - } - } - - sta.RedirAddr, err = net.ResolveIPAddr("ip", redirAddr) + sta.RedirHost, sta.RedirPort, err = parseRedirAddr(preParse.RedirAddr) if err != nil { - return fmt.Errorf("unable to resolve RedirAddr: %v. ", err) + return fmt.Errorf("unable to parse RedirAddr: %v", err) } - for _, addr := range preParse.BindAddr { - bindAddr, err := net.ResolveTCPAddr("tcp", addr) - if err != nil { - return err - } - sta.BindAddr = append(sta.BindAddr, bindAddr) - } - - for name, pair := range preParse.ProxyBook { - name = strings.ToLower(name) - if len(pair) != 2 { - return fmt.Errorf("invalid proxy endpoint and address pair for %v: %v", name, pair) - } - network := strings.ToLower(pair[0]) - switch network { - case "tcp": - addr, err := net.ResolveTCPAddr("tcp", pair[1]) - if err != nil { - return err - } - sta.ProxyBook[name] = addr - continue - case "udp": - addr, err := net.ResolveUDPAddr("udp", pair[1]) - if err != nil { - return err - } - sta.ProxyBook[name] = addr - continue - } - } - - pvBytes, err := base64.StdEncoding.DecodeString(preParse.PrivateKey) + sta.BindAddr, err = parseBindAddr(preParse.BindAddr) if err != nil { - return errors.New("Failed to decode private key: " + err.Error()) + return fmt.Errorf("unable to parse BindAddr: %v", err) } + + sta.ProxyBook, err = parseProxyBook(preParse.ProxyBook) + if err != nil { + return fmt.Errorf("unable to parse ProxyBook: %v", err) + } + var pv [32]byte - copy(pv[:], pvBytes) + copy(pv[:], preParse.PrivateKey) sta.staticPv = &pv - adminUID, err := base64.StdEncoding.DecodeString(preParse.AdminUID) - if err != nil { - return errors.New("Failed to decode AdminUID: " + err.Error()) - } - sta.AdminUID = adminUID + sta.AdminUID = preParse.AdminUID var arrUID [16]byte for _, UID := range preParse.BypassUID { copy(arrUID[:], UID) sta.BypassUID[arrUID] = struct{}{} } - copy(arrUID[:], adminUID) + copy(arrUID[:], sta.AdminUID) sta.BypassUID[arrUID] = struct{}{} + return nil } diff --git a/internal/server/state_test.go b/internal/server/state_test.go new file mode 100644 index 0000000..ce9d40d --- /dev/null +++ b/internal/server/state_test.go @@ -0,0 +1,157 @@ +package server + +import ( + "net" + "testing" +) + +func TestParseRedirAddr(t *testing.T) { + t.Run("ipv4 without port", func(t *testing.T) { + ipv4noPort := "1.2.3.4" + host, port, err := parseRedirAddr(ipv4noPort) + if err != nil { + t.Errorf("parsing %v error: %v", ipv4noPort, err) + return + } + if host.String() != "1.2.3.4" { + t.Errorf("expected %v got %v", "1.2.3.4", host.String()) + } + if port != "" { + t.Errorf("port not empty when there is no port") + } + }) + + t.Run("ipv4 with port", func(t *testing.T) { + ipv4wPort := "1.2.3.4:1234" + host, port, err := parseRedirAddr(ipv4wPort) + if err != nil { + t.Errorf("parsing %v error: %v", ipv4wPort, err) + return + } + if host.String() != "1.2.3.4" { + t.Errorf("expected %v got %v", "1.2.3.4", host.String()) + } + if port != "1234" { + t.Errorf("wrong port: expected %v, got %v", "1234", port) + } + }) + + t.Run("domain without port", func(t *testing.T) { + domainNoPort := "example.com" + host, port, err := parseRedirAddr(domainNoPort) + if err != nil { + t.Errorf("parsing %v error: %v", domainNoPort, err) + return + } + expHost, err := net.ResolveIPAddr("ip", "example.com") + if err != nil { + t.Errorf("tester error: cannot resolve example.com: %v", err) + return + } + if host.String() != expHost.String() { + t.Errorf("expected %v got %v", expHost.String(), host.String()) + } + if port != "" { + t.Errorf("port not empty when there is no port") + } + }) + + t.Run("domain with port", func(t *testing.T) { + domainWPort := "example.com:80" + host, port, err := parseRedirAddr(domainWPort) + if err != nil { + t.Errorf("parsing %v error: %v", domainWPort, err) + return + } + expHost, err := net.ResolveIPAddr("ip", "example.com") + if err != nil { + t.Errorf("tester error: cannot resolve example.com: %v", err) + return + } + if host.String() != expHost.String() { + t.Errorf("expected %v got %v", expHost.String(), host.String()) + } + if port != "80" { + t.Errorf("wrong port: expected %v, got %v", "80", port) + } + }) + + t.Run("ipv6 without port", func(t *testing.T) { + ipv6noPort := "a:b:c:d::" + host, port, err := parseRedirAddr(ipv6noPort) + if err != nil { + t.Errorf("parsing %v error: %v", ipv6noPort, err) + return + } + if host.String() != "a:b:c:d::" { + t.Errorf("expected %v got %v", "a:b:c:d::", host.String()) + } + if port != "" { + t.Errorf("port not empty when there is no port") + } + }) + + t.Run("ipv6 with port", func(t *testing.T) { + ipv6wPort := "[a:b:c:d::]:80" + host, port, err := parseRedirAddr(ipv6wPort) + if err != nil { + t.Errorf("parsing %v error: %v", ipv6wPort, err) + return + } + if host.String() != "a:b:c:d::" { + t.Errorf("expected %v got %v", "a:b:c:d::", host.String()) + } + if port != "80" { + t.Errorf("wrong port: expected %v, got %v", "80", port) + } + }) +} + +func TestParseBindAddr(t *testing.T) { + t.Run("port only", func(t *testing.T) { + addrs, err := parseBindAddr([]string{":443"}) + if err != nil { + t.Error(err) + return + } + if addrs[0].String() != ":443" { + t.Errorf("expected %v got %v", ":443", addrs[0].String()) + } + }) + + t.Run("specific address", func(t *testing.T) { + addrs, err := parseBindAddr([]string{"192.168.1.123:443"}) + if err != nil { + t.Error(err) + return + } + if addrs[0].String() != "192.168.1.123:443" { + t.Errorf("expected %v got %v", "192.168.1.123:443", addrs[0].String()) + } + }) + + t.Run("ipv6", func(t *testing.T) { + addrs, err := parseBindAddr([]string{"[::]:443"}) + if err != nil { + t.Error(err) + return + } + if addrs[0].String() != "[::]:443" { + t.Errorf("expected %v got %v", "[::]:443", addrs[0].String()) + } + }) + + t.Run("mixed", func(t *testing.T) { + addrs, err := parseBindAddr([]string{":80", "[::]:443"}) + if err != nil { + t.Error(err) + return + } + if addrs[0].String() != ":80" { + t.Errorf("expected %v got %v", ":80", addrs[0].String()) + } + if addrs[1].String() != "[::]:443" { + t.Errorf("expected %v got %v", "[::]:443", addrs[1].String()) + } + }) +}