diff --git a/internal/client/TLS.go b/internal/client/TLS.go index 62d5961..573e5e5 100644 --- a/internal/client/TLS.go +++ b/internal/client/TLS.go @@ -46,17 +46,6 @@ func addExtRec(typ []byte, data []byte) []byte { return ret } -func addRecordLayer(input []byte, typ []byte, ver []byte) []byte { - length := make([]byte, 2) - binary.BigEndian.PutUint16(length, uint16(len(input))) - ret := make([]byte, 5+len(input)) - copy(ret[0:1], typ) - copy(ret[1:3], ver) - copy(ret[3:5], length) - copy(ret[5:], input) - return ret -} - func genStegClientHello(ai authenticationPayload, serverName string) (ret clientHelloFields) { // random is marshalled ephemeral pub key 32 bytes // The authentication ciphertext and its tag are then distributed among SessionId and X25519KeyShare @@ -68,33 +57,31 @@ func genStegClientHello(ai authenticationPayload, serverName string) (ret client } type DirectTLS struct { - browser + *util.TLSConn + browser browser } -func (DirectTLS) HasRecordLayer() bool { return true } -func (DirectTLS) UnitReadFunc() func(net.Conn, []byte) (int, error) { return util.ReadTLS } - -// PrepareConnection handles the TLS handshake for a given conn and returns the sessionKey +// NewClientTransport handles the TLS handshake for a given conn and returns the sessionKey // if the server proceed with Cloak authentication -func (tls DirectTLS) PrepareConnection(authInfo *authInfo, conn net.Conn) (preparedConn net.Conn, sessionKey [32]byte, err error) { - preparedConn = conn +func (tls *DirectTLS) Handshake(rawConn net.Conn, authInfo authInfo) (sessionKey [32]byte, err error) { payload, sharedSecret := makeAuthenticationPayload(authInfo, rand.Reader, time.Now()) chOnly := tls.browser.composeClientHello(genStegClientHello(payload, authInfo.MockDomain)) - chWithRecordLayer := addRecordLayer(chOnly, []byte{0x16}, []byte{0x03, 0x01}) - _, err = preparedConn.Write(chWithRecordLayer) + chWithRecordLayer := util.AddRecordLayer(chOnly, util.Handshake, util.VersionTLS11) + _, err = rawConn.Write(chWithRecordLayer) if err != nil { return } log.Trace("client hello sent successfully") + tls.TLSConn = &util.TLSConn{Conn: rawConn} buf := make([]byte, 1024) log.Trace("waiting for ServerHello") - _, err = util.ReadTLS(preparedConn, buf) + _, err = tls.Read(buf) if err != nil { return } - encrypted := append(buf[11:43], buf[89:121]...) + encrypted := append(buf[6:38], buf[84:116]...) nonce := encrypted[0:12] ciphertextWithTag := encrypted[12:60] sessionKeySlice, err := util.AESGCMDecrypt(nonce, sharedSecret[:], ciphertextWithTag) @@ -105,12 +92,11 @@ func (tls DirectTLS) PrepareConnection(authInfo *authInfo, conn net.Conn) (prepa for i := 0; i < 2; i++ { // ChangeCipherSpec and EncryptedCert (in the format of application data) - _, err = util.ReadTLS(preparedConn, buf) + _, err = tls.Read(buf) if err != nil { return } } - - return preparedConn, sessionKey, nil + return sessionKey, nil } diff --git a/internal/client/auth.go b/internal/client/auth.go index ad75061..d0f1164 100644 --- a/internal/client/auth.go +++ b/internal/client/auth.go @@ -30,7 +30,7 @@ type authInfo struct { // makeAuthenticationPayload generates the ephemeral key pair, calculates the shared secret, and then compose and // encrypt the authenticationPayload -func makeAuthenticationPayload(authInfo *authInfo, randReader io.Reader, time time.Time) (ret authenticationPayload, sharedSecret [32]byte) { +func makeAuthenticationPayload(authInfo authInfo, randReader io.Reader, time time.Time) (ret authenticationPayload, sharedSecret [32]byte) { /* Authentication data: +----------+----------------+---------------------+-------------+--------------+--------+------------+ diff --git a/internal/client/auth_test.go b/internal/client/auth_test.go index 6fc39c7..56d0c53 100644 --- a/internal/client/auth_test.go +++ b/internal/client/auth_test.go @@ -10,14 +10,14 @@ import ( func TestMakeAuthenticationPayload(t *testing.T) { tests := []struct { - authInfo *authInfo + authInfo authInfo seed io.Reader time time.Time expPayload authenticationPayload expSecret [32]byte }{ { - &authInfo{ + authInfo{ Unordered: false, SessionId: 3421516597, UID: []byte{ diff --git a/internal/client/connector.go b/internal/client/connector.go index 2153f04..760e828 100644 --- a/internal/client/connector.go +++ b/internal/client/connector.go @@ -14,14 +14,14 @@ import ( ) type remoteConnConfig struct { - NumConn int - KeepAlive time.Duration - Protector func(string, string, syscall.RawConn) error - RemoteAddr string - Transport Transport + NumConn int + KeepAlive time.Duration + Protector func(string, string, syscall.RawConn) error + RemoteAddr string + TransportMaker func() Transport } -func MakeSession(connConfig *remoteConnConfig, authInfo *authInfo, isAdmin bool) *mux.Session { +func MakeSession(connConfig remoteConnConfig, authInfo authInfo, isAdmin bool) *mux.Session { log.Info("Attempting to start a new session") if !isAdmin { // sessionID is usergenerated. There shouldn't be a security concern because the scope of @@ -48,16 +48,17 @@ func MakeSession(connConfig *remoteConnConfig, authInfo *authInfo, isAdmin bool) time.Sleep(time.Second * 3) goto makeconn } - var sk [32]byte - remoteConn, sk, err = connConfig.Transport.PrepareConnection(authInfo, remoteConn) + + transportConn := connConfig.TransportMaker() + sk, err := transportConn.Handshake(remoteConn, authInfo) if err != nil { - remoteConn.Close() + transportConn.Close() log.Errorf("Failed to prepare connection to remote: %v", err) time.Sleep(time.Second * 3) goto makeconn } _sessionKey.Store(sk) - connsCh <- remoteConn + connsCh <- transportConn wg.Done() }() } @@ -65,7 +66,7 @@ func MakeSession(connConfig *remoteConnConfig, authInfo *authInfo, isAdmin bool) log.Debug("All underlying connections established") sessionKey := _sessionKey.Load().([32]byte) - obfuscator, err := mux.MakeObfuscator(authInfo.EncryptionMethod, sessionKey, connConfig.Transport.HasRecordLayer()) + obfuscator, err := mux.MakeObfuscator(authInfo.EncryptionMethod, sessionKey) if err != nil { log.Fatal(err) } @@ -73,7 +74,6 @@ func MakeSession(connConfig *remoteConnConfig, authInfo *authInfo, isAdmin bool) seshConfig := mux.SessionConfig{ Obfuscator: obfuscator, Valve: nil, - UnitRead: connConfig.Transport.UnitReadFunc(), Unordered: authInfo.Unordered, } sesh := mux.MakeSession(authInfo.SessionId, seshConfig) diff --git a/internal/client/piper.go b/internal/client/piper.go index 02c2973..44dfab7 100644 --- a/internal/client/piper.go +++ b/internal/client/piper.go @@ -11,7 +11,7 @@ import ( log "github.com/sirupsen/logrus" ) -func RouteUDP(localConfig *localConnConfig, newSeshFunc func() *mux.Session) { +func RouteUDP(localConfig localConnConfig, newSeshFunc func() *mux.Session) { var sesh *mux.Session localUDPAddr, err := net.ResolveUDPAddr("udp", localConfig.LocalAddr) if err != nil { @@ -100,7 +100,7 @@ start: } -func RouteTCP(localConfig *localConnConfig, newSeshFunc func() *mux.Session) { +func RouteTCP(localConfig localConnConfig, newSeshFunc func() *mux.Session) { tcpListener, err := net.Listen("tcp", localConfig.LocalAddr) if err != nil { log.Fatal(err) diff --git a/internal/client/state.go b/internal/client/state.go index b725c09..379c4e4 100644 --- a/internal/client/state.go +++ b/internal/client/state.go @@ -100,13 +100,12 @@ func ParseConfig(conf string) (raw *rawConfig, err error) { return } -func (raw *rawConfig) SplitConfigs() (local *localConnConfig, remote *remoteConnConfig, auth *authInfo, err error) { - nullErr := func(field string) (local *localConnConfig, remote *remoteConnConfig, auth *authInfo, err error) { +func (raw *rawConfig) SplitConfigs() (local localConnConfig, remote remoteConnConfig, auth authInfo, err error) { + nullErr := func(field string) (local localConnConfig, remote remoteConnConfig, auth authInfo, err error) { err = fmt.Errorf("%v cannot be empty", field) return } - auth = new(authInfo) auth.UID = raw.UID auth.Unordered = raw.UDP if raw.ServerName == "" { @@ -145,7 +144,6 @@ func (raw *rawConfig) SplitConfigs() (local *localConnConfig, remote *remoteConn return } - remote = new(remoteConnConfig) if raw.RemoteHost == "" { return nullErr("RemoteHost") } @@ -161,7 +159,11 @@ func (raw *rawConfig) SplitConfigs() (local *localConnConfig, remote *remoteConn // Transport and (if TLS mode), browser switch strings.ToLower(raw.Transport) { case "cdn": - remote.Transport = WSOverTLS{remote.RemoteAddr} + remote.TransportMaker = func() Transport { + return &WSOverTLS{ + cdnDomainPort: remote.RemoteAddr, + } + } case "direct": fallthrough default: @@ -174,7 +176,11 @@ func (raw *rawConfig) SplitConfigs() (local *localConnConfig, remote *remoteConn default: browser = &Chrome{} } - remote.Transport = DirectTLS{browser} + remote.TransportMaker = func() Transport { + return &DirectTLS{ + browser: browser, + } + } } // KeepAlive @@ -184,8 +190,6 @@ func (raw *rawConfig) SplitConfigs() (local *localConnConfig, remote *remoteConn remote.KeepAlive = remote.KeepAlive * time.Second } - local = new(localConnConfig) - if raw.LocalHost == "" { return nullErr("LocalHost") } diff --git a/internal/client/transport.go b/internal/client/transport.go index a361011..809abe2 100644 --- a/internal/client/transport.go +++ b/internal/client/transport.go @@ -3,7 +3,6 @@ package client import "net" type Transport interface { - PrepareConnection(*authInfo, net.Conn) (net.Conn, [32]byte, error) - HasRecordLayer() bool - UnitReadFunc() func(net.Conn, []byte) (int, error) + Handshake(rawConn net.Conn, authInfo authInfo) (sessionKey [32]byte, err error) + net.Conn } diff --git a/internal/client/websocket.go b/internal/client/websocket.go index 202b639..fc02c3c 100644 --- a/internal/client/websocket.go +++ b/internal/client/websocket.go @@ -16,47 +16,44 @@ import ( ) type WSOverTLS struct { + *util.WebSocketConn cdnDomainPort string } -func (WSOverTLS) HasRecordLayer() bool { return false } -func (WSOverTLS) UnitReadFunc() func(net.Conn, []byte) (int, error) { return util.ReadWebSocket } - -func (ws WSOverTLS) PrepareConnection(authInfo *authInfo, cdnConn net.Conn) (preparedConn net.Conn, sessionKey [32]byte, err error) { +func (ws *WSOverTLS) Handshake(rawConn net.Conn, authInfo authInfo) (sessionKey [32]byte, err error) { utlsConfig := &utls.Config{ ServerName: authInfo.MockDomain, InsecureSkipVerify: true, } - uconn := utls.UClient(cdnConn, utlsConfig, utls.HelloChrome_Auto) + uconn := utls.UClient(rawConn, utlsConfig, utls.HelloChrome_Auto) err = uconn.Handshake() - preparedConn = uconn if err != nil { return } u, err := url.Parse("ws://" + ws.cdnDomainPort) if err != nil { - return preparedConn, sessionKey, fmt.Errorf("failed to parse ws url: %v", err) + return sessionKey, fmt.Errorf("failed to parse ws url: %v", err) } payload, sharedSecret := makeAuthenticationPayload(authInfo, rand.Reader, time.Now()) header := http.Header{} header.Add("hidden", base64.StdEncoding.EncodeToString(append(payload.randPubKey[:], payload.ciphertextWithTag[:]...))) - c, _, err := websocket.NewClient(preparedConn, u, header, 16480, 16480) + c, _, err := websocket.NewClient(uconn, u, header, 16480, 16480) if err != nil { - return preparedConn, sessionKey, fmt.Errorf("failed to handshake: %v", err) + return sessionKey, fmt.Errorf("failed to handshake: %v", err) } - preparedConn = &util.WebSocketConn{Conn: c} + ws.WebSocketConn = &util.WebSocketConn{Conn: c} buf := make([]byte, 128) - n, err := preparedConn.Read(buf) + n, err := ws.Read(buf) if err != nil { - return preparedConn, sessionKey, fmt.Errorf("failed to read reply: %v", err) + return sessionKey, fmt.Errorf("failed to read reply: %v", err) } if n != 60 { - return preparedConn, sessionKey, errors.New("reply must be 60 bytes") + return sessionKey, errors.New("reply must be 60 bytes") } reply := buf[:60] diff --git a/internal/multiplex/obfs.go b/internal/multiplex/obfs.go index 640e15a..67c464f 100644 --- a/internal/multiplex/obfs.go +++ b/internal/multiplex/obfs.go @@ -37,11 +37,7 @@ type Obfuscator struct { SessionKey [32]byte } -func MakeObfs(salsaKey [32]byte, payloadCipher cipher.AEAD, hasRecordLayer bool) Obfser { - var rlLen int - if hasRecordLayer { - rlLen = 5 - } +func MakeObfs(salsaKey [32]byte, payloadCipher cipher.AEAD) Obfser { obfs := func(f *Frame, buf []byte) (int, error) { // we need the encrypted data to be at least 8 bytes to be used as nonce for salsa20 stream header encryption // this will be the case if the encryption method is an AEAD cipher, however for plain, it's well possible @@ -56,15 +52,15 @@ func MakeObfs(salsaKey [32]byte, payloadCipher cipher.AEAD, hasRecordLayer bool) } // usefulLen is the amount of bytes that will be eventually sent off - usefulLen := rlLen + HEADER_LEN + len(f.Payload) + int(extraLen) + usefulLen := HEADER_LEN + len(f.Payload) + int(extraLen) if len(buf) < usefulLen { return 0, io.ErrShortBuffer } // we do as much in-place as possible to save allocation - useful := buf[:usefulLen] // (tls header) + payload + potential overhead - header := useful[rlLen : rlLen+HEADER_LEN] - encryptedPayloadWithExtra := useful[rlLen+HEADER_LEN:] + useful := buf[:usefulLen] // stream header + payload + potential overhead + header := useful[:HEADER_LEN] + encryptedPayloadWithExtra := useful[HEADER_LEN:] putU32(header[0:4], f.StreamID) putU64(header[4:12], f.Seq) @@ -84,39 +80,23 @@ func MakeObfs(salsaKey [32]byte, payloadCipher cipher.AEAD, hasRecordLayer bool) nonce := encryptedPayloadWithExtra[len(encryptedPayloadWithExtra)-8:] salsa20.XORKeyStream(header, header, nonce, &salsaKey) - if hasRecordLayer { - recordLayer := useful[0:5] - // We don't use util.AddRecordLayer here to avoid unnecessary malloc - recordLayer[0] = 0x17 - recordLayer[1] = 0x03 - recordLayer[2] = 0x03 - binary.BigEndian.PutUint16(recordLayer[3:5], uint16(HEADER_LEN+len(encryptedPayloadWithExtra))) - } - // Composing final obfsed message return usefulLen, nil } return obfs } -func MakeDeobfs(salsaKey [32]byte, payloadCipher cipher.AEAD, hasRecordLayer bool) Deobfser { - var rlLen int - if hasRecordLayer { - rlLen = 5 - } - // record layer length + stream header length + minimum data size (i.e. nonce size of salsa20) - minInputLen := rlLen + HEADER_LEN + 8 +func MakeDeobfs(salsaKey [32]byte, payloadCipher cipher.AEAD) Deobfser { + // stream header length + minimum data size (i.e. nonce size of salsa20) + minInputLen := HEADER_LEN + 8 deobfs := func(in []byte) (*Frame, error) { if len(in) < minInputLen { return nil, fmt.Errorf("input size %v, but it cannot be shorter than %v bytes", len(in), minInputLen) } - peeled := make([]byte, len(in)-rlLen) - copy(peeled, in[rlLen:]) - - header := peeled[:HEADER_LEN] - pldWithOverHead := peeled[HEADER_LEN:] // payload + potential overhead + header := in[:HEADER_LEN] + pldWithOverHead := in[HEADER_LEN:] // payload + potential overhead - nonce := peeled[len(peeled)-8:] + nonce := in[len(in)-8:] salsa20.XORKeyStream(header, header, nonce, &salsaKey) streamID := u32(header[0:4]) @@ -156,7 +136,7 @@ func MakeDeobfs(salsaKey [32]byte, payloadCipher cipher.AEAD, hasRecordLayer boo return deobfs } -func MakeObfuscator(encryptionMethod byte, sessionKey [32]byte, hasRecordLayer bool) (obfuscator *Obfuscator, err error) { +func MakeObfuscator(encryptionMethod byte, sessionKey [32]byte) (obfuscator *Obfuscator, err error) { var payloadCipher cipher.AEAD switch encryptionMethod { case E_METHOD_PLAIN: @@ -181,8 +161,8 @@ func MakeObfuscator(encryptionMethod byte, sessionKey [32]byte, hasRecordLayer b } obfuscator = &Obfuscator{ - MakeObfs(sessionKey, payloadCipher, hasRecordLayer), - MakeDeobfs(sessionKey, payloadCipher, hasRecordLayer), + MakeObfs(sessionKey, payloadCipher), + MakeDeobfs(sessionKey, payloadCipher), sessionKey, } return diff --git a/internal/multiplex/session.go b/internal/multiplex/session.go index 2c8a141..cfc12f6 100644 --- a/internal/multiplex/session.go +++ b/internal/multiplex/session.go @@ -27,9 +27,6 @@ type SessionConfig struct { Valve - // This is supposed to read one TLS message. - UnitRead func(net.Conn, []byte) (int, error) - Unordered bool SendBufferSize int diff --git a/internal/multiplex/switchboard.go b/internal/multiplex/switchboard.go index fcff9f3..f09b123 100644 --- a/internal/multiplex/switchboard.go +++ b/internal/multiplex/switchboard.go @@ -156,7 +156,7 @@ func (sb *switchboard) deplex(connId uint32, conn net.Conn) { defer conn.Close() buf := make([]byte, sb.recvBufferSize) for { - n, err := sb.session.UnitRead(conn, buf) + n, err := conn.Read(buf) sb.valve.rxWait(n) sb.valve.AddRx(int64(n)) if err != nil { diff --git a/internal/util/tls.go b/internal/util/tls.go new file mode 100644 index 0000000..b440941 --- /dev/null +++ b/internal/util/tls.go @@ -0,0 +1,79 @@ +package util + +import ( + "encoding/binary" + "io" + "net" + "time" +) + +const ( + VersionTLS11 = 0x0301 + VersionTLS13 = 0x0303 + + Handshake = 22 + ApplicationData = 23 +) + +func AddRecordLayer(input []byte, typ byte, ver uint16) []byte { + length := make([]byte, 2) + binary.BigEndian.PutUint16(length, uint16(len(input))) + ret := make([]byte, 5+len(input)) + ret[0] = typ + binary.BigEndian.PutUint16(ret[1:3], ver) + copy(ret[3:5], length) + copy(ret[5:], input) + return ret +} + +type TLSConn struct { + net.Conn +} + +func (tls *TLSConn) LocalAddr() net.Addr { + return tls.Conn.LocalAddr() +} + +func (tls *TLSConn) RemoteAddr() net.Addr { + return tls.Conn.RemoteAddr() +} + +func (tls *TLSConn) SetDeadline(t time.Time) error { + return tls.Conn.SetDeadline(t) +} + +func (tls *TLSConn) SetReadDeadline(t time.Time) error { + return tls.Conn.SetReadDeadline(t) +} + +func (tls *TLSConn) SetWriteDeadline(t time.Time) error { + return tls.Conn.SetWriteDeadline(t) +} + +func (tls *TLSConn) Read(buffer []byte) (n int, err error) { + // TCP is a stream. Multiple TLS messages can arrive at the same time, + // a single message can also be segmented due to MTU of the IP layer. + // This function guareentees a single TLS message to be read and everything + // else is left in the buffer. + _, err = io.ReadFull(tls.Conn, buffer[:5]) + if err != nil { + return + } + + dataLength := int(binary.BigEndian.Uint16(buffer[3:5])) + if dataLength > len(buffer) { + err = io.ErrShortBuffer + return + } + return io.ReadFull(tls.Conn, buffer[:dataLength]) +} + +func (tls *TLSConn) Write(in []byte) (n int, err error) { + // TODO: write record layer directly first? + toWrite := AddRecordLayer(in, ApplicationData, VersionTLS13) + return tls.Conn.Write(toWrite) +} + +func (tls *TLSConn) Close() error { + return tls.Conn.Close() +} diff --git a/internal/util/util.go b/internal/util/util.go index 04ff551..1c6c54e 100644 --- a/internal/util/util.go +++ b/internal/util/util.go @@ -4,7 +4,6 @@ import ( "crypto/aes" "crypto/cipher" "crypto/rand" - "encoding/binary" "io" "net" "time" @@ -60,24 +59,8 @@ func CryptoRandRead(buf []byte) { } // ReadTLS reads TLS data according to its record layer -func ReadTLS(conn net.Conn, buffer []byte) (n int, err error) { - // TCP is a stream. Multiple TLS messages can arrive at the same time, - // a single message can also be segmented due to MTU of the IP layer. - // This function guareentees a single TLS message to be read and everything - // else is left in the buffer. - _, err = io.ReadFull(conn, buffer[:5]) - if err != nil { - return - } - - dataLength := int(binary.BigEndian.Uint16(buffer[3:5])) - if dataLength > len(buffer) { - err = io.ErrShortBuffer - return - } - n, err = io.ReadFull(conn, buffer[5:dataLength+5]) - return n + 5, err -} +//func ReadTLS(conn net.Conn, buffer []byte) (n int, err error) { +//} func Pipe(dst net.Conn, src net.Conn, srcReadTimeout time.Duration) { // The maximum size of TLS message will be 16380+14+16. 14 because of the stream header and 16