obfs4/handshake_ntor.go

387 lines
11 KiB
Go
Raw Normal View History

2014-05-09 10:23:58 +00:00
/*
* Copyright (c) 2014, Yawning Angel <yawning at torproject dot org>
* All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* * Redistributions of source code must retain the above copyright notice,
* this list of conditions and the following disclaimer.
*
* * Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
* ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
* LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
* CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
* SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
* INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
* CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
* ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
* POSSIBILITY OF SUCH DAMAGE.
*/
package obfs4
import (
"bytes"
"crypto/hmac"
"crypto/rand"
"crypto/sha256"
"encoding/hex"
"errors"
"fmt"
"hash"
"math/big"
"strconv"
"time"
"github.com/yawning/obfs4/framing"
"github.com/yawning/obfs4/ntor"
)
const (
clientMinPadLength = serverMinHandshakeLength - clientMinHandshakeLength
clientMaxPadLength = framing.MaximumSegmentLength - clientMinHandshakeLength
clientMinHandshakeLength = ntor.RepresentativeLength + markLength + macLength
clientMaxHandshakeLength = framing.MaximumSegmentLength
serverMinPadLength = 0
serverMaxPadLength = framing.MaximumSegmentLength - serverMinHandshakeLength
serverMinHandshakeLength = ntor.RepresentativeLength + ntor.AuthLength +
markLength + macLength
serverMaxHandshakeLength = framing.MaximumSegmentLength
markLength = sha256.Size
macLength = sha256.Size
)
var ErrMarkNotFoundYet = errors.New("handshake: M_[C,S] not found yet")
var ErrInvalidHandshake = errors.New("handshake: Failed to find M_[C,S]")
var ErrNtorFailed = errors.New("handshake: ntor handshake failure")
type InvalidMacError struct {
Derived []byte
Received []byte
}
func (e *InvalidMacError) Error() string {
return fmt.Sprintf("handshake: MAC mismatch: Dervied: %s Received: %s.",
hex.EncodeToString(e.Derived), hex.EncodeToString(e.Received))
}
type InvalidAuthError struct {
Derived *ntor.Auth
Received *ntor.Auth
}
func (e *InvalidAuthError) Error() string {
return fmt.Sprintf("handshake: ntor AUTH mismatch: Derived: %s Received:%s.",
hex.EncodeToString(e.Derived.Bytes()[:]),
hex.EncodeToString(e.Received.Bytes()[:]))
}
type clientHandshake struct {
keypair *ntor.Keypair
nodeID *ntor.NodeID
serverIdentity *ntor.PublicKey
epochHour []byte
mac hash.Hash
serverRepresentative *ntor.Representative
serverAuth *ntor.Auth
serverMark []byte
}
func newClientHandshake(nodeID *ntor.NodeID, serverIdentity *ntor.PublicKey) (*clientHandshake, error) {
var err error
hs := new(clientHandshake)
hs.keypair, err = ntor.NewKeypair(true)
if err != nil {
return nil, err
}
hs.nodeID = nodeID
hs.serverIdentity = serverIdentity
hs.mac = hmac.New(sha256.New, hs.serverIdentity.Bytes()[:])
return hs, nil
}
func (hs *clientHandshake) generateHandshake() ([]byte, error) {
var buf bytes.Buffer
hs.mac.Reset()
hs.mac.Write(hs.keypair.Representative().Bytes()[:])
mark := hs.mac.Sum(nil)
// The client handshake is X | P_C | M_C | MAC(X | P_C | M_C | E) where:
// * X is the client's ephemeral Curve25519 public key representative.
// * P_C is [0,clientMaxPadLength] bytes of random padding.
// * M_C is HMAC-SHA256(serverIdentity, X)
// * MAC is HMAC-SHA256(serverIdentity, X .... E)
// * E is the string representation of the number of hours since the UNIX
// epoch.
// Generate the padding
pad, err := makePad(clientMinPadLength, clientMaxPadLength)
if err != nil {
return nil, err
}
// Write X, P_C, M_C.
buf.Write(hs.keypair.Representative().Bytes()[:])
buf.Write(pad)
buf.Write(mark)
// Calculate and write the MAC.
hs.mac.Reset()
hs.mac.Write(buf.Bytes())
hs.epochHour = []byte(strconv.FormatInt(getEpochHour(), 10))
hs.mac.Write(hs.epochHour)
buf.Write(hs.mac.Sum(nil))
return buf.Bytes(), nil
}
func (hs *clientHandshake) parseServerHandshake(resp []byte) (int, []byte, error) {
// No point in examining the data unless the miminum plausible response has
// been received.
if serverMinHandshakeLength > len(resp) {
return 0, nil, ErrMarkNotFoundYet
}
if hs.serverRepresentative == nil || hs.serverAuth == nil {
// Pull out the representative/AUTH. (XXX: Add ctors to ntor)
hs.serverRepresentative = new(ntor.Representative)
copy(hs.serverRepresentative.Bytes()[:], resp[0:ntor.RepresentativeLength])
hs.serverAuth = new(ntor.Auth)
copy(hs.serverAuth.Bytes()[:], resp[ntor.RepresentativeLength:])
// Derive the mark
hs.mac.Reset()
hs.mac.Write(hs.serverRepresentative.Bytes()[:])
hs.serverMark = hs.mac.Sum(nil)
}
// Attempt to find the mark + MAC.
pos := findMark(hs.serverMark, resp,
ntor.RepresentativeLength+ntor.AuthLength, serverMaxHandshakeLength)
if pos == -1 {
if len(resp) >= serverMaxHandshakeLength {
return 0, nil, ErrInvalidHandshake
}
return 0, nil, ErrMarkNotFoundYet
}
// Validate the MAC.
hs.mac.Reset()
hs.mac.Write(resp[:pos+markLength])
hs.mac.Write(hs.epochHour)
macCmp := hs.mac.Sum(nil)
macRx := resp[pos+markLength : pos+markLength+macLength]
if !hmac.Equal(macCmp, macRx) {
return 0, nil, &InvalidMacError{macCmp, macRx}
}
// Complete the handshake.
serverPublic := hs.serverRepresentative.ToPublic()
ok, seed, auth := ntor.ClientHandshake(hs.keypair, serverPublic,
hs.serverIdentity, hs.nodeID)
if !ok {
return 0, nil, ErrNtorFailed
}
if !ntor.CompareAuth(auth, hs.serverAuth.Bytes()[:]) {
return 0, nil, &InvalidAuthError{auth, hs.serverAuth}
}
return pos + markLength + macLength, seed.Bytes()[:], nil
}
type serverHandshake struct {
keypair *ntor.Keypair
nodeID *ntor.NodeID
serverIdentity *ntor.Keypair
epochHour []byte
serverAuth *ntor.Auth
mac hash.Hash
clientRepresentative *ntor.Representative
clientMark []byte
}
func newServerHandshake(nodeID *ntor.NodeID, serverIdentity *ntor.Keypair) *serverHandshake {
hs := new(serverHandshake)
hs.nodeID = nodeID
hs.serverIdentity = serverIdentity
hs.mac = hmac.New(sha256.New, hs.serverIdentity.Public().Bytes()[:])
return hs
}
func (hs *serverHandshake) parseClientHandshake(resp []byte) ([]byte, error) {
// No point in examining the data unless the miminum plausible response has
// been received.
if clientMinHandshakeLength > len(resp) {
return nil, ErrMarkNotFoundYet
}
if hs.clientRepresentative == nil {
// Pull out the representative/AUTH. (XXX: Add ctors to ntor)
hs.clientRepresentative = new(ntor.Representative)
copy(hs.clientRepresentative.Bytes()[:], resp[0:ntor.RepresentativeLength])
// Derive the mark
hs.mac.Reset()
hs.mac.Write(hs.clientRepresentative.Bytes()[:])
hs.clientMark = hs.mac.Sum(nil)
}
// Attempt to find the mark + MAC.
pos := findMark(hs.clientMark, resp, ntor.RepresentativeLength,
serverMaxHandshakeLength)
if pos == -1 {
if len(resp) >= clientMaxHandshakeLength {
return nil, ErrInvalidHandshake
}
return nil, ErrMarkNotFoundYet
}
// Validate the MAC.
macFound := false
for _, off := range []int64{0, -1, 1} {
// Allow epoch to be off by up to a hour in either direction.
epochHour := []byte(strconv.FormatInt(getEpochHour()+int64(off), 10))
hs.mac.Reset()
hs.mac.Write(resp[:pos+markLength])
hs.mac.Write(epochHour)
macCmp := hs.mac.Sum(nil)
macRx := resp[pos+markLength : pos+markLength+macLength]
if hmac.Equal(macCmp, macRx) {
macFound = true
hs.epochHour = epochHour
// In theory, we should always evaluate all 3 MACs, but at this
// point we are reasonably confident that the client knows the
// correct NodeID/Public key, and if this fails, we just ignore the
// client for a random interval and drop the connection anyway.
break
}
}
if !macFound {
// This probably should be an InvalidMacError, but conveying the 3 MACS
// that would be accepted is annoying so just return a generic fatal
// failure.
return nil, ErrInvalidHandshake
}
// Client should never sent trailing garbage.
if len(resp) != pos+markLength+macLength {
return nil, ErrInvalidHandshake
}
// At this point the client knows that we exist, so do the keypair
// generation and complete our side of the handshake.
var err error
hs.keypair, err = ntor.NewKeypair(true)
if err != nil {
return nil, err
}
clientPublic := hs.clientRepresentative.ToPublic()
ok, seed, auth := ntor.ServerHandshake(clientPublic, hs.keypair,
hs.serverIdentity, hs.nodeID)
if !ok {
return nil, ErrNtorFailed
}
hs.serverAuth = auth
return seed.Bytes()[:], nil
}
func (hs *serverHandshake) generateHandshake() ([]byte, error) {
var buf bytes.Buffer
hs.mac.Reset()
hs.mac.Write(hs.keypair.Representative().Bytes()[:])
mark := hs.mac.Sum(nil)
// The server handshake is Y | AUTH | P_S | M_S | MAC(Y | AUTH | P_S | M_S | E) where:
// * Y is the server's ephemeral Curve25519 public key representative.
// * AUTH is the ntor handshake AUTH value.
// * P_S is [0,serverMaxPadLength] bytes of random padding.
// * M_S is HMAC-SHA256(serverIdentity, Y)
// * MAC is HMAC-SHA256(serverIdentity, Y .... E)
// * E is the string representation of the number of hours since the UNIX
// epoch.
// Generate the padding
pad, err := makePad(serverMinPadLength, serverMaxPadLength)
if err != nil {
return nil, err
}
// Write Y, AUTH, P_S, M_S.
buf.Write(hs.keypair.Representative().Bytes()[:])
buf.Write(hs.serverAuth.Bytes()[:])
buf.Write(pad)
buf.Write(mark)
// Calculate and write the MAC.
hs.mac.Reset()
hs.mac.Write(buf.Bytes())
hs.epochHour = []byte(strconv.FormatInt(getEpochHour(), 10))
hs.mac.Write(hs.epochHour)
buf.Write(hs.mac.Sum(nil))
return buf.Bytes(), nil
}
// getEpochHour returns the number of hours since the UNIX epoch.
func getEpochHour() int64 {
return time.Now().Unix() / 3600
}
func findMark(mark, buf []byte, startPos, maxPos int) int {
endPos := len(buf)
if endPos > maxPos {
endPos = maxPos
}
// XXX: bytes.Index() uses a naive search, which kind of sucks.
pos := bytes.Index(buf[startPos:endPos], mark)
if pos == -1 {
return -1
}
// Return the index relative to the start of the slice.
return pos + startPos
}
func makePad(min, max int64) ([]byte, error) {
if max < min {
panic(fmt.Sprintf("makePad: min > max (%d, %d)", min, max))
}
padRange := int64((max + 1) - min)
padLen, err := rand.Int(rand.Reader, big.NewInt(padRange))
if err != nil {
return nil, err
}
pad := make([]byte, padLen.Int64()+min)
_, err = rand.Read(pad)
if err != nil {
return nil, err
}
return pad, err
}
/* vim :set ts=4 sw=4 sts=4 noet : */