From 2001f0b698183b998dbf8e52f5d40a0d82aeef09 Mon Sep 17 00:00:00 2001 From: Yawning Angel Date: Sun, 1 Jun 2014 04:51:33 +0000 Subject: [PATCH] Generate client keypairs before connecting, instead of after. Part of issue #9. --- handshake_ntor.go | 9 ++------- handshake_ntor_test.go | 24 ++++++++++++++++++------ obfs4.go | 14 +++++++++++++- 3 files changed, 33 insertions(+), 14 deletions(-) diff --git a/handshake_ntor.go b/handshake_ntor.go index fc107c2..92f00dc 100644 --- a/handshake_ntor.go +++ b/handshake_ntor.go @@ -121,14 +121,9 @@ type clientHandshake struct { serverMark []byte } -func newClientHandshake(nodeID *ntor.NodeID, serverIdentity *ntor.PublicKey) (*clientHandshake, error) { - var err error - +func newClientHandshake(nodeID *ntor.NodeID, serverIdentity *ntor.PublicKey, sessionKey *ntor.Keypair) (*clientHandshake, error) { hs := new(clientHandshake) - hs.keypair, err = ntor.NewKeypair(true) - if err != nil { - return nil, err - } + hs.keypair = sessionKey hs.nodeID = nodeID hs.serverIdentity = serverIdentity hs.padLen = csrand.IntRange(clientMinPadLength, clientMaxPadLength) diff --git a/handshake_ntor_test.go b/handshake_ntor_test.go index b3e0a4d..69fb442 100644 --- a/handshake_ntor_test.go +++ b/handshake_ntor_test.go @@ -43,9 +43,13 @@ func TestHandshakeNtor(t *testing.T) { // Test client handshake padding. for l := clientMinPadLength; l <= clientMaxPadLength; l++ { // Generate the client state and override the pad length. - clientHs, err := newClientHandshake(nodeID, idKeypair.Public()) + clientKeypair, err := ntor.NewKeypair(true) if err != nil { - t.Fatalf("[%d:0] newClientHandshake failed:", l, err) + t.Fatalf("[%d:0] ntor.NewKeypair failed: %s", l, err) + } + clientHs, err := newClientHandshake(nodeID, idKeypair.Public(), clientKeypair) + if err != nil { + t.Fatalf("[%d:0] newClientHandshake failed: %s", l, err) } clientHs.padLen = l @@ -99,9 +103,13 @@ func TestHandshakeNtor(t *testing.T) { // Test server handshake padding. for l := serverMinPadLength; l <= serverMaxPadLength+inlineSeedFrameLength; l++ { // Generate the client state and override the pad length. - clientHs, err := newClientHandshake(nodeID, idKeypair.Public()) + clientKeypair, err := ntor.NewKeypair(true) + if err != nil { + t.Fatalf("[%d:0] ntor.NewKeypair failed: %s", l, err) + } + clientHs, err := newClientHandshake(nodeID, idKeypair.Public(), clientKeypair) if err != nil { - t.Fatalf("[%d:0] newClientHandshake failed:", l, err) + t.Fatalf("[%d:0] newClientHandshake failed: %s", l, err) } clientHs.padLen = clientMinPadLength @@ -146,9 +154,13 @@ func TestHandshakeNtor(t *testing.T) { } // Test oversized client padding. - clientHs, err := newClientHandshake(nodeID, idKeypair.Public()) + clientKeypair, err := ntor.NewKeypair(true) + if err != nil { + t.Fatalf("ntor.NewKeypair failed: %s", err) + } + clientHs, err := newClientHandshake(nodeID, idKeypair.Public(), clientKeypair) if err != nil { - t.Fatalf("newClientHandshake failed:", err) + t.Fatalf("newClientHandshake failed: %s", err) } clientHs.padLen = clientMaxPadLength + 1 diff --git a/obfs4.go b/obfs4.go index c780e0c..cc5e3b9 100644 --- a/obfs4.go +++ b/obfs4.go @@ -69,6 +69,8 @@ const ( type Obfs4Conn struct { conn net.Conn + sessionKey *ntor.Keypair + lenProbDist *wDist iatProbDist *wDist @@ -157,6 +159,8 @@ func (c *Obfs4Conn) clientHandshake(nodeID *ntor.NodeID, publicKey *ntor.PublicK } defer func() { + // The session key is not needed past returning from this routine. + c.sessionKey = nil if err != nil { c.setBroken() } @@ -165,7 +169,7 @@ func (c *Obfs4Conn) clientHandshake(nodeID *ntor.NodeID, publicKey *ntor.PublicK // Generate/send the client handshake. var hs *clientHandshake var blob []byte - hs, err = newClientHandshake(nodeID, publicKey) + hs, err = newClientHandshake(nodeID, publicKey, c.sessionKey) if err != nil { return } @@ -576,6 +580,14 @@ func DialObfs4DialFn(dialFn DialFn, network, address, nodeID, publicKey string, } c.iatProbDist = newWDist(iatSeed, 0, maxIatDelay) } + + // Generate the session keypair *before* connecting to the remote peer. + c.sessionKey, err = ntor.NewKeypair(true) + if err != nil { + return nil, err + } + + // Connect to the remote peer. c.conn, err = dialFn(network, address) if err != nil { return nil, err