Framing in Stream.Write to prevent silent short write

pull/110/head
Andy Wang 5 years ago
parent 17d57d9369
commit e9243a2e9f

@ -77,17 +77,17 @@ func Copy(dst net.Conn, src net.Conn, srcReadTimeout time.Duration) (written int
}
nr, er := src.Read(buf)
if nr > 0 {
var offset int
for offset < nr {
nw, ew := dst.Write(buf[offset:nr])
if nw > 0 {
written += int64(nw)
}
if ew != nil {
err = ew
break
}
offset += nw
nw, ew := dst.Write(buf[0:nr])
if nw > 0 {
written += int64(nw)
}
if ew != nil {
err = ew
break
}
if nr != nw {
err = io.ErrShortWrite
break
}
}
if er != nil {

@ -59,6 +59,8 @@ type Session struct {
closed uint32
terminalMsg atomic.Value
maxStreamUnitWrite int // the max size passed to Write calls before it splits it into multiple frames
}
func MakeSession(id uint32, config SessionConfig) *Session {
@ -82,6 +84,7 @@ func MakeSession(id uint32, config SessionConfig) *Session {
if config.MaxFrameSize <= 0 {
sesh.MaxFrameSize = defaultSendRecvBufSize - 1024
}
sesh.maxStreamUnitWrite = sesh.MaxFrameSize - HEADER_LEN - sesh.Obfuscator.minOverhead
sbConfig := switchboardConfig{
valve: sesh.Valve,

@ -96,37 +96,39 @@ func (s *Stream) Write(in []byte) (n int, err error) {
return 0, ErrBrokenStream
}
var payload []byte
maxDataLen := s.session.MaxFrameSize - HEADER_LEN - s.session.minOverhead
if len(in) <= maxDataLen {
payload = in
} else {
//TODO: short write isn't the correct behaviour
payload = in[:maxDataLen]
}
for n < len(in) {
var framePayload []byte
if len(in)-n <= s.session.maxStreamUnitWrite {
framePayload = in[n:]
} else {
framePayload = in[n : s.session.maxStreamUnitWrite+n]
}
f := &Frame{
StreamID: s.id,
Seq: atomic.AddUint64(&s.nextSendSeq, 1) - 1,
Closing: C_NOOP,
Payload: payload,
}
f := &Frame{
StreamID: s.id,
Seq: atomic.AddUint64(&s.nextSendSeq, 1) - 1,
Closing: C_NOOP,
Payload: framePayload,
}
i, err := s.session.Obfs(f, s.obfsBuf)
if err != nil {
return i, err
}
n, err = s.session.sb.send(s.obfsBuf[:i], &s.assignedConnId)
log.Tracef("%v sent to remote through stream %v with err %v", len(payload), s.id, err)
if err != nil {
if err == errBrokenSwitchboard {
s.session.SetTerminalMsg(err.Error())
s.session.passiveClose()
var cipherTextLen int
cipherTextLen, err = s.session.Obfs(f, s.obfsBuf)
if err != nil {
return 0, err
}
return
}
return len(payload), nil
_, err = s.session.sb.send(s.obfsBuf[:cipherTextLen], &s.assignedConnId)
log.Tracef("%v sent to remote through stream %v with err %v", len(framePayload), s.id, err)
if err != nil {
if err == errBrokenSwitchboard {
s.session.SetTerminalMsg(err.Error())
s.session.passiveClose()
}
return
}
n += len(framePayload)
}
return
}
func (s *Stream) passiveClose() error {

Loading…
Cancel
Save