diff --git a/internal/server/dispatcher.go b/internal/server/dispatcher.go index 80c5ef9..1d0ca3c 100644 --- a/internal/server/dispatcher.go +++ b/internal/server/dispatcher.go @@ -141,8 +141,8 @@ func dispatchConnection(conn net.Conn, sta *State) { log.Error("Failed to send first packet to redirection server", err) return } - go io.Copy(webConn, conn) - go io.Copy(conn, webConn) + go common.Copy(webConn, conn) + go common.Copy(conn, webConn) } if err != nil { diff --git a/internal/server/dispatcher_test.go b/internal/server/dispatcher_test.go index a4b9c7f..c98a81b 100644 --- a/internal/server/dispatcher_test.go +++ b/internal/server/dispatcher_test.go @@ -115,6 +115,24 @@ func TestReadFirstPacket(t *testing.T) { assert.NoError(t, ret.err) }) + t.Run("TLS bad recordlayer length", func(t *testing.T) { + local, remote := connutil.AsyncPipe() + buf := make([]byte, 1500) + retChan := make(chan rfpReturnValue) + go rfp(remote, buf, retChan) + + first, _ := hex.DecodeString("160301ffff") + local.Write(first) + + ret := <-retChan + + assert.Equal(t, len(first), ret.n) + assert.Equal(t, first, buf[:ret.n]) + assert.IsType(t, TLS{}, ret.transport) + assert.Equal(t, io.ErrShortBuffer, ret.err) + assert.True(t, ret.redirOnErr) + }) + t.Run("Good WebSocket", func(t *testing.T) { local, remote := connutil.AsyncPipe() buf := make([]byte, 1500)