diff --git a/obfs4.go b/obfs4.go index eadcbef..429ba95 100644 --- a/obfs4.go +++ b/obfs4.go @@ -42,7 +42,6 @@ import ( const ( headerLength = framing.FrameOverhead + packetOverhead - defaultReadSize = framing.MaximumSegmentLength connectionTimeout = time.Duration(15) * time.Second minCloseThreshold = 0 @@ -163,6 +162,8 @@ func (c *Obfs4Conn) clientHandshake(nodeID *ntor.NodeID, publicKey *ntor.PublicK var n int n, err = c.conn.Read(hsBuf[:]) if err != nil { + // Yes, just bail out of handshaking even if the Read could have + // returned data, no point in continuing on EOF/etc. return } c.receiveBuffer.Write(hsBuf[:n]) @@ -215,6 +216,8 @@ func (c *Obfs4Conn) serverHandshake(nodeID *ntor.NodeID, keypair *ntor.Keypair) var n int n, err = c.conn.Read(hsBuf[:]) if err != nil { + // Yes, just bail out of handshaking even if the Read could have + // returned data, no point in continuing on EOF/etc. return } c.receiveBuffer.Write(hsBuf[:n]) @@ -354,7 +357,7 @@ func (c *Obfs4Conn) Write(b []byte) (n int, err error) { } }() - // XXX: Change this to write directly to c.conn skipping frameBuf. + // TODO: Change this to write directly to c.conn skipping frameBuf. chopBuf := bytes.NewBuffer(b) var payload [maxPacketPayloadLength]byte var frameBuf bytes.Buffer diff --git a/packet.go b/packet.go index 6f9eb03..75179cb 100644 --- a/packet.go +++ b/packet.go @@ -125,21 +125,15 @@ func (c *Obfs4Conn) consumeFramedPackets(w io.Writer) (n int, err error) { } var buf [consumeReadSize]byte - var rdLen int - rdLen, err = c.conn.Read(buf[:]) - if err != nil { - return - } + rdLen, rdErr := c.conn.Read(buf[:]) c.receiveBuffer.Write(buf[:rdLen]) - var decoded [framing.MaximumFramePayloadLength]byte for c.receiveBuffer.Len() > 0 { // Decrypt an AEAD frame. decLen := 0 decLen, err = c.decoder.Decode(decoded[:], &c.receiveBuffer) if err == framing.ErrAgain { - // The accumulated payload does not make up a full frame. - return + break } else if err != nil { break } else if decLen < packetOverhead { @@ -187,8 +181,12 @@ func (c *Obfs4Conn) consumeFramedPackets(w io.Writer) (n int, err error) { } } - // All errors that reach this point are fatal. - if err != nil { + // Read errors and non-framing.ErrAgain errors are all fatal. + if (err != nil && err != framing.ErrAgain) || rdErr != nil { + // Propagate read errors correctly. + if err == nil && rdErr != nil { + err = rdErr + } c.setBroken() }