Fix up how Read() errors were processed where appropriate.

This commit is contained in:
Yawning Angel 2014-05-15 00:52:53 +00:00
parent 79a7ad7f2b
commit 013c3c7c4d
2 changed files with 13 additions and 12 deletions

View File

@ -42,7 +42,6 @@ import (
const (
headerLength = framing.FrameOverhead + packetOverhead
defaultReadSize = framing.MaximumSegmentLength
connectionTimeout = time.Duration(15) * time.Second
minCloseThreshold = 0
@ -163,6 +162,8 @@ func (c *Obfs4Conn) clientHandshake(nodeID *ntor.NodeID, publicKey *ntor.PublicK
var n int
n, err = c.conn.Read(hsBuf[:])
if err != nil {
// Yes, just bail out of handshaking even if the Read could have
// returned data, no point in continuing on EOF/etc.
return
}
c.receiveBuffer.Write(hsBuf[:n])
@ -215,6 +216,8 @@ func (c *Obfs4Conn) serverHandshake(nodeID *ntor.NodeID, keypair *ntor.Keypair)
var n int
n, err = c.conn.Read(hsBuf[:])
if err != nil {
// Yes, just bail out of handshaking even if the Read could have
// returned data, no point in continuing on EOF/etc.
return
}
c.receiveBuffer.Write(hsBuf[:n])
@ -354,7 +357,7 @@ func (c *Obfs4Conn) Write(b []byte) (n int, err error) {
}
}()
// XXX: Change this to write directly to c.conn skipping frameBuf.
// TODO: Change this to write directly to c.conn skipping frameBuf.
chopBuf := bytes.NewBuffer(b)
var payload [maxPacketPayloadLength]byte
var frameBuf bytes.Buffer

View File

@ -125,21 +125,15 @@ func (c *Obfs4Conn) consumeFramedPackets(w io.Writer) (n int, err error) {
}
var buf [consumeReadSize]byte
var rdLen int
rdLen, err = c.conn.Read(buf[:])
if err != nil {
return
}
rdLen, rdErr := c.conn.Read(buf[:])
c.receiveBuffer.Write(buf[:rdLen])
var decoded [framing.MaximumFramePayloadLength]byte
for c.receiveBuffer.Len() > 0 {
// Decrypt an AEAD frame.
decLen := 0
decLen, err = c.decoder.Decode(decoded[:], &c.receiveBuffer)
if err == framing.ErrAgain {
// The accumulated payload does not make up a full frame.
return
break
} else if err != nil {
break
} else if decLen < packetOverhead {
@ -187,8 +181,12 @@ func (c *Obfs4Conn) consumeFramedPackets(w io.Writer) (n int, err error) {
}
}
// All errors that reach this point are fatal.
if err != nil {
// Read errors and non-framing.ErrAgain errors are all fatal.
if (err != nil && err != framing.ErrAgain) || rdErr != nil {
// Propagate read errors correctly.
if err == nil && rdErr != nil {
err = rdErr
}
c.setBroken()
}