mirror of
https://github.com/cbeuw/Cloak.git
synced 2024-11-09 19:10:44 +00:00
151 lines
3.7 KiB
Go
151 lines
3.7 KiB
Go
package multiplex
|
|
|
|
import (
|
|
"bytes"
|
|
"io"
|
|
"math/rand"
|
|
"net"
|
|
"sync"
|
|
"testing"
|
|
|
|
"github.com/cbeuw/Cloak/internal/common"
|
|
"github.com/cbeuw/connutil"
|
|
"github.com/stretchr/testify/assert"
|
|
)
|
|
|
|
func serveEcho(l net.Listener) {
|
|
for {
|
|
conn, err := l.Accept()
|
|
if err != nil {
|
|
// TODO: pass the error back
|
|
return
|
|
}
|
|
go func(conn net.Conn) {
|
|
_, err := io.Copy(conn, conn)
|
|
if err != nil {
|
|
// TODO: pass the error back
|
|
return
|
|
}
|
|
}(conn)
|
|
}
|
|
}
|
|
|
|
type connPair struct {
|
|
clientConn net.Conn
|
|
serverConn net.Conn
|
|
}
|
|
|
|
func makeSessionPair(numConn int) (*Session, *Session, []*connPair) {
|
|
sessionKey := [32]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31}
|
|
sessionId := 1
|
|
obfuscator, _ := MakeObfuscator(EncryptionMethodChaha20Poly1305, sessionKey)
|
|
clientConfig := SessionConfig{
|
|
Obfuscator: obfuscator,
|
|
Valve: nil,
|
|
Unordered: false,
|
|
}
|
|
serverConfig := clientConfig
|
|
|
|
clientSession := MakeSession(uint32(sessionId), clientConfig)
|
|
serverSession := MakeSession(uint32(sessionId), serverConfig)
|
|
|
|
paris := make([]*connPair, numConn)
|
|
for i := 0; i < numConn; i++ {
|
|
c, s := connutil.AsyncPipe()
|
|
clientConn := common.NewTLSConn(c)
|
|
serverConn := common.NewTLSConn(s)
|
|
paris[i] = &connPair{
|
|
clientConn: clientConn,
|
|
serverConn: serverConn,
|
|
}
|
|
clientSession.AddConnection(clientConn)
|
|
serverSession.AddConnection(serverConn)
|
|
}
|
|
return clientSession, serverSession, paris
|
|
}
|
|
|
|
func runEchoTest(t *testing.T, conns []net.Conn, msgLen int) {
|
|
var wg sync.WaitGroup
|
|
|
|
for _, conn := range conns {
|
|
wg.Add(1)
|
|
go func(conn net.Conn) {
|
|
defer wg.Done()
|
|
|
|
testData := make([]byte, msgLen)
|
|
rand.Read(testData)
|
|
|
|
// we cannot call t.Fatalf in concurrent contexts
|
|
n, err := conn.Write(testData)
|
|
if n != msgLen {
|
|
t.Errorf("written only %v, err %v", n, err)
|
|
return
|
|
}
|
|
|
|
recvBuf := make([]byte, msgLen)
|
|
_, err = io.ReadFull(conn, recvBuf)
|
|
if err != nil {
|
|
t.Errorf("failed to read back: %v", err)
|
|
return
|
|
}
|
|
|
|
if !bytes.Equal(testData, recvBuf) {
|
|
t.Errorf("echoed data not correct")
|
|
return
|
|
}
|
|
}(conn)
|
|
}
|
|
wg.Wait()
|
|
}
|
|
|
|
func TestMultiplex(t *testing.T) {
|
|
const numStreams = 2000 // -race option limits the number of goroutines to 8192
|
|
const numConns = 4
|
|
const msgLen = 16384
|
|
|
|
clientSession, serverSession, _ := makeSessionPair(numConns)
|
|
go serveEcho(serverSession)
|
|
|
|
streams := make([]net.Conn, numStreams)
|
|
for i := 0; i < numStreams; i++ {
|
|
stream, err := clientSession.OpenStream()
|
|
assert.NoError(t, err)
|
|
streams[i] = stream
|
|
}
|
|
|
|
//test echo
|
|
runEchoTest(t, streams, msgLen)
|
|
|
|
assert.EqualValues(t, numStreams, clientSession.streamCount(), "client stream count is wrong")
|
|
assert.EqualValues(t, numStreams, serverSession.streamCount(), "server stream count is wrong")
|
|
|
|
// close one stream
|
|
closing, streams := streams[0], streams[1:]
|
|
err := closing.Close()
|
|
assert.NoError(t, err, "couldn't close a stream")
|
|
_, err = closing.Write([]byte{0})
|
|
assert.Equal(t, ErrBrokenStream, err)
|
|
_, err = closing.Read(make([]byte, 1))
|
|
assert.Equal(t, ErrBrokenStream, err)
|
|
}
|
|
|
|
func TestMux_StreamClosing(t *testing.T) {
|
|
clientSession, serverSession, _ := makeSessionPair(1)
|
|
go serveEcho(serverSession)
|
|
|
|
// read after closing stream
|
|
testData := make([]byte, 128)
|
|
recvBuf := make([]byte, 128)
|
|
toBeClosed, _ := clientSession.OpenStream()
|
|
_, err := toBeClosed.Write(testData) // should be echoed back
|
|
assert.NoError(t, err, "couldn't write to a stream")
|
|
|
|
_, err = io.ReadFull(toBeClosed, recvBuf[:1])
|
|
assert.NoError(t, err, "can't read anything before stream closed")
|
|
|
|
_ = toBeClosed.Close()
|
|
_, err = io.ReadFull(toBeClosed, recvBuf[1:])
|
|
assert.NoError(t, err, "can't read residual data on stream")
|
|
assert.Equal(t, testData, recvBuf, "incorrect data read back")
|
|
}
|