mirror of
https://gitlab.com/yawning/obfs4.git
synced 2024-11-15 12:12:53 +00:00
Change the framing Encoder/Decoder to take the destination slice.
In theory this is easier on the garbage collector. Probably could reuse more of the intermediary buffers by stashing them in the connection state, but that makes the code kind of messy. This should be an improvement.
This commit is contained in:
parent
89d5338eed
commit
48c6f06d04
@ -61,6 +61,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"hash"
|
||||
"io"
|
||||
|
||||
"code.google.com/p/go.crypto/nacl/secretbox"
|
||||
|
||||
@ -172,40 +173,41 @@ func NewEncoder(key []byte) *Encoder {
|
||||
}
|
||||
|
||||
// Encode encodes a single frame worth of payload and returns the encoded
|
||||
// length and the resulting frame. InvalidPayloadLengthError is recoverable,
|
||||
// all other errors MUST be treated as fatal and the session aborted.
|
||||
func (encoder *Encoder) Encode(payload []byte) (int, []byte, error) {
|
||||
// length. InvalidPayloadLengthError is recoverable, all other errors MUST be
|
||||
// treated as fatal and the session aborted.
|
||||
func (encoder *Encoder) Encode(frame, payload []byte) (n int, err error) {
|
||||
payloadLen := len(payload)
|
||||
if MaximumFramePayloadLength < payloadLen {
|
||||
return 0, nil, InvalidPayloadLengthError(payloadLen)
|
||||
return 0, InvalidPayloadLengthError(payloadLen)
|
||||
}
|
||||
if len(frame) < payloadLen + FrameOverhead {
|
||||
return 0, io.ErrShortBuffer
|
||||
}
|
||||
|
||||
// Generate a new nonce.
|
||||
var nonce [nonceLength]byte
|
||||
err := encoder.nonce.bytes(&nonce)
|
||||
err = encoder.nonce.bytes(&nonce)
|
||||
if err != nil {
|
||||
return 0, nil, err
|
||||
return 0, err
|
||||
}
|
||||
encoder.nonce.counter++
|
||||
|
||||
// Encrypt and MAC payload.
|
||||
var box []byte
|
||||
box = secretbox.Seal(nil, payload, &nonce, &encoder.key)
|
||||
box := secretbox.Seal(frame[:lengthLength], payload, &nonce, &encoder.key)
|
||||
|
||||
// Obfuscate the length.
|
||||
length := uint16(len(box))
|
||||
length := uint16(len(box)-lengthLength)
|
||||
encoder.sip.Write(nonce[:])
|
||||
lengthMask := encoder.sip.Sum(nil)
|
||||
encoder.sip.Reset()
|
||||
length ^= binary.BigEndian.Uint16(lengthMask)
|
||||
var obfsLen [lengthLength]byte
|
||||
binary.BigEndian.PutUint16(obfsLen[:], length)
|
||||
binary.BigEndian.PutUint16(frame[:2], length)
|
||||
|
||||
// Prepare the next obfsucator.
|
||||
encoder.sip.Write(box)
|
||||
encoder.sip.Write(box[lengthLength:])
|
||||
|
||||
// Return the frame.
|
||||
return payloadLen + FrameOverhead, append(obfsLen[:], box...), nil
|
||||
return len(box), nil
|
||||
}
|
||||
|
||||
// Decoder is a frame decoder instance.
|
||||
@ -233,23 +235,23 @@ func NewDecoder(key []byte) *Decoder {
|
||||
return decoder
|
||||
}
|
||||
|
||||
// Decode decodes a stream of data and returns the length and decoded frame if
|
||||
// any. ErrAgain is a temporary failure, all other errors MUST be treated as
|
||||
// fatal and the session aborted.
|
||||
func (decoder *Decoder) Decode(data *bytes.Buffer) (int, []byte, error) {
|
||||
// Decode decodes a stream of data and returns the length if any. ErrAgain is
|
||||
// a temporary failure, all other errors MUST be treated as fatal and the
|
||||
// session aborted.
|
||||
func (decoder *Decoder) Decode(data []byte, frames *bytes.Buffer) (int, error) {
|
||||
// A length of 0 indicates that we do not know how big the next frame is
|
||||
// going to be.
|
||||
if decoder.nextLength == 0 {
|
||||
// Attempt to pull out the next frame length.
|
||||
if lengthLength > data.Len() {
|
||||
return 0, nil, ErrAgain
|
||||
if lengthLength > frames.Len() {
|
||||
return 0, ErrAgain
|
||||
}
|
||||
|
||||
// Remove the length field from the buffer.
|
||||
var obfsLen [lengthLength]byte
|
||||
n, err := data.Read(obfsLen[:])
|
||||
n, err := frames.Read(obfsLen[:])
|
||||
if err != nil {
|
||||
return 0, nil, err
|
||||
return 0, err
|
||||
} else if n != lengthLength {
|
||||
// Should *NEVER* happen, since at least 2 bytes exist.
|
||||
panic(fmt.Sprintf("BUG: Failed to read obfuscated length: %d", n))
|
||||
@ -258,7 +260,7 @@ func (decoder *Decoder) Decode(data *bytes.Buffer) (int, []byte, error) {
|
||||
// Derive the nonce the peer used.
|
||||
err = decoder.nonce.bytes(&decoder.nextNonce)
|
||||
if err != nil {
|
||||
return 0, nil, err
|
||||
return 0, err
|
||||
}
|
||||
|
||||
// Deobfuscate the length field.
|
||||
@ -268,36 +270,36 @@ func (decoder *Decoder) Decode(data *bytes.Buffer) (int, []byte, error) {
|
||||
decoder.sip.Reset()
|
||||
length ^= binary.BigEndian.Uint16(lengthMask)
|
||||
if maxFrameLength < length || minFrameLength > length {
|
||||
return 0, nil, InvalidFrameLengthError(length)
|
||||
return 0, InvalidFrameLengthError(length)
|
||||
}
|
||||
decoder.nextLength = length
|
||||
}
|
||||
|
||||
if int(decoder.nextLength) > data.Len() {
|
||||
return 0, nil, ErrAgain
|
||||
if int(decoder.nextLength) > frames.Len() {
|
||||
return 0, ErrAgain
|
||||
}
|
||||
|
||||
// Unseal the frame.
|
||||
box := make([]byte, decoder.nextLength)
|
||||
n, err := data.Read(box)
|
||||
var box [maxFrameLength]byte
|
||||
n, err := frames.Read(box[:decoder.nextLength])
|
||||
if err != nil {
|
||||
return 0, nil, err
|
||||
return 0, err
|
||||
} else if n != int(decoder.nextLength) {
|
||||
// Should *NEVER* happen, since at least 2 bytes exist.
|
||||
// Should *NEVER* happen, since the length is checked.
|
||||
panic(fmt.Sprintf("BUG: Failed to read secretbox, got %d, should have %d",
|
||||
n, decoder.nextLength))
|
||||
}
|
||||
out, ok := secretbox.Open(nil, box, &decoder.nextNonce, &decoder.key)
|
||||
out, ok := secretbox.Open(data[:0], box[:n], &decoder.nextNonce, &decoder.key)
|
||||
if !ok {
|
||||
return 0, nil, ErrTagMismatch
|
||||
return 0, ErrTagMismatch
|
||||
}
|
||||
decoder.sip.Write(box)
|
||||
decoder.sip.Write(box[:n])
|
||||
|
||||
// Clean up and prepare for the next frame.
|
||||
decoder.nextLength = 0
|
||||
decoder.nonce.counter++
|
||||
|
||||
return len(out), out, nil
|
||||
return len(out), nil
|
||||
}
|
||||
|
||||
/* vim :set ts=4 sw=4 sts=4 noet : */
|
||||
|
@ -69,7 +69,8 @@ func TestEncoder_Encode(t *testing.T) {
|
||||
buf := make([]byte, MaximumFramePayloadLength)
|
||||
_, _ = rand.Read(buf) // YOLO
|
||||
for i := 0; i <= MaximumFramePayloadLength; i++ {
|
||||
n, frame, err := encoder.Encode(buf[0:i])
|
||||
var frame [MaximumSegmentLength]byte
|
||||
n, err := encoder.Encode(frame[:], buf[0:i])
|
||||
if err != nil {
|
||||
t.Fatalf("Encoder.encode([%d]byte), failed: %s", i, err)
|
||||
}
|
||||
@ -77,10 +78,6 @@ func TestEncoder_Encode(t *testing.T) {
|
||||
t.Fatalf("Unexpected encoded framesize: %d, expecting %d", n, i+
|
||||
FrameOverhead)
|
||||
}
|
||||
if len(frame) != n {
|
||||
t.Fatalf("Encoded frame length/rval mismatch: %d != %d",
|
||||
len(frame), n)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -88,9 +85,10 @@ func TestEncoder_Encode(t *testing.T) {
|
||||
func TestEncoder_Encode_Oversize(t *testing.T) {
|
||||
encoder := newEncoder(t)
|
||||
|
||||
buf := make([]byte, MaximumFramePayloadLength+1)
|
||||
_, _ = rand.Read(buf) // YOLO
|
||||
_, _, err := encoder.Encode(buf)
|
||||
var frame [MaximumSegmentLength]byte
|
||||
var buf [MaximumFramePayloadLength+1]byte
|
||||
_, _ = rand.Read(buf[:]) // YOLO
|
||||
_, err := encoder.Encode(frame[:], buf[:])
|
||||
if _, ok := err.(InvalidPayloadLengthError); !ok {
|
||||
t.Error("Encoder.encode() returned unexpected error:", err)
|
||||
}
|
||||
@ -112,10 +110,11 @@ func TestDecoder_Decode(t *testing.T) {
|
||||
encoder := NewEncoder(key)
|
||||
decoder := NewDecoder(key)
|
||||
|
||||
buf := make([]byte, MaximumFramePayloadLength)
|
||||
_, _ = rand.Read(buf) // YOLO
|
||||
var buf [MaximumFramePayloadLength]byte
|
||||
_, _ = rand.Read(buf[:]) // YOLO
|
||||
for i := 0; i <= MaximumFramePayloadLength; i++ {
|
||||
encLen, frame, err := encoder.Encode(buf[0:i])
|
||||
var frame [MaximumSegmentLength]byte
|
||||
encLen, err := encoder.Encode(frame[:], buf[0:i])
|
||||
if err != nil {
|
||||
t.Fatalf("Encoder.encode([%d]byte), failed: %s", i, err)
|
||||
}
|
||||
@ -123,12 +122,10 @@ func TestDecoder_Decode(t *testing.T) {
|
||||
t.Fatalf("Unexpected encoded framesize: %d, expecting %d", encLen,
|
||||
i+FrameOverhead)
|
||||
}
|
||||
if len(frame) != encLen {
|
||||
t.Fatalf("Encoded frame length/rval mismatch: %d != %d",
|
||||
len(frame), encLen)
|
||||
}
|
||||
|
||||
decLen, decoded, err := decoder.Decode(bytes.NewBuffer(frame))
|
||||
var decoded [MaximumFramePayloadLength]byte
|
||||
|
||||
decLen, err := decoder.Decode(decoded[:], bytes.NewBuffer(frame[:encLen]))
|
||||
if err != nil {
|
||||
t.Fatalf("Decoder.decode([%d]byte), failed: %s", i, err)
|
||||
}
|
||||
@ -136,13 +133,8 @@ func TestDecoder_Decode(t *testing.T) {
|
||||
t.Fatalf("Unexpected decoded framesize: %d, expecting %d",
|
||||
decLen, i)
|
||||
}
|
||||
if len(decoded) != i {
|
||||
t.Fatalf("Encoded frame length/rval mismatch: %d != %d",
|
||||
len(decoded), i)
|
||||
|
||||
}
|
||||
|
||||
if 0 != bytes.Compare(decoded, buf[0:i]) {
|
||||
if 0 != bytes.Compare(decoded[:decLen], buf[:i]) {
|
||||
t.Fatalf("Frame %d does not match encoder input", i)
|
||||
}
|
||||
}
|
||||
@ -152,6 +144,7 @@ func TestDecoder_Decode(t *testing.T) {
|
||||
// of payload.
|
||||
func BenchmarkEncoder_Encode(b *testing.B) {
|
||||
var chopBuf [MaximumFramePayloadLength]byte
|
||||
var frame [MaximumSegmentLength]byte
|
||||
payload := make([]byte, 1024*1024)
|
||||
encoder := NewEncoder(generateRandomKey())
|
||||
b.ResetTimer()
|
||||
@ -165,8 +158,8 @@ func BenchmarkEncoder_Encode(b *testing.B) {
|
||||
b.Fatal("buffer.Read() failed:", err)
|
||||
}
|
||||
|
||||
n, frame, err := encoder.Encode(chopBuf[:n])
|
||||
transfered += len(frame) - FrameOverhead
|
||||
n, err = encoder.Encode(frame[:], chopBuf[:n])
|
||||
transfered += n - FrameOverhead
|
||||
}
|
||||
if transfered != len(payload) {
|
||||
b.Fatalf("Transfered length mismatch: %d != %d", transfered,
|
||||
|
21
packet.go
21
packet.go
@ -100,18 +100,18 @@ func (c *Obfs4Conn) producePacket(w io.Writer, pktType uint8, data []byte, padLe
|
||||
pktLen := packetOverhead + len(data) + int(padLen)
|
||||
|
||||
// Encode the packet in an AEAD frame.
|
||||
// TODO: Change Encode to write into frame directly
|
||||
var frame []byte
|
||||
_, frame, err = c.encoder.Encode(pkt[:pktLen])
|
||||
var frame [framing.MaximumSegmentLength]byte
|
||||
frameLen := 0
|
||||
frameLen, err = c.encoder.Encode(frame[:], pkt[:pktLen])
|
||||
if err != nil {
|
||||
// All encoder errors are fatal.
|
||||
return
|
||||
}
|
||||
var wrLen int
|
||||
wrLen, err = w.Write(frame)
|
||||
wrLen, err = w.Write(frame[:frameLen])
|
||||
if err != nil {
|
||||
return
|
||||
} else if wrLen < len(frame) {
|
||||
} else if wrLen < frameLen {
|
||||
err = io.ErrShortWrite
|
||||
return
|
||||
}
|
||||
@ -132,22 +132,23 @@ func (c *Obfs4Conn) consumeFramedPackets(w io.Writer) (n int, err error) {
|
||||
}
|
||||
c.receiveBuffer.Write(buf[:rdLen])
|
||||
|
||||
var decoded [framing.MaximumFramePayloadLength]byte
|
||||
for c.receiveBuffer.Len() > 0 {
|
||||
// Decrypt an AEAD frame.
|
||||
// TODO: Change Decode to write into packet directly
|
||||
var pkt []byte
|
||||
_, pkt, err = c.decoder.Decode(&c.receiveBuffer)
|
||||
decLen := 0
|
||||
decLen, err = c.decoder.Decode(decoded[:], &c.receiveBuffer)
|
||||
if err == framing.ErrAgain {
|
||||
// The accumulated payload does not make up a full frame.
|
||||
return
|
||||
} else if err != nil {
|
||||
break
|
||||
} else if len(pkt) < packetOverhead {
|
||||
err = InvalidPacketLengthError(len(pkt))
|
||||
} else if decLen < packetOverhead {
|
||||
err = InvalidPacketLengthError(decLen)
|
||||
break
|
||||
}
|
||||
|
||||
// Decode the packet.
|
||||
pkt := decoded[0:decLen]
|
||||
pktType := pkt[0]
|
||||
payloadLen := binary.BigEndian.Uint16(pkt[1:])
|
||||
if int(payloadLen) > len(pkt)-packetOverhead {
|
||||
|
Loading…
Reference in New Issue
Block a user