mirror of
https://github.com/cbeuw/Cloak.git
synced 2024-11-03 23:15:18 +00:00
Framing in Stream.Write to prevent silent short write
This commit is contained in:
parent
17d57d9369
commit
e9243a2e9f
@ -77,17 +77,17 @@ func Copy(dst net.Conn, src net.Conn, srcReadTimeout time.Duration) (written int
|
||||
}
|
||||
nr, er := src.Read(buf)
|
||||
if nr > 0 {
|
||||
var offset int
|
||||
for offset < nr {
|
||||
nw, ew := dst.Write(buf[offset:nr])
|
||||
if nw > 0 {
|
||||
written += int64(nw)
|
||||
}
|
||||
if ew != nil {
|
||||
err = ew
|
||||
break
|
||||
}
|
||||
offset += nw
|
||||
nw, ew := dst.Write(buf[0:nr])
|
||||
if nw > 0 {
|
||||
written += int64(nw)
|
||||
}
|
||||
if ew != nil {
|
||||
err = ew
|
||||
break
|
||||
}
|
||||
if nr != nw {
|
||||
err = io.ErrShortWrite
|
||||
break
|
||||
}
|
||||
}
|
||||
if er != nil {
|
||||
|
@ -59,6 +59,8 @@ type Session struct {
|
||||
closed uint32
|
||||
|
||||
terminalMsg atomic.Value
|
||||
|
||||
maxStreamUnitWrite int // the max size passed to Write calls before it splits it into multiple frames
|
||||
}
|
||||
|
||||
func MakeSession(id uint32, config SessionConfig) *Session {
|
||||
@ -82,6 +84,7 @@ func MakeSession(id uint32, config SessionConfig) *Session {
|
||||
if config.MaxFrameSize <= 0 {
|
||||
sesh.MaxFrameSize = defaultSendRecvBufSize - 1024
|
||||
}
|
||||
sesh.maxStreamUnitWrite = sesh.MaxFrameSize - HEADER_LEN - sesh.Obfuscator.minOverhead
|
||||
|
||||
sbConfig := switchboardConfig{
|
||||
valve: sesh.Valve,
|
||||
|
@ -96,37 +96,39 @@ func (s *Stream) Write(in []byte) (n int, err error) {
|
||||
return 0, ErrBrokenStream
|
||||
}
|
||||
|
||||
var payload []byte
|
||||
maxDataLen := s.session.MaxFrameSize - HEADER_LEN - s.session.minOverhead
|
||||
if len(in) <= maxDataLen {
|
||||
payload = in
|
||||
} else {
|
||||
//TODO: short write isn't the correct behaviour
|
||||
payload = in[:maxDataLen]
|
||||
}
|
||||
|
||||
f := &Frame{
|
||||
StreamID: s.id,
|
||||
Seq: atomic.AddUint64(&s.nextSendSeq, 1) - 1,
|
||||
Closing: C_NOOP,
|
||||
Payload: payload,
|
||||
}
|
||||
|
||||
i, err := s.session.Obfs(f, s.obfsBuf)
|
||||
if err != nil {
|
||||
return i, err
|
||||
}
|
||||
n, err = s.session.sb.send(s.obfsBuf[:i], &s.assignedConnId)
|
||||
log.Tracef("%v sent to remote through stream %v with err %v", len(payload), s.id, err)
|
||||
if err != nil {
|
||||
if err == errBrokenSwitchboard {
|
||||
s.session.SetTerminalMsg(err.Error())
|
||||
s.session.passiveClose()
|
||||
for n < len(in) {
|
||||
var framePayload []byte
|
||||
if len(in)-n <= s.session.maxStreamUnitWrite {
|
||||
framePayload = in[n:]
|
||||
} else {
|
||||
framePayload = in[n : s.session.maxStreamUnitWrite+n]
|
||||
}
|
||||
return
|
||||
}
|
||||
return len(payload), nil
|
||||
|
||||
f := &Frame{
|
||||
StreamID: s.id,
|
||||
Seq: atomic.AddUint64(&s.nextSendSeq, 1) - 1,
|
||||
Closing: C_NOOP,
|
||||
Payload: framePayload,
|
||||
}
|
||||
|
||||
var cipherTextLen int
|
||||
cipherTextLen, err = s.session.Obfs(f, s.obfsBuf)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
_, err = s.session.sb.send(s.obfsBuf[:cipherTextLen], &s.assignedConnId)
|
||||
log.Tracef("%v sent to remote through stream %v with err %v", len(framePayload), s.id, err)
|
||||
if err != nil {
|
||||
if err == errBrokenSwitchboard {
|
||||
s.session.SetTerminalMsg(err.Error())
|
||||
s.session.passiveClose()
|
||||
}
|
||||
return
|
||||
}
|
||||
n += len(framePayload)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (s *Stream) passiveClose() error {
|
||||
|
Loading…
Reference in New Issue
Block a user