Move the server keypair generation to right after Accept().

Instead of threading the code, move the keypair generation to right
after Accept() is called.  This should mask the timing differential due
to the rejection sampling with the noise from the variablity in how
long it takes for the server to get around to pulling a connection out
of the backlog, and the time taken for the client to send it's portion
of the handshake.

The downside is that anyone connecting to the obfs4 port does force us
to do a bunch of math, but the obfs4 math is relatively cheap compared
to it's precursors.

Fixes #9.
merge-requests/3/head
Yawning Angel 10 years ago
parent 2001f0b698
commit 36228437c4

@ -121,7 +121,7 @@ type clientHandshake struct {
serverMark []byte serverMark []byte
} }
func newClientHandshake(nodeID *ntor.NodeID, serverIdentity *ntor.PublicKey, sessionKey *ntor.Keypair) (*clientHandshake, error) { func newClientHandshake(nodeID *ntor.NodeID, serverIdentity *ntor.PublicKey, sessionKey *ntor.Keypair) *clientHandshake {
hs := new(clientHandshake) hs := new(clientHandshake)
hs.keypair = sessionKey hs.keypair = sessionKey
hs.nodeID = nodeID hs.nodeID = nodeID
@ -129,7 +129,7 @@ func newClientHandshake(nodeID *ntor.NodeID, serverIdentity *ntor.PublicKey, ses
hs.padLen = csrand.IntRange(clientMinPadLength, clientMaxPadLength) hs.padLen = csrand.IntRange(clientMinPadLength, clientMaxPadLength)
hs.mac = hmac.New(sha256.New, append(hs.serverIdentity.Bytes()[:], hs.nodeID.Bytes()[:]...)) hs.mac = hmac.New(sha256.New, append(hs.serverIdentity.Bytes()[:], hs.nodeID.Bytes()[:]...))
return hs, nil return hs
} }
func (hs *clientHandshake) generateHandshake() ([]byte, error) { func (hs *clientHandshake) generateHandshake() ([]byte, error) {
@ -236,8 +236,9 @@ type serverHandshake struct {
clientMark []byte clientMark []byte
} }
func newServerHandshake(nodeID *ntor.NodeID, serverIdentity *ntor.Keypair) *serverHandshake { func newServerHandshake(nodeID *ntor.NodeID, serverIdentity *ntor.Keypair, sessionKey *ntor.Keypair) *serverHandshake {
hs := new(serverHandshake) hs := new(serverHandshake)
hs.keypair = sessionKey
hs.nodeID = nodeID hs.nodeID = nodeID
hs.serverIdentity = serverIdentity hs.serverIdentity = serverIdentity
hs.padLen = csrand.IntRange(serverMinPadLength, serverMaxPadLength) hs.padLen = csrand.IntRange(serverMinPadLength, serverMaxPadLength)
@ -312,14 +313,6 @@ func (hs *serverHandshake) parseClientHandshake(filter *replayFilter, resp []byt
return nil, ErrInvalidHandshake return nil, ErrInvalidHandshake
} }
// At this point the client knows that we exist, so do the keypair
// generation and complete our side of the handshake.
var err error
hs.keypair, err = ntor.NewKeypair(true)
if err != nil {
return nil, err
}
clientPublic := hs.clientRepresentative.ToPublic() clientPublic := hs.clientRepresentative.ToPublic()
ok, seed, auth := ntor.ServerHandshake(clientPublic, hs.keypair, ok, seed, auth := ntor.ServerHandshake(clientPublic, hs.keypair,
hs.serverIdentity, hs.nodeID) hs.serverIdentity, hs.nodeID)

@ -47,10 +47,7 @@ func TestHandshakeNtor(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("[%d:0] ntor.NewKeypair failed: %s", l, err) t.Fatalf("[%d:0] ntor.NewKeypair failed: %s", l, err)
} }
clientHs, err := newClientHandshake(nodeID, idKeypair.Public(), clientKeypair) clientHs := newClientHandshake(nodeID, idKeypair.Public(), clientKeypair)
if err != nil {
t.Fatalf("[%d:0] newClientHandshake failed: %s", l, err)
}
clientHs.padLen = l clientHs.padLen = l
// Generate what the client will send to the server. // Generate what the client will send to the server.
@ -69,7 +66,11 @@ func TestHandshakeNtor(t *testing.T) {
} }
// Generate the server state and override the pad length. // Generate the server state and override the pad length.
serverHs := newServerHandshake(nodeID, idKeypair) serverKeypair, err := ntor.NewKeypair(true)
if err != nil {
t.Fatalf("[%d:0] ntor.NewKeypair failed: %s", l, err)
}
serverHs := newServerHandshake(nodeID, idKeypair, serverKeypair)
serverHs.padLen = serverMinPadLength serverHs.padLen = serverMinPadLength
// Parse the client handshake message. // Parse the client handshake message.
@ -107,10 +108,7 @@ func TestHandshakeNtor(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("[%d:0] ntor.NewKeypair failed: %s", l, err) t.Fatalf("[%d:0] ntor.NewKeypair failed: %s", l, err)
} }
clientHs, err := newClientHandshake(nodeID, idKeypair.Public(), clientKeypair) clientHs := newClientHandshake(nodeID, idKeypair.Public(), clientKeypair)
if err != nil {
t.Fatalf("[%d:0] newClientHandshake failed: %s", l, err)
}
clientHs.padLen = clientMinPadLength clientHs.padLen = clientMinPadLength
// Generate what the client will send to the server. // Generate what the client will send to the server.
@ -123,7 +121,11 @@ func TestHandshakeNtor(t *testing.T) {
} }
// Generate the server state and override the pad length. // Generate the server state and override the pad length.
serverHs := newServerHandshake(nodeID, idKeypair) serverKeypair, err := ntor.NewKeypair(true)
if err != nil {
t.Fatalf("[%d:0] ntor.NewKeypair failed: %s", l, err)
}
serverHs := newServerHandshake(nodeID, idKeypair, serverKeypair)
serverHs.padLen = l serverHs.padLen = l
// Parse the client handshake message. // Parse the client handshake message.
@ -158,7 +160,7 @@ func TestHandshakeNtor(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("ntor.NewKeypair failed: %s", err) t.Fatalf("ntor.NewKeypair failed: %s", err)
} }
clientHs, err := newClientHandshake(nodeID, idKeypair.Public(), clientKeypair) clientHs := newClientHandshake(nodeID, idKeypair.Public(), clientKeypair)
if err != nil { if err != nil {
t.Fatalf("newClientHandshake failed: %s", err) t.Fatalf("newClientHandshake failed: %s", err)
} }
@ -168,7 +170,11 @@ func TestHandshakeNtor(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("clientHandshake.generateHandshake() (forced oversize) failed: %s", err) t.Fatalf("clientHandshake.generateHandshake() (forced oversize) failed: %s", err)
} }
serverHs := newServerHandshake(nodeID, idKeypair) serverKeypair, err := ntor.NewKeypair(true)
if err != nil {
t.Fatalf("ntor.NewKeypair failed: %s", err)
}
serverHs := newServerHandshake(nodeID, idKeypair, serverKeypair)
_, err = serverHs.parseClientHandshake(serverFilter, clientBlob) _, err = serverHs.parseClientHandshake(serverFilter, clientBlob)
if err == nil { if err == nil {
t.Fatalf("serverHandshake.parseClientHandshake() succeded (oversized)") t.Fatalf("serverHandshake.parseClientHandshake() succeded (oversized)")
@ -180,7 +186,7 @@ func TestHandshakeNtor(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("clientHandshake.generateHandshake() (forced undersize) failed: %s", err) t.Fatalf("clientHandshake.generateHandshake() (forced undersize) failed: %s", err)
} }
serverHs = newServerHandshake(nodeID, idKeypair) serverHs = newServerHandshake(nodeID, idKeypair, serverKeypair)
_, err = serverHs.parseClientHandshake(serverFilter, clientBlob) _, err = serverHs.parseClientHandshake(serverFilter, clientBlob)
if err == nil { if err == nil {
t.Fatalf("serverHandshake.parseClientHandshake() succeded (undersized)") t.Fatalf("serverHandshake.parseClientHandshake() succeded (undersized)")
@ -198,7 +204,7 @@ func TestHandshakeNtor(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("clientHandshake.generateHandshake() failed: %s", err) t.Fatalf("clientHandshake.generateHandshake() failed: %s", err)
} }
serverHs = newServerHandshake(nodeID, idKeypair) serverHs = newServerHandshake(nodeID, idKeypair, serverKeypair)
serverHs.padLen = serverMaxPadLength + inlineSeedFrameLength + 1 serverHs.padLen = serverMaxPadLength + inlineSeedFrameLength + 1
_, err = serverHs.parseClientHandshake(serverFilter, clientBlob) _, err = serverHs.parseClientHandshake(serverFilter, clientBlob)
if err != nil { if err != nil {

@ -159,7 +159,6 @@ func (c *Obfs4Conn) clientHandshake(nodeID *ntor.NodeID, publicKey *ntor.PublicK
} }
defer func() { defer func() {
// The session key is not needed past returning from this routine.
c.sessionKey = nil c.sessionKey = nil
if err != nil { if err != nil {
c.setBroken() c.setBroken()
@ -169,10 +168,7 @@ func (c *Obfs4Conn) clientHandshake(nodeID *ntor.NodeID, publicKey *ntor.PublicK
// Generate/send the client handshake. // Generate/send the client handshake.
var hs *clientHandshake var hs *clientHandshake
var blob []byte var blob []byte
hs, err = newClientHandshake(nodeID, publicKey, c.sessionKey) hs = newClientHandshake(nodeID, publicKey, c.sessionKey)
if err != nil {
return
}
blob, err = hs.generateHandshake() blob, err = hs.generateHandshake()
if err != nil { if err != nil {
return return
@ -231,12 +227,13 @@ func (c *Obfs4Conn) serverHandshake(nodeID *ntor.NodeID, keypair *ntor.Keypair)
} }
defer func() { defer func() {
c.sessionKey = nil
if err != nil { if err != nil {
c.setBroken() c.setBroken()
} }
}() }()
hs := newServerHandshake(nodeID, keypair) hs := newServerHandshake(nodeID, keypair, c.sessionKey)
err = c.conn.SetDeadline(time.Now().Add(connectionTimeout)) err = c.conn.SetDeadline(time.Now().Add(connectionTimeout))
if err != nil { if err != nil {
return return
@ -645,6 +642,17 @@ func (l *Obfs4Listener) AcceptObfs4() (*Obfs4Conn, error) {
// Allocate the obfs4 connection state. // Allocate the obfs4 connection state.
cObfs := new(Obfs4Conn) cObfs := new(Obfs4Conn)
// Generate the session keypair *before* consuming data from the peer, to
// add more noise to the keypair generation time. The idea is that jitter
// here is masked by network latency (the time it takes for a server to
// accept a socket out of the backlog should not be fixed, and the client
// needs to send the public key).
cObfs.sessionKey, err = ntor.NewKeypair(true)
if err != nil {
return nil, err
}
cObfs.conn = c cObfs.conn = c
cObfs.isServer = true cObfs.isServer = true
cObfs.listener = l cObfs.listener = l

Loading…
Cancel
Save