diff --git a/internal/multiplex/bufferedPipe_test.go b/internal/multiplex/bufferedPipe_test.go index 4c94be2..307a7da 100644 --- a/internal/multiplex/bufferedPipe_test.go +++ b/internal/multiplex/bufferedPipe_test.go @@ -2,6 +2,7 @@ package multiplex import ( "bytes" + "math/rand" "testing" "time" ) @@ -164,3 +165,29 @@ func TestReadAfterClose(t *testing.T) { ) } } + +func BenchmarkBufferedPipe_RW(b *testing.B) { + const PAYLOAD_LEN = 1300 + testData := make([]byte, PAYLOAD_LEN) + rand.Read(testData) + + pipe := NewBufferedPipe() + + smallBuf := make([]byte, PAYLOAD_LEN-10) + go func() { + for { + pipe.Read(smallBuf) + } + }() + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, err := pipe.Write(testData) + if err != nil { + b.Error( + "For", "pipe write", + "got", err, + ) + } + b.SetBytes(PAYLOAD_LEN) + } +} diff --git a/internal/multiplex/obfs.go b/internal/multiplex/obfs.go index dd988d6..8682a46 100644 --- a/internal/multiplex/obfs.go +++ b/internal/multiplex/obfs.go @@ -8,9 +8,11 @@ import ( "errors" "golang.org/x/crypto/chacha20poly1305" "golang.org/x/crypto/salsa20" + + prand "math/rand" ) -type Obfser func(*Frame) ([]byte, error) +type Obfser func(*Frame, []byte) (int, error) type Deobfser func([]byte) (*Frame, error) var u32 = binary.BigEndian.Uint32 @@ -19,27 +21,37 @@ var putU32 = binary.BigEndian.PutUint32 const HEADER_LEN = 12 func MakeObfs(salsaKey [32]byte, payloadCipher cipher.AEAD) Obfser { - var tagLen int - if payloadCipher == nil { - tagLen = 8 //nonce - } else { - tagLen = payloadCipher.Overhead() - } - obfs := func(f *Frame) ([]byte, error) { - ret := make([]byte, 5+HEADER_LEN+len(f.Payload)+tagLen) - recordLayer := ret[0:5] - header := ret[5 : 5+HEADER_LEN] - encryptedPayload := ret[5+HEADER_LEN:] + obfs := func(f *Frame, buf []byte) (int, error) { + var extraLen uint8 + if payloadCipher == nil { + if len(f.Payload) < 8 { + extraLen = uint8(8 - len(f.Payload)) + } + } else { + extraLen = uint8(payloadCipher.Overhead()) + } - // header: [StreamID 4 bytes][Seq 4 bytes][Closing 1 byte][random 3 bytes] + usefulLen := 5 + HEADER_LEN + len(f.Payload) + int(extraLen) + if len(buf) < usefulLen { + return 0, errors.New("buffer is too small") + } + used := buf[:usefulLen] + recordLayer := used[0:5] + header := used[5 : 5+HEADER_LEN] + encryptedPayload := used[5+HEADER_LEN:] + + // header: [StreamID 4 bytes][Seq 4 bytes][Closing 1 byte][extraLen 1 bytes][random 2 bytes] putU32(header[0:4], f.StreamID) putU32(header[4:8], f.Seq) header[8] = f.Closing - rand.Read(header[9:12]) + header[9] = extraLen + prand.Read(header[10:12]) if payloadCipher == nil { copy(encryptedPayload, f.Payload) - rand.Read(encryptedPayload[len(encryptedPayload)-tagLen:]) + if extraLen != 0 { + rand.Read(encryptedPayload[len(encryptedPayload)-int(extraLen):]) + } } else { ciphertext := payloadCipher.Seal(nil, header, f.Payload, nil) copy(encryptedPayload, ciphertext) @@ -54,20 +66,14 @@ func MakeObfs(salsaKey [32]byte, payloadCipher cipher.AEAD) Obfser { recordLayer[1] = 0x03 recordLayer[2] = 0x03 binary.BigEndian.PutUint16(recordLayer[3:5], uint16(HEADER_LEN+len(encryptedPayload))) - return ret, nil + return usefulLen, nil } return obfs } func MakeDeobfs(salsaKey [32]byte, payloadCipher cipher.AEAD) Deobfser { - var tagLen int - if payloadCipher == nil { - tagLen = 8 // nonce - } else { - tagLen = payloadCipher.Overhead() - } deobfs := func(in []byte) (*Frame, error) { - if len(in) < 5+HEADER_LEN+tagLen { + if len(in) < 5+HEADER_LEN+8 { return nil, errors.New("Input cannot be shorter than 33 bytes") } peeled := in[5:] @@ -81,8 +87,9 @@ func MakeDeobfs(salsaKey [32]byte, payloadCipher cipher.AEAD) Deobfser { streamID := u32(header[0:4]) seq := u32(header[4:8]) closing := header[8] + extraLen := header[9] - outputPayload := make([]byte, len(payload)-tagLen) + outputPayload := make([]byte, len(payload)-int(extraLen)) if payloadCipher == nil { copy(outputPayload, payload) diff --git a/internal/multiplex/obfs_test.go b/internal/multiplex/obfs_test.go index ca3344f..ff489ca 100644 --- a/internal/multiplex/obfs_test.go +++ b/internal/multiplex/obfs_test.go @@ -15,21 +15,22 @@ func TestOobfs(t *testing.T) { sessionKey := make([]byte, 32) rand.Read(sessionKey) - run := func(obfuscator *Obfuscator) { + run := func(obfuscator *Obfuscator, ct *testing.T) { + obfsBuf := make([]byte, 512) f := &Frame{} _testFrame, _ := quick.Value(reflect.TypeOf(f), rand.New(rand.NewSource(42))) testFrame := _testFrame.Interface().(*Frame) - obfsed, err := obfuscator.Obfs(testFrame) + i, err := obfuscator.Obfs(testFrame, obfsBuf) if err != nil { - t.Error("failed to obfs ", err) + ct.Error("failed to obfs ", err) } - resultFrame, err := obfuscator.Deobfs(obfsed) + resultFrame, err := obfuscator.Deobfs(obfsBuf[:i]) if err != nil { - t.Error("failed to deobfs ", err) + ct.Error("failed to deobfs ", err) } if !bytes.Equal(testFrame.Payload, resultFrame.Payload) || testFrame.StreamID != resultFrame.StreamID { - t.Error("expecting", testFrame, + ct.Error("expecting", testFrame, "got", resultFrame) } } @@ -39,21 +40,21 @@ func TestOobfs(t *testing.T) { if err != nil { t.Errorf("failed to generate obfuscator %v", err) } - run(obfuscator) + run(obfuscator, t) }) t.Run("aes-gcm", func(t *testing.T) { obfuscator, err := GenerateObfs(0x01, sessionKey) if err != nil { t.Errorf("failed to generate obfuscator %v", err) } - run(obfuscator) + run(obfuscator, t) }) t.Run("chacha20-poly1305", func(t *testing.T) { obfuscator, err := GenerateObfs(0x01, sessionKey) if err != nil { t.Errorf("failed to generate obfuscator %v", err) } - run(obfuscator) + run(obfuscator, t) }) } @@ -68,6 +69,8 @@ func BenchmarkObfs(b *testing.B) { testPayload, } + obfsBuf := make([]byte, 512) + var key [32]byte rand.Read(key[:]) b.Run("AES256GCM", func(b *testing.B) { @@ -77,7 +80,7 @@ func BenchmarkObfs(b *testing.B) { obfs := MakeObfs(key, payloadCipher) b.ResetTimer() for i := 0; i < b.N; i++ { - obfs(testFrame) + obfs(testFrame, obfsBuf) } }) b.Run("AES128GCM", func(b *testing.B) { @@ -87,14 +90,14 @@ func BenchmarkObfs(b *testing.B) { obfs := MakeObfs(key, payloadCipher) b.ResetTimer() for i := 0; i < b.N; i++ { - obfs(testFrame) + obfs(testFrame, obfsBuf) } }) b.Run("plain", func(b *testing.B) { obfs := MakeObfs(key, nil) b.ResetTimer() for i := 0; i < b.N; i++ { - obfs(testFrame) + obfs(testFrame, obfsBuf) } }) b.Run("chacha20Poly1305", func(b *testing.B) { @@ -103,7 +106,7 @@ func BenchmarkObfs(b *testing.B) { obfs := MakeObfs(key, payloadCipher) b.ResetTimer() for i := 0; i < b.N; i++ { - obfs(testFrame) + obfs(testFrame, obfsBuf) } }) } diff --git a/internal/multiplex/stream.go b/internal/multiplex/stream.go index 02c8d73..1d17af7 100644 --- a/internal/multiplex/stream.go +++ b/internal/multiplex/stream.go @@ -37,6 +37,8 @@ type Stream struct { // close(die) is used to notify different goroutines that this stream is closing closed uint32 + + obfsBuf []byte } func makeStream(id uint32, sesh *Session) *Stream { @@ -46,6 +48,7 @@ func makeStream(id uint32, sesh *Session) *Stream { sh: []*frameNode{}, newFrameCh: make(chan *Frame, 1024), sortedBuf: NewBufferedPipe(), + obfsBuf: make([]byte, 17000), } go stream.recvNewFrame() log.Tracef("stream %v opened", id) @@ -93,11 +96,11 @@ func (s *Stream) Write(in []byte) (n int, err error) { Payload: in, } - tlsRecord, err := s.session.Obfs(f) + i, err := s.session.Obfs(f, s.obfsBuf) if err != nil { - return 0, err + return i, err } - n, err = s.session.sb.send(tlsRecord) + n, err = s.session.sb.send(s.obfsBuf[:i]) return } @@ -136,8 +139,14 @@ func (s *Stream) Close() error { Closing: 1, Payload: pad, } - tlsRecord, _ := s.session.Obfs(f) - s.session.sb.send(tlsRecord) + i, err := s.session.Obfs(f, s.obfsBuf) + if err != nil { + return err + } + _, err = s.session.sb.send(s.obfsBuf[:i]) + if err != nil { + return err + } s._close() s.session.delStream(s.id) diff --git a/internal/multiplex/stream_test.go b/internal/multiplex/stream_test.go index 928a9cf..c6a7db0 100644 --- a/internal/multiplex/stream_test.go +++ b/internal/multiplex/stream_test.go @@ -92,16 +92,18 @@ func TestStream_Read(t *testing.T) { var streamID uint32 buf := make([]byte, 10) + + obfsBuf := make([]byte, 512) t.Run("Plain read", func(t *testing.T) { f.StreamID = streamID - obfsed, _ := sesh.Obfs(f) + i, _ := sesh.Obfs(f, obfsBuf) streamID++ - ch <- obfsed + ch <- obfsBuf[:i] stream, err := sesh.Accept() if err != nil { t.Error("failed to accept stream", err) } - i, err := stream.Read(buf) + i, err = stream.Read(buf) if err != nil { t.Error("failed to read", err) } @@ -115,9 +117,9 @@ func TestStream_Read(t *testing.T) { }) t.Run("Nil buf", func(t *testing.T) { f.StreamID = streamID - obfsed, _ := sesh.Obfs(f) + i, _ := sesh.Obfs(f, obfsBuf) streamID++ - ch <- obfsed + ch <- obfsBuf[:i] stream, _ := sesh.Accept() i, err := stream.Read(nil) if i != 0 || err != nil { @@ -135,9 +137,9 @@ func TestStream_Read(t *testing.T) { }) t.Run("Read after stream close", func(t *testing.T) { f.StreamID = streamID - obfsed, _ := sesh.Obfs(f) + i, _ := sesh.Obfs(f, obfsBuf) streamID++ - ch <- obfsed + ch <- obfsBuf[:i] stream, _ := sesh.Accept() stream.Close() i, err := stream.Read(buf) @@ -159,9 +161,9 @@ func TestStream_Read(t *testing.T) { }) t.Run("Read after session close", func(t *testing.T) { f.StreamID = streamID - obfsed, _ := sesh.Obfs(f) + i, _ := sesh.Obfs(f, obfsBuf) streamID++ - ch <- obfsed + ch <- obfsBuf[:i] stream, _ := sesh.Accept() sesh.Close() i, err := stream.Read(buf)