diff --git a/internal/server/dispatcher.go b/internal/server/dispatcher.go index 87b4701..258fa13 100644 --- a/internal/server/dispatcher.go +++ b/internal/server/dispatcher.go @@ -57,6 +57,7 @@ func dispatchConnection(conn net.Conn, sta *State) { _, err = webConn.Write(data) if err != nil { log.Error("Failed to send first packet to redirection server", err) + return } go io.Copy(webConn, conn) go io.Copy(conn, webConn) diff --git a/internal/test/integration_test.go b/internal/test/integration_test.go index 9ef9a90..99877cd 100644 --- a/internal/test/integration_test.go +++ b/internal/test/integration_test.go @@ -7,7 +7,6 @@ import ( "github.com/cbeuw/Cloak/internal/common" mux "github.com/cbeuw/Cloak/internal/multiplex" "github.com/cbeuw/Cloak/internal/server" - "github.com/cbeuw/Cloak/internal/server/usermanager" "github.com/cbeuw/connutil" "io" "io/ioutil" @@ -57,27 +56,31 @@ func basicClientConfigs(state common.WorldState) (client.LocalConnConfig, client LocalHost: "127.0.0.1", LocalPort: "9999", } - lcl, rmt, auth, _ := clientConfig.SplitConfigs(state) + lcl, rmt, auth, err := clientConfig.SplitConfigs(state) + if err != nil { + log.Fatal(err) + } return lcl, rmt, auth } func basicServerState(ws common.WorldState, db *os.File) *server.State { - manager, _ := usermanager.MakeLocalManager(db.Name()) - var pv [32]byte - copy(pv[:], privateKey) - serverState := &server.State{ - ProxyBook: map[string]net.Addr{"test": &net.TCPAddr{}}, - UsedRandom: map[[32]byte]int64{}, - Timeout: 0, - BypassUID: map[[16]byte]struct{}{bypassUID: {}}, - RedirHost: &net.TCPAddr{}, - RedirPort: "9999", - Panel: server.MakeUserPanel(manager), - LocalAPIRouter: nil, - StaticPv: &pv, - WorldState: ws, + var serverConfig = server.RawConfig{ + ProxyBook: map[string][]string{"test": {"tcp", "fake.com:9999"}}, + BindAddr: []string{"fake.com:9999"}, + BypassUID: [][]byte{bypassUID[:]}, + RedirAddr: "fake.com:9999", + PrivateKey: privateKey, + AdminUID: nil, + DatabasePath: db.Name(), + StreamTimeout: 300, + KeepAlive: 15, + CncMode: false, } - return serverState + state, err := server.InitState(serverConfig, ws) + if err != nil { + log.Fatal(err) + } + return state } func establishSession(lcc client.LocalConnConfig, rcc client.RemoteConnConfig, ai client.AuthInfo, serverState *server.State) (common.Dialer, net.Listener, common.Dialer, net.Listener, error) { @@ -103,11 +106,11 @@ func establishSession(lcc client.LocalConnConfig, rcc client.RemoteConnConfig, a } func runEchoTest(t *testing.T, conns []net.Conn) { - const testDataLen = 16384 var wg sync.WaitGroup for _, conn := range conns { wg.Add(1) go func(conn net.Conn) { + testDataLen := rand.Intn(65536) testData := make([]byte, testDataLen) rand.Read(testData) @@ -134,7 +137,7 @@ func runEchoTest(t *testing.T, conns []net.Conn) { func TestTCP(t *testing.T) { var tmpDB, _ = ioutil.TempFile("", "ck_user_info") defer os.Remove(tmpDB.Name()) - log.SetOutput(ioutil.Discard) + log.SetLevel(log.FatalLevel) worldState := common.WorldOfTime(time.Unix(10, 0)) lcc, rcc, ai := basicClientConfigs(worldState) @@ -172,3 +175,38 @@ func TestTCP(t *testing.T) { runEchoTest(t, conns[:]) }) } + +func TestClosingStreamsFromProxy(t *testing.T) { + var tmpDB, _ = ioutil.TempFile("", "ck_user_info") + defer os.Remove(tmpDB.Name()) + log.SetLevel(log.FatalLevel) + worldState := common.WorldOfTime(time.Unix(10, 0)) + lcc, rcc, ai := basicClientConfigs(worldState) + sta := basicServerState(worldState, tmpDB) + pxyClientD, pxyServerL, _, _, err := establishSession(lcc, rcc, ai, sta) + if err != nil { + t.Fatal(err) + } + + // closing stream on server side + clientConn, _ := pxyClientD.Dial("", "") + clientConn.Write(make([]byte, 16)) + serverConn, _ := pxyServerL.Accept() + serverConn.Close() + + time.Sleep(100 * time.Millisecond) + if _, err := clientConn.Read(make([]byte, 16)); err == nil { + t.Errorf("closing stream on server side is not reflected to the client: %v", err) + } + + // closing stream on client side + clientConn, _ = pxyClientD.Dial("", "") + clientConn.Write(make([]byte, 16)) + serverConn, _ = pxyServerL.Accept() + clientConn.Close() + + time.Sleep(100 * time.Millisecond) + if _, err := serverConn.Read(make([]byte, 16)); err == nil { + t.Errorf("closing stream on client side is not reflected to the server: %v", err) + } +} diff --git a/internal/test/test.go b/internal/test/test.go index aa63f31..56e5404 100644 --- a/internal/test/test.go +++ b/internal/test/test.go @@ -1,3 +1 @@ package test - -func blah() {}