Generate client keypairs before connecting, instead of after.

Part of issue #9.
merge-requests/3/head
Yawning Angel 10 years ago
parent 697b51b4bd
commit 2001f0b698

@ -121,14 +121,9 @@ type clientHandshake struct {
serverMark []byte
}
func newClientHandshake(nodeID *ntor.NodeID, serverIdentity *ntor.PublicKey) (*clientHandshake, error) {
var err error
func newClientHandshake(nodeID *ntor.NodeID, serverIdentity *ntor.PublicKey, sessionKey *ntor.Keypair) (*clientHandshake, error) {
hs := new(clientHandshake)
hs.keypair, err = ntor.NewKeypair(true)
if err != nil {
return nil, err
}
hs.keypair = sessionKey
hs.nodeID = nodeID
hs.serverIdentity = serverIdentity
hs.padLen = csrand.IntRange(clientMinPadLength, clientMaxPadLength)

@ -43,9 +43,13 @@ func TestHandshakeNtor(t *testing.T) {
// Test client handshake padding.
for l := clientMinPadLength; l <= clientMaxPadLength; l++ {
// Generate the client state and override the pad length.
clientHs, err := newClientHandshake(nodeID, idKeypair.Public())
clientKeypair, err := ntor.NewKeypair(true)
if err != nil {
t.Fatalf("[%d:0] newClientHandshake failed:", l, err)
t.Fatalf("[%d:0] ntor.NewKeypair failed: %s", l, err)
}
clientHs, err := newClientHandshake(nodeID, idKeypair.Public(), clientKeypair)
if err != nil {
t.Fatalf("[%d:0] newClientHandshake failed: %s", l, err)
}
clientHs.padLen = l
@ -99,9 +103,13 @@ func TestHandshakeNtor(t *testing.T) {
// Test server handshake padding.
for l := serverMinPadLength; l <= serverMaxPadLength+inlineSeedFrameLength; l++ {
// Generate the client state and override the pad length.
clientHs, err := newClientHandshake(nodeID, idKeypair.Public())
clientKeypair, err := ntor.NewKeypair(true)
if err != nil {
t.Fatalf("[%d:0] ntor.NewKeypair failed: %s", l, err)
}
clientHs, err := newClientHandshake(nodeID, idKeypair.Public(), clientKeypair)
if err != nil {
t.Fatalf("[%d:0] newClientHandshake failed:", l, err)
t.Fatalf("[%d:0] newClientHandshake failed: %s", l, err)
}
clientHs.padLen = clientMinPadLength
@ -146,9 +154,13 @@ func TestHandshakeNtor(t *testing.T) {
}
// Test oversized client padding.
clientHs, err := newClientHandshake(nodeID, idKeypair.Public())
clientKeypair, err := ntor.NewKeypair(true)
if err != nil {
t.Fatalf("ntor.NewKeypair failed: %s", err)
}
clientHs, err := newClientHandshake(nodeID, idKeypair.Public(), clientKeypair)
if err != nil {
t.Fatalf("newClientHandshake failed:", err)
t.Fatalf("newClientHandshake failed: %s", err)
}
clientHs.padLen = clientMaxPadLength + 1

@ -69,6 +69,8 @@ const (
type Obfs4Conn struct {
conn net.Conn
sessionKey *ntor.Keypair
lenProbDist *wDist
iatProbDist *wDist
@ -157,6 +159,8 @@ func (c *Obfs4Conn) clientHandshake(nodeID *ntor.NodeID, publicKey *ntor.PublicK
}
defer func() {
// The session key is not needed past returning from this routine.
c.sessionKey = nil
if err != nil {
c.setBroken()
}
@ -165,7 +169,7 @@ func (c *Obfs4Conn) clientHandshake(nodeID *ntor.NodeID, publicKey *ntor.PublicK
// Generate/send the client handshake.
var hs *clientHandshake
var blob []byte
hs, err = newClientHandshake(nodeID, publicKey)
hs, err = newClientHandshake(nodeID, publicKey, c.sessionKey)
if err != nil {
return
}
@ -576,6 +580,14 @@ func DialObfs4DialFn(dialFn DialFn, network, address, nodeID, publicKey string,
}
c.iatProbDist = newWDist(iatSeed, 0, maxIatDelay)
}
// Generate the session keypair *before* connecting to the remote peer.
c.sessionKey, err = ntor.NewKeypair(true)
if err != nil {
return nil, err
}
// Connect to the remote peer.
c.conn, err = dialFn(network, address)
if err != nil {
return nil, err

Loading…
Cancel
Save