Refactor udp piping and add tests

pull/110/head
Andy Wang 4 years ago
parent 9f413ff23a
commit 5d4e8b8d8d

@ -166,7 +166,7 @@ func main() {
}
if authInfo.Unordered {
client.RouteUDP(localConfig, seshMaker)
client.RouteUDP(net.ListenPacket, localConfig, seshMaker)
} else {
listener, err := net.Listen("tcp", localConfig.LocalAddr)
if err != nil {

@ -11,20 +11,16 @@ import (
log "github.com/sirupsen/logrus"
)
func RouteUDP(localConfig LocalConnConfig, newSeshFunc func() *mux.Session) {
func RouteUDP(listen func(string, string) (net.PacketConn, error), localConfig LocalConnConfig, newSeshFunc func() *mux.Session) {
var sesh *mux.Session
localUDPAddr, err := net.ResolveUDPAddr("udp", localConfig.LocalAddr)
if err != nil {
log.Fatal(err)
}
start:
localConn, err := net.ListenUDP("udp", localUDPAddr)
localConn, err := listen("udp", localConfig.LocalAddr)
if err != nil {
log.Fatal(err)
}
var otherEnd atomic.Value
data := make([]byte, 10240)
i, oe, err := localConn.ReadFromUDP(data)
i, oe, err := localConn.ReadFrom(data)
if err != nil {
log.Errorf("Failed to read first packet from proxy client: %v", err)
localConn.Close()
@ -35,7 +31,7 @@ start:
if sesh == nil || sesh.IsClosed() {
sesh = newSeshFunc()
}
log.Debugf("proxy local address %v", otherEnd.Load().(*net.UDPAddr).String())
log.Debugf("proxy local address %v", otherEnd.Load().(net.Addr).String())
stream, err := sesh.OpenStream()
if err != nil {
log.Errorf("Failed to open stream: %v", err)
@ -63,7 +59,7 @@ start:
stream.Close()
break
}
_, err = localConn.WriteToUDP(buf[:i], otherEnd.Load().(*net.UDPAddr))
_, err = localConn.WriteTo(buf[:i], otherEnd.Load().(net.Addr))
if err != nil {
log.Print(err)
localConn.Close()
@ -82,7 +78,7 @@ start:
if localConfig.Timeout != 0 {
localConn.SetReadDeadline(time.Now().Add(localConfig.Timeout))
}
i, oe, err := localConn.ReadFromUDP(buf)
i, oe, err := localConn.ReadFrom(buf)
if err != nil {
localConn.Close()
stream.Close()

@ -101,6 +101,10 @@ func (s *Stream) Write(in []byte) (n int, err error) {
if len(in)-n <= s.session.maxStreamUnitWrite {
framePayload = in[n:]
} else {
if s.session.Unordered { // no splitting
err = io.ErrShortBuffer
return
}
framePayload = in[n : s.session.maxStreamUnitWrite+n]
}

@ -21,23 +21,60 @@ import (
log "github.com/sirupsen/logrus"
)
func serveEcho(l net.Listener) {
func serveTCPEcho(l net.Listener) {
for {
conn, err := l.Accept()
if err != nil {
// TODO: pass the error back
log.Error(err)
return
}
go func() {
conn := conn
_, err := io.Copy(conn, conn)
if err != nil {
// TODO: pass the error back
conn.Close()
log.Error(err)
return
}
}()
}
}
/*
func serveUDPEcho(listener *connutil.PipeListener) {
for {
conn, err := listener.ListenPacket("udp", "")
if err != nil {
log.Error(err)
return
}
const bufSize = 32 * 1024
go func() {
conn := conn
defer conn.Close()
buf := make([]byte, bufSize)
for {
r,_, err := conn.ReadFrom(buf)
if err != nil {
log.Error(err)
return
}
w, err := conn.WriteTo(buf[:r], nil)
if err != nil {
log.Error(err)
return
}
if r != w {
log.Error("written not eqal to read")
return
}
}
}()
}
}
*/
var bypassUID = [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}
var publicKey, _ = base64.StdEncoding.DecodeString("7f7TuKrs264VNSgMno8PkDlyhGhVuOSR8JHLE6H4Ljc=")
var privateKey, _ = base64.StdEncoding.DecodeString("SMWeC6VuZF8S/id65VuFQFlfa7hTEJBpL6wWhqPP100=")
@ -45,7 +82,7 @@ var privateKey, _ = base64.StdEncoding.DecodeString("SMWeC6VuZF8S/id65VuFQFlfa7h
func basicClientConfigs(state common.WorldState) (client.LocalConnConfig, client.RemoteConnConfig, client.AuthInfo) {
var clientConfig = client.RawConfig{
ServerName: "www.example.com",
ProxyMethod: "test",
ProxyMethod: "tcp",
EncryptionMethod: "plain",
UID: bypassUID[:],
PublicKey: publicKey,
@ -66,7 +103,7 @@ func basicClientConfigs(state common.WorldState) (client.LocalConnConfig, client
func basicServerState(ws common.WorldState, db *os.File) *server.State {
var serverConfig = server.RawConfig{
ProxyBook: map[string][]string{"test": {"tcp", "fake.com:9999"}},
ProxyBook: map[string][]string{"tcp": {"tcp", "fake.com:9999"}, "udp": {"udp", "fake.com:9999"}},
BindAddr: []string{"fake.com:9999"},
BypassUID: [][]byte{bypassUID[:]},
RedirAddr: "fake.com:9999",
@ -84,16 +121,19 @@ func basicServerState(ws common.WorldState, db *os.File) *server.State {
return state
}
func establishSession(lcc client.LocalConnConfig, rcc client.RemoteConnConfig, ai client.AuthInfo, serverState *server.State) (common.Dialer, net.Listener, common.Dialer, net.Listener, error) {
func establishSession(lcc client.LocalConnConfig, rcc client.RemoteConnConfig, ai client.AuthInfo, serverState *server.State) (common.Dialer, *connutil.PipeListener, common.Dialer, net.Listener, error) {
// transport
ckClientDialer, ckServerListener := connutil.DialerListener(10 * 1024)
clientSeshMaker := func() *mux.Session {
return client.MakeSession(rcc, ai, ckClientDialer, false)
}
proxyToCkClientD, proxyToCkClientL := connutil.DialerListener(10 * 1024)
go client.RouteTCP(proxyToCkClientL, lcc.Timeout, clientSeshMaker)
if ai.Unordered {
go client.RouteUDP(proxyToCkClientL.ListenPacket, lcc, clientSeshMaker)
} else {
go client.RouteTCP(proxyToCkClientL, lcc.Timeout, clientSeshMaker)
}
// set up server
ckServerToProxyD, ckServerToProxyL := connutil.DialerListener(10 * 1024)
@ -106,12 +146,12 @@ func establishSession(lcc client.LocalConnConfig, rcc client.RemoteConnConfig, a
return proxyToCkClientD, ckServerToProxyL, ckClientDialer, ckServerToWebL, nil
}
func runEchoTest(t *testing.T, conns []net.Conn) {
func runEchoTest(t *testing.T, conns []net.Conn, maxMsgLen int) {
var wg sync.WaitGroup
for _, conn := range conns {
wg.Add(1)
go func(conn net.Conn) {
testDataLen := rand.Intn(65536)
testDataLen := rand.Intn(maxMsgLen)
testData := make([]byte, testDataLen)
rand.Read(testData)
@ -135,10 +175,59 @@ func runEchoTest(t *testing.T, conns []net.Conn) {
wg.Wait()
}
func TestUDP(t *testing.T) {
var tmpDB, _ = ioutil.TempFile("", "ck_user_info")
defer os.Remove(tmpDB.Name())
log.SetLevel(log.TraceLevel)
worldState := common.WorldOfTime(time.Unix(10, 0))
lcc, rcc, ai := basicClientConfigs(worldState)
ai.ProxyMethod = "udp"
ai.Unordered = true
sta := basicServerState(worldState, tmpDB)
pxyClientD, pxyServerL, _, _, err := establishSession(lcc, rcc, ai, sta)
if err != nil {
t.Fatal(err)
}
t.Run("simple send", func(t *testing.T) {
pxyClientConn, err := pxyClientD.Dial("udp", "")
if err != nil {
t.Error(err)
}
const testDataLen = 1500
testData := make([]byte, testDataLen)
rand.Read(testData)
n, err := pxyClientConn.Write(testData)
if n != testDataLen {
t.Errorf("wrong length sent: %v", n)
}
if err != nil {
t.Error(err)
}
pxyServerConn, err := pxyServerL.ListenPacket("", "")
if err != nil {
t.Error(err)
}
recvBuf := make([]byte, testDataLen+100)
r, _, err := pxyServerConn.ReadFrom(recvBuf)
if err != nil {
t.Error(err)
}
if !bytes.Equal(testData, recvBuf[:r]) {
t.Error("read wrong data")
}
})
}
func TestTCP(t *testing.T) {
var tmpDB, _ = ioutil.TempFile("", "ck_user_info")
defer os.Remove(tmpDB.Name())
log.SetLevel(log.FatalLevel)
log.SetLevel(log.ErrorLevel)
worldState := common.WorldOfTime(time.Unix(10, 0))
lcc, rcc, ai := basicClientConfigs(worldState)
@ -155,7 +244,7 @@ func TestTCP(t *testing.T) {
writeData := make([]byte, dataLen)
rand.Read(writeData)
t.Run(fmt.Sprintf("data length %v", dataLen), func(t *testing.T) {
go serveEcho(pxyServerL)
go serveTCPEcho(pxyServerL)
conn, err := pxyClientD.Dial("", "")
if err != nil {
t.Error(err)
@ -182,7 +271,7 @@ func TestTCP(t *testing.T) {
})
t.Run("user echo", func(t *testing.T) {
go serveEcho(pxyServerL)
go serveTCPEcho(pxyServerL)
const numConns = 2000 // -race option limits the number of goroutines to 8192
var conns [numConns]net.Conn
for i := 0; i < numConns; i++ {
@ -192,11 +281,11 @@ func TestTCP(t *testing.T) {
}
}
runEchoTest(t, conns[:])
runEchoTest(t, conns[:], 65536)
})
t.Run("redir echo", func(t *testing.T) {
go serveEcho(rdirServerL)
go serveTCPEcho(rdirServerL)
const numConns = 2000 // -race option limits the number of goroutines to 8192
var conns [numConns]net.Conn
for i := 0; i < numConns; i++ {
@ -205,14 +294,14 @@ func TestTCP(t *testing.T) {
t.Error(err)
}
}
runEchoTest(t, conns[:])
runEchoTest(t, conns[:], 65536)
})
}
func TestClosingStreamsFromProxy(t *testing.T) {
var tmpDB, _ = ioutil.TempFile("", "ck_user_info")
defer os.Remove(tmpDB.Name())
log.SetLevel(log.FatalLevel)
log.SetLevel(log.ErrorLevel)
worldState := common.WorldOfTime(time.Unix(10, 0))
lcc, rcc, ai := basicClientConfigs(worldState)
sta := basicServerState(worldState, tmpDB)
@ -247,7 +336,7 @@ func TestClosingStreamsFromProxy(t *testing.T) {
func BenchmarkThroughput(b *testing.B) {
var tmpDB, _ = ioutil.TempFile("", "ck_user_info")
defer os.Remove(tmpDB.Name())
log.SetLevel(log.FatalLevel)
log.SetLevel(log.ErrorLevel)
worldState := common.WorldOfTime(time.Unix(10, 0))
lcc, rcc, ai := basicClientConfigs(worldState)
sta := basicServerState(worldState, tmpDB)

Loading…
Cancel
Save