diff --git a/obfs4.go b/obfs4.go index 2823b75..afe8967 100644 --- a/obfs4.go +++ b/obfs4.go @@ -101,9 +101,9 @@ func (c *Obfs4Conn) closeAfterDelay() { // Consume and discard data on this connection until either the specified // interval passes or a certain size has been reached. discarded := 0 - buf := make([]byte, defaultReadSize) + var buf [framing.MaximumSegmentLength]byte for discarded < int(toDiscard) { - n, err := c.conn.Read(buf) + n, err := c.conn.Read(buf[:]) if err != nil { return } @@ -281,12 +281,12 @@ func (c *Obfs4Conn) WriteTo(w io.Writer) (n int64, err error) { // If there is buffered payload from earlier Read() calls, write. if c.receiveDecodedBuffer.Len() > 0 { wrLen, err = w.Write(c.receiveDecodedBuffer.Bytes()) - if wrLen < int(c.receiveDecodedBuffer.Len()) { - c.isOk = false - return int64(wrLen), io.ErrShortWrite - } else if err != nil { + if err != nil { c.isOk = false return int64(wrLen), err + } else if wrLen < int(c.receiveDecodedBuffer.Len()) { + c.isOk = false + return int64(wrLen), io.ErrShortWrite } c.receiveDecodedBuffer.Reset() } @@ -308,66 +308,58 @@ func (c *Obfs4Conn) WriteTo(w io.Writer) (n int64, err error) { return } -func (c *Obfs4Conn) Write(b []byte) (int, error) { +func (c *Obfs4Conn) Write(b []byte) (n int, err error) { chopBuf := bytes.NewBuffer(b) - buf := make([]byte, maxPacketPayloadLength) - nSent := 0 + var payload [maxPacketPayloadLength]byte var frameBuf bytes.Buffer for chopBuf.Len() > 0 { // Send maximum sized frames. - n, err := chopBuf.Read(buf) + rdLen := 0 + rdLen, err = chopBuf.Read(payload[:]) if err != nil { c.isOk = false return 0, err - } else if n == 0 { + } else if rdLen == 0 { panic(fmt.Sprintf("BUG: Write(), chopping length was 0")) } - nSent += n + n += rdLen - _, frame, err := c.makeAndEncryptPacket(packetTypePayload, buf[:n], 0) + err = c.producePacket(&frameBuf, packetTypePayload, payload[:rdLen], 0) if err != nil { c.isOk = false return 0, err } - - frameBuf.Write(frame) } // Insert random padding. In theory it's possible to inline padding for // certain framesizes into the last AEAD packet, but always sending 1 or 2 // padding frames is considerably easier. padLen := c.calcPadLen(frameBuf.Len()) - if padLen > 0 { - if padLen > headerLength { - _, frame, err := c.makeAndEncryptPacket(packetTypePayload, []byte{}, - uint16(padLen-headerLength)) - if err != nil { - c.isOk = false - return 0, err - } - frameBuf.Write(frame) - } else { - _, frame, err := c.makeAndEncryptPacket(packetTypePayload, []byte{}, - maxPacketPayloadLength) - if err != nil { - c.isOk = false - return 0, err - } - frameBuf.Write(frame) - - _, frame, err = c.makeAndEncryptPacket(packetTypePayload, []byte{}, - uint16(padLen)) - if err != nil { - c.isOk = false - return 0, err - } - frameBuf.Write(frame) + if padLen > headerLength { + err = c.producePacket(&frameBuf, packetTypePayload, []byte{}, + uint16(padLen-headerLength)) + if err != nil { + c.isOk = false + return 0, err + } + } else if padLen > 0 { + err = c.producePacket(&frameBuf, packetTypePayload, []byte{}, + maxPacketPayloadLength) + if err != nil { + c.isOk = false + return 0, err + } + err = c.producePacket(&frameBuf, packetTypePayload, []byte{}, + uint16(padLen)) + if err != nil { + c.isOk = false + return 0, err } } // Send the frame(s). - _, err := c.conn.Write(frameBuf.Bytes()) + _, err = c.conn.Write(frameBuf.Bytes()) if err != nil { // Partial writes are fatal because the frame encoder state is advanced // at this point. It's possible to keep frameBuf around, but fuck it. @@ -376,7 +368,7 @@ func (c *Obfs4Conn) Write(b []byte) (int, error) { return 0, err } - return nSent, nil + return } func (c *Obfs4Conn) Close() error { diff --git a/packet.go b/packet.go index 8fb53d0..7b69517 100644 --- a/packet.go +++ b/packet.go @@ -67,8 +67,8 @@ func (e InvalidPayloadLengthError) Error() string { var zeroPadBytes [maxPacketPaddingLength]byte -func makePacket(pkt []byte, pktType uint8, data []byte, padLen uint16) int { - pktLen := packetOverhead + len(data) + int(padLen) +func (c *Obfs4Conn) producePacket(w io.Writer, pktType uint8, data []byte, padLen uint16) error { + var pkt [framing.MaximumFramePayloadLength]byte if len(data)+int(padLen) > maxPacketPayloadLength { panic(fmt.Sprintf("BUG: makePacket() len(data) + padLen > maxPacketPayloadLength: %d + %d > %d", @@ -80,7 +80,6 @@ func makePacket(pkt []byte, pktType uint8, data []byte, padLen uint16) int { // uint16_t length Length of the payload (Big Endian). // uint8_t[] payload Data payload. // uint8_t[] padding Padding. - pkt[0] = pktType binary.BigEndian.PutUint16(pkt[1:], uint16(len(data))) if len(data) > 0 { @@ -88,18 +87,26 @@ func makePacket(pkt []byte, pktType uint8, data []byte, padLen uint16) int { } copy(pkt[3+len(data):], zeroPadBytes[:padLen]) - return pktLen -} - -func (c *Obfs4Conn) makeAndEncryptPacket(pktType uint8, data []byte, padLen uint16) (int, []byte, error) { - var pkt [framing.MaximumFramePayloadLength]byte - - // Wrap the payload in a packet. - n := makePacket(pkt[:], pktType, data[:], padLen) + pktLen := packetOverhead + len(data) + int(padLen) // Encode the packet in an AEAD frame. - n, frame, err := c.encoder.Encode(pkt[:n]) - return n, frame, err + // TODO: Change Encode to write into frame directly + _, frame, err := c.encoder.Encode(pkt[:pktLen]) + if err != nil { + // All encoder errors are fatal. + c.isOk = false + return err + } + wrLen, err := w.Write(frame) + if err != nil { + c.isOk = false + return err + } else if wrLen < len(frame) { + c.isOk = false + return io.ErrShortWrite + } + + return nil } func (c *Obfs4Conn) consumeFramedPackets(w io.Writer) (n int, err error) { @@ -116,7 +123,7 @@ func (c *Obfs4Conn) consumeFramedPackets(w io.Writer) (n int, err error) { for c.receiveBuffer.Len() > 0 { // Decrypt an AEAD frame. - // TODO: Change decode to write into packet directly + // TODO: Change Decode to write into packet directly var pkt []byte _, pkt, err = c.decoder.Decode(&c.receiveBuffer) if err == framing.ErrAgain { @@ -145,10 +152,10 @@ func (c *Obfs4Conn) consumeFramedPackets(w io.Writer) (n int, err error) { // c.WriteTo() skips buffering in c.receiveDecodedBuffer wrLen, err := w.Write(payload) n += wrLen - if wrLen < int(payloadLen) { - err = io.ErrShortWrite + if err != nil { break - } else if err != nil { + } else if wrLen < int(payloadLen) { + err = io.ErrShortWrite break } } else {