Fix transport null pointer

pull/97/head
Andy Wang 5 years ago
parent 99fa812594
commit 39e54bae6c

@ -69,18 +69,19 @@ var ErrBadProxyMethod = errors.New("invalid proxy method")
// is authorised. It also returns a finisher callback function to be called when the caller wishes to proceed with // is authorised. It also returns a finisher callback function to be called when the caller wishes to proceed with
// the handshake // the handshake
func PrepareConnection(firstPacket []byte, sta *State, conn net.Conn) (info ClientInfo, finisher func([]byte) (net.Conn, error), err error) { func PrepareConnection(firstPacket []byte, sta *State, conn net.Conn) (info ClientInfo, finisher func([]byte) (net.Conn, error), err error) {
var transport Transport
switch firstPacket[0] { switch firstPacket[0] {
case 0x47: case 0x47:
info.Transport = WebSocket{} transport = WebSocket{}
case 0x16: case 0x16:
info.Transport = TLS{} transport = TLS{}
default: default:
err = ErrUnreconisedProtocol err = ErrUnreconisedProtocol
return return
} }
var ai authenticationInfo var ai authenticationInfo
ai, finisher, err = info.Transport.handshake(firstPacket, sta.staticPv, conn) ai, finisher, err = transport.handshake(firstPacket, sta.staticPv, conn)
if err != nil { if err != nil {
return return
@ -101,6 +102,6 @@ func PrepareConnection(firstPacket []byte, sta *State, conn net.Conn) (info Clie
err = ErrBadProxyMethod err = ErrBadProxyMethod
return return
} }
info.Transport = transport
return return
} }

@ -3,6 +3,7 @@ package server
import ( import (
"crypto" "crypto"
"encoding/hex" "encoding/hex"
"fmt"
"github.com/cbeuw/Cloak/internal/ecdh" "github.com/cbeuw/Cloak/internal/ecdh"
"testing" "testing"
"time" "time"
@ -123,6 +124,10 @@ func TestPrepareConnection(t *testing.T) {
t.Error("failed to get correct session id") t.Error("failed to get correct session id")
return return
} }
if info.Transport.(fmt.Stringer).String() != "TLS" {
t.Errorf("wrong transport: %v", info.Transport)
return
}
}) })
t.Run("TLS correct but replay", func(t *testing.T) { t.Run("TLS correct but replay", func(t *testing.T) {
sta := getNewState() sta := getNewState()

Loading…
Cancel
Save