Change the framing Encoder/Decoder to take the destination slice.

In theory this is easier on the garbage collector.  Probably could
reuse more of the intermediary buffers by stashing them in the
connection state, but that makes the code kind of messy.  This should
be an improvement.
This commit is contained in:
Yawning Angel 2014-05-14 09:58:53 +00:00
parent 89d5338eed
commit 48c6f06d04
3 changed files with 63 additions and 67 deletions

View File

@ -61,6 +61,7 @@ import (
"errors"
"fmt"
"hash"
"io"
"code.google.com/p/go.crypto/nacl/secretbox"
@ -172,40 +173,41 @@ func NewEncoder(key []byte) *Encoder {
}
// Encode encodes a single frame worth of payload and returns the encoded
// length and the resulting frame. InvalidPayloadLengthError is recoverable,
// all other errors MUST be treated as fatal and the session aborted.
func (encoder *Encoder) Encode(payload []byte) (int, []byte, error) {
// length. InvalidPayloadLengthError is recoverable, all other errors MUST be
// treated as fatal and the session aborted.
func (encoder *Encoder) Encode(frame, payload []byte) (n int, err error) {
payloadLen := len(payload)
if MaximumFramePayloadLength < payloadLen {
return 0, nil, InvalidPayloadLengthError(payloadLen)
return 0, InvalidPayloadLengthError(payloadLen)
}
if len(frame) < payloadLen + FrameOverhead {
return 0, io.ErrShortBuffer
}
// Generate a new nonce.
var nonce [nonceLength]byte
err := encoder.nonce.bytes(&nonce)
err = encoder.nonce.bytes(&nonce)
if err != nil {
return 0, nil, err
return 0, err
}
encoder.nonce.counter++
// Encrypt and MAC payload.
var box []byte
box = secretbox.Seal(nil, payload, &nonce, &encoder.key)
box := secretbox.Seal(frame[:lengthLength], payload, &nonce, &encoder.key)
// Obfuscate the length.
length := uint16(len(box))
length := uint16(len(box)-lengthLength)
encoder.sip.Write(nonce[:])
lengthMask := encoder.sip.Sum(nil)
encoder.sip.Reset()
length ^= binary.BigEndian.Uint16(lengthMask)
var obfsLen [lengthLength]byte
binary.BigEndian.PutUint16(obfsLen[:], length)
binary.BigEndian.PutUint16(frame[:2], length)
// Prepare the next obfsucator.
encoder.sip.Write(box)
encoder.sip.Write(box[lengthLength:])
// Return the frame.
return payloadLen + FrameOverhead, append(obfsLen[:], box...), nil
return len(box), nil
}
// Decoder is a frame decoder instance.
@ -233,23 +235,23 @@ func NewDecoder(key []byte) *Decoder {
return decoder
}
// Decode decodes a stream of data and returns the length and decoded frame if
// any. ErrAgain is a temporary failure, all other errors MUST be treated as
// fatal and the session aborted.
func (decoder *Decoder) Decode(data *bytes.Buffer) (int, []byte, error) {
// Decode decodes a stream of data and returns the length if any. ErrAgain is
// a temporary failure, all other errors MUST be treated as fatal and the
// session aborted.
func (decoder *Decoder) Decode(data []byte, frames *bytes.Buffer) (int, error) {
// A length of 0 indicates that we do not know how big the next frame is
// going to be.
if decoder.nextLength == 0 {
// Attempt to pull out the next frame length.
if lengthLength > data.Len() {
return 0, nil, ErrAgain
if lengthLength > frames.Len() {
return 0, ErrAgain
}
// Remove the length field from the buffer.
var obfsLen [lengthLength]byte
n, err := data.Read(obfsLen[:])
n, err := frames.Read(obfsLen[:])
if err != nil {
return 0, nil, err
return 0, err
} else if n != lengthLength {
// Should *NEVER* happen, since at least 2 bytes exist.
panic(fmt.Sprintf("BUG: Failed to read obfuscated length: %d", n))
@ -258,7 +260,7 @@ func (decoder *Decoder) Decode(data *bytes.Buffer) (int, []byte, error) {
// Derive the nonce the peer used.
err = decoder.nonce.bytes(&decoder.nextNonce)
if err != nil {
return 0, nil, err
return 0, err
}
// Deobfuscate the length field.
@ -268,36 +270,36 @@ func (decoder *Decoder) Decode(data *bytes.Buffer) (int, []byte, error) {
decoder.sip.Reset()
length ^= binary.BigEndian.Uint16(lengthMask)
if maxFrameLength < length || minFrameLength > length {
return 0, nil, InvalidFrameLengthError(length)
return 0, InvalidFrameLengthError(length)
}
decoder.nextLength = length
}
if int(decoder.nextLength) > data.Len() {
return 0, nil, ErrAgain
if int(decoder.nextLength) > frames.Len() {
return 0, ErrAgain
}
// Unseal the frame.
box := make([]byte, decoder.nextLength)
n, err := data.Read(box)
var box [maxFrameLength]byte
n, err := frames.Read(box[:decoder.nextLength])
if err != nil {
return 0, nil, err
return 0, err
} else if n != int(decoder.nextLength) {
// Should *NEVER* happen, since at least 2 bytes exist.
// Should *NEVER* happen, since the length is checked.
panic(fmt.Sprintf("BUG: Failed to read secretbox, got %d, should have %d",
n, decoder.nextLength))
}
out, ok := secretbox.Open(nil, box, &decoder.nextNonce, &decoder.key)
out, ok := secretbox.Open(data[:0], box[:n], &decoder.nextNonce, &decoder.key)
if !ok {
return 0, nil, ErrTagMismatch
return 0, ErrTagMismatch
}
decoder.sip.Write(box)
decoder.sip.Write(box[:n])
// Clean up and prepare for the next frame.
decoder.nextLength = 0
decoder.nonce.counter++
return len(out), out, nil
return len(out), nil
}
/* vim :set ts=4 sw=4 sts=4 noet : */

View File

@ -69,7 +69,8 @@ func TestEncoder_Encode(t *testing.T) {
buf := make([]byte, MaximumFramePayloadLength)
_, _ = rand.Read(buf) // YOLO
for i := 0; i <= MaximumFramePayloadLength; i++ {
n, frame, err := encoder.Encode(buf[0:i])
var frame [MaximumSegmentLength]byte
n, err := encoder.Encode(frame[:], buf[0:i])
if err != nil {
t.Fatalf("Encoder.encode([%d]byte), failed: %s", i, err)
}
@ -77,10 +78,6 @@ func TestEncoder_Encode(t *testing.T) {
t.Fatalf("Unexpected encoded framesize: %d, expecting %d", n, i+
FrameOverhead)
}
if len(frame) != n {
t.Fatalf("Encoded frame length/rval mismatch: %d != %d",
len(frame), n)
}
}
}
@ -88,9 +85,10 @@ func TestEncoder_Encode(t *testing.T) {
func TestEncoder_Encode_Oversize(t *testing.T) {
encoder := newEncoder(t)
buf := make([]byte, MaximumFramePayloadLength+1)
_, _ = rand.Read(buf) // YOLO
_, _, err := encoder.Encode(buf)
var frame [MaximumSegmentLength]byte
var buf [MaximumFramePayloadLength+1]byte
_, _ = rand.Read(buf[:]) // YOLO
_, err := encoder.Encode(frame[:], buf[:])
if _, ok := err.(InvalidPayloadLengthError); !ok {
t.Error("Encoder.encode() returned unexpected error:", err)
}
@ -112,10 +110,11 @@ func TestDecoder_Decode(t *testing.T) {
encoder := NewEncoder(key)
decoder := NewDecoder(key)
buf := make([]byte, MaximumFramePayloadLength)
_, _ = rand.Read(buf) // YOLO
var buf [MaximumFramePayloadLength]byte
_, _ = rand.Read(buf[:]) // YOLO
for i := 0; i <= MaximumFramePayloadLength; i++ {
encLen, frame, err := encoder.Encode(buf[0:i])
var frame [MaximumSegmentLength]byte
encLen, err := encoder.Encode(frame[:], buf[0:i])
if err != nil {
t.Fatalf("Encoder.encode([%d]byte), failed: %s", i, err)
}
@ -123,12 +122,10 @@ func TestDecoder_Decode(t *testing.T) {
t.Fatalf("Unexpected encoded framesize: %d, expecting %d", encLen,
i+FrameOverhead)
}
if len(frame) != encLen {
t.Fatalf("Encoded frame length/rval mismatch: %d != %d",
len(frame), encLen)
}
decLen, decoded, err := decoder.Decode(bytes.NewBuffer(frame))
var decoded [MaximumFramePayloadLength]byte
decLen, err := decoder.Decode(decoded[:], bytes.NewBuffer(frame[:encLen]))
if err != nil {
t.Fatalf("Decoder.decode([%d]byte), failed: %s", i, err)
}
@ -136,13 +133,8 @@ func TestDecoder_Decode(t *testing.T) {
t.Fatalf("Unexpected decoded framesize: %d, expecting %d",
decLen, i)
}
if len(decoded) != i {
t.Fatalf("Encoded frame length/rval mismatch: %d != %d",
len(decoded), i)
}
if 0 != bytes.Compare(decoded, buf[0:i]) {
if 0 != bytes.Compare(decoded[:decLen], buf[:i]) {
t.Fatalf("Frame %d does not match encoder input", i)
}
}
@ -152,6 +144,7 @@ func TestDecoder_Decode(t *testing.T) {
// of payload.
func BenchmarkEncoder_Encode(b *testing.B) {
var chopBuf [MaximumFramePayloadLength]byte
var frame [MaximumSegmentLength]byte
payload := make([]byte, 1024*1024)
encoder := NewEncoder(generateRandomKey())
b.ResetTimer()
@ -165,8 +158,8 @@ func BenchmarkEncoder_Encode(b *testing.B) {
b.Fatal("buffer.Read() failed:", err)
}
n, frame, err := encoder.Encode(chopBuf[:n])
transfered += len(frame) - FrameOverhead
n, err = encoder.Encode(frame[:], chopBuf[:n])
transfered += n - FrameOverhead
}
if transfered != len(payload) {
b.Fatalf("Transfered length mismatch: %d != %d", transfered,

View File

@ -100,18 +100,18 @@ func (c *Obfs4Conn) producePacket(w io.Writer, pktType uint8, data []byte, padLe
pktLen := packetOverhead + len(data) + int(padLen)
// Encode the packet in an AEAD frame.
// TODO: Change Encode to write into frame directly
var frame []byte
_, frame, err = c.encoder.Encode(pkt[:pktLen])
var frame [framing.MaximumSegmentLength]byte
frameLen := 0
frameLen, err = c.encoder.Encode(frame[:], pkt[:pktLen])
if err != nil {
// All encoder errors are fatal.
return
}
var wrLen int
wrLen, err = w.Write(frame)
wrLen, err = w.Write(frame[:frameLen])
if err != nil {
return
} else if wrLen < len(frame) {
} else if wrLen < frameLen {
err = io.ErrShortWrite
return
}
@ -132,22 +132,23 @@ func (c *Obfs4Conn) consumeFramedPackets(w io.Writer) (n int, err error) {
}
c.receiveBuffer.Write(buf[:rdLen])
var decoded [framing.MaximumFramePayloadLength]byte
for c.receiveBuffer.Len() > 0 {
// Decrypt an AEAD frame.
// TODO: Change Decode to write into packet directly
var pkt []byte
_, pkt, err = c.decoder.Decode(&c.receiveBuffer)
decLen := 0
decLen, err = c.decoder.Decode(decoded[:], &c.receiveBuffer)
if err == framing.ErrAgain {
// The accumulated payload does not make up a full frame.
return
} else if err != nil {
break
} else if len(pkt) < packetOverhead {
err = InvalidPacketLengthError(len(pkt))
} else if decLen < packetOverhead {
err = InvalidPacketLengthError(decLen)
break
}
// Decode the packet.
pkt := decoded[0:decLen]
pktType := pkt[0]
payloadLen := binary.BigEndian.Uint16(pkt[1:])
if int(payloadLen) > len(pkt)-packetOverhead {