mirror of
https://gitlab.com/yawning/obfs4.git
synced 2024-11-15 12:12:53 +00:00
First pass at cleaning up the write code.
This commit is contained in:
parent
731a926172
commit
557e746815
76
obfs4.go
76
obfs4.go
@ -101,9 +101,9 @@ func (c *Obfs4Conn) closeAfterDelay() {
|
||||
// Consume and discard data on this connection until either the specified
|
||||
// interval passes or a certain size has been reached.
|
||||
discarded := 0
|
||||
buf := make([]byte, defaultReadSize)
|
||||
var buf [framing.MaximumSegmentLength]byte
|
||||
for discarded < int(toDiscard) {
|
||||
n, err := c.conn.Read(buf)
|
||||
n, err := c.conn.Read(buf[:])
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
@ -281,12 +281,12 @@ func (c *Obfs4Conn) WriteTo(w io.Writer) (n int64, err error) {
|
||||
// If there is buffered payload from earlier Read() calls, write.
|
||||
if c.receiveDecodedBuffer.Len() > 0 {
|
||||
wrLen, err = w.Write(c.receiveDecodedBuffer.Bytes())
|
||||
if wrLen < int(c.receiveDecodedBuffer.Len()) {
|
||||
c.isOk = false
|
||||
return int64(wrLen), io.ErrShortWrite
|
||||
} else if err != nil {
|
||||
if err != nil {
|
||||
c.isOk = false
|
||||
return int64(wrLen), err
|
||||
} else if wrLen < int(c.receiveDecodedBuffer.Len()) {
|
||||
c.isOk = false
|
||||
return int64(wrLen), io.ErrShortWrite
|
||||
}
|
||||
c.receiveDecodedBuffer.Reset()
|
||||
}
|
||||
@ -308,66 +308,58 @@ func (c *Obfs4Conn) WriteTo(w io.Writer) (n int64, err error) {
|
||||
return
|
||||
}
|
||||
|
||||
func (c *Obfs4Conn) Write(b []byte) (int, error) {
|
||||
func (c *Obfs4Conn) Write(b []byte) (n int, err error) {
|
||||
chopBuf := bytes.NewBuffer(b)
|
||||
buf := make([]byte, maxPacketPayloadLength)
|
||||
nSent := 0
|
||||
var payload [maxPacketPayloadLength]byte
|
||||
var frameBuf bytes.Buffer
|
||||
|
||||
for chopBuf.Len() > 0 {
|
||||
// Send maximum sized frames.
|
||||
n, err := chopBuf.Read(buf)
|
||||
rdLen := 0
|
||||
rdLen, err = chopBuf.Read(payload[:])
|
||||
if err != nil {
|
||||
c.isOk = false
|
||||
return 0, err
|
||||
} else if n == 0 {
|
||||
} else if rdLen == 0 {
|
||||
panic(fmt.Sprintf("BUG: Write(), chopping length was 0"))
|
||||
}
|
||||
nSent += n
|
||||
n += rdLen
|
||||
|
||||
_, frame, err := c.makeAndEncryptPacket(packetTypePayload, buf[:n], 0)
|
||||
err = c.producePacket(&frameBuf, packetTypePayload, payload[:rdLen], 0)
|
||||
if err != nil {
|
||||
c.isOk = false
|
||||
return 0, err
|
||||
}
|
||||
|
||||
frameBuf.Write(frame)
|
||||
}
|
||||
|
||||
// Insert random padding. In theory it's possible to inline padding for
|
||||
// certain framesizes into the last AEAD packet, but always sending 1 or 2
|
||||
// padding frames is considerably easier.
|
||||
padLen := c.calcPadLen(frameBuf.Len())
|
||||
if padLen > 0 {
|
||||
if padLen > headerLength {
|
||||
_, frame, err := c.makeAndEncryptPacket(packetTypePayload, []byte{},
|
||||
uint16(padLen-headerLength))
|
||||
if err != nil {
|
||||
c.isOk = false
|
||||
return 0, err
|
||||
}
|
||||
frameBuf.Write(frame)
|
||||
} else {
|
||||
_, frame, err := c.makeAndEncryptPacket(packetTypePayload, []byte{},
|
||||
maxPacketPayloadLength)
|
||||
if err != nil {
|
||||
c.isOk = false
|
||||
return 0, err
|
||||
}
|
||||
frameBuf.Write(frame)
|
||||
|
||||
_, frame, err = c.makeAndEncryptPacket(packetTypePayload, []byte{},
|
||||
uint16(padLen))
|
||||
if err != nil {
|
||||
c.isOk = false
|
||||
return 0, err
|
||||
}
|
||||
frameBuf.Write(frame)
|
||||
if padLen > headerLength {
|
||||
err = c.producePacket(&frameBuf, packetTypePayload, []byte{},
|
||||
uint16(padLen-headerLength))
|
||||
if err != nil {
|
||||
c.isOk = false
|
||||
return 0, err
|
||||
}
|
||||
} else if padLen > 0 {
|
||||
err = c.producePacket(&frameBuf, packetTypePayload, []byte{},
|
||||
maxPacketPayloadLength)
|
||||
if err != nil {
|
||||
c.isOk = false
|
||||
return 0, err
|
||||
}
|
||||
err = c.producePacket(&frameBuf, packetTypePayload, []byte{},
|
||||
uint16(padLen))
|
||||
if err != nil {
|
||||
c.isOk = false
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
|
||||
// Send the frame(s).
|
||||
_, err := c.conn.Write(frameBuf.Bytes())
|
||||
_, err = c.conn.Write(frameBuf.Bytes())
|
||||
if err != nil {
|
||||
// Partial writes are fatal because the frame encoder state is advanced
|
||||
// at this point. It's possible to keep frameBuf around, but fuck it.
|
||||
@ -376,7 +368,7 @@ func (c *Obfs4Conn) Write(b []byte) (int, error) {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return nSent, nil
|
||||
return
|
||||
}
|
||||
|
||||
func (c *Obfs4Conn) Close() error {
|
||||
|
41
packet.go
41
packet.go
@ -67,8 +67,8 @@ func (e InvalidPayloadLengthError) Error() string {
|
||||
|
||||
var zeroPadBytes [maxPacketPaddingLength]byte
|
||||
|
||||
func makePacket(pkt []byte, pktType uint8, data []byte, padLen uint16) int {
|
||||
pktLen := packetOverhead + len(data) + int(padLen)
|
||||
func (c *Obfs4Conn) producePacket(w io.Writer, pktType uint8, data []byte, padLen uint16) error {
|
||||
var pkt [framing.MaximumFramePayloadLength]byte
|
||||
|
||||
if len(data)+int(padLen) > maxPacketPayloadLength {
|
||||
panic(fmt.Sprintf("BUG: makePacket() len(data) + padLen > maxPacketPayloadLength: %d + %d > %d",
|
||||
@ -80,7 +80,6 @@ func makePacket(pkt []byte, pktType uint8, data []byte, padLen uint16) int {
|
||||
// uint16_t length Length of the payload (Big Endian).
|
||||
// uint8_t[] payload Data payload.
|
||||
// uint8_t[] padding Padding.
|
||||
|
||||
pkt[0] = pktType
|
||||
binary.BigEndian.PutUint16(pkt[1:], uint16(len(data)))
|
||||
if len(data) > 0 {
|
||||
@ -88,18 +87,26 @@ func makePacket(pkt []byte, pktType uint8, data []byte, padLen uint16) int {
|
||||
}
|
||||
copy(pkt[3+len(data):], zeroPadBytes[:padLen])
|
||||
|
||||
return pktLen
|
||||
}
|
||||
|
||||
func (c *Obfs4Conn) makeAndEncryptPacket(pktType uint8, data []byte, padLen uint16) (int, []byte, error) {
|
||||
var pkt [framing.MaximumFramePayloadLength]byte
|
||||
|
||||
// Wrap the payload in a packet.
|
||||
n := makePacket(pkt[:], pktType, data[:], padLen)
|
||||
pktLen := packetOverhead + len(data) + int(padLen)
|
||||
|
||||
// Encode the packet in an AEAD frame.
|
||||
n, frame, err := c.encoder.Encode(pkt[:n])
|
||||
return n, frame, err
|
||||
// TODO: Change Encode to write into frame directly
|
||||
_, frame, err := c.encoder.Encode(pkt[:pktLen])
|
||||
if err != nil {
|
||||
// All encoder errors are fatal.
|
||||
c.isOk = false
|
||||
return err
|
||||
}
|
||||
wrLen, err := w.Write(frame)
|
||||
if err != nil {
|
||||
c.isOk = false
|
||||
return err
|
||||
} else if wrLen < len(frame) {
|
||||
c.isOk = false
|
||||
return io.ErrShortWrite
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Obfs4Conn) consumeFramedPackets(w io.Writer) (n int, err error) {
|
||||
@ -116,7 +123,7 @@ func (c *Obfs4Conn) consumeFramedPackets(w io.Writer) (n int, err error) {
|
||||
|
||||
for c.receiveBuffer.Len() > 0 {
|
||||
// Decrypt an AEAD frame.
|
||||
// TODO: Change decode to write into packet directly
|
||||
// TODO: Change Decode to write into packet directly
|
||||
var pkt []byte
|
||||
_, pkt, err = c.decoder.Decode(&c.receiveBuffer)
|
||||
if err == framing.ErrAgain {
|
||||
@ -145,10 +152,10 @@ func (c *Obfs4Conn) consumeFramedPackets(w io.Writer) (n int, err error) {
|
||||
// c.WriteTo() skips buffering in c.receiveDecodedBuffer
|
||||
wrLen, err := w.Write(payload)
|
||||
n += wrLen
|
||||
if wrLen < int(payloadLen) {
|
||||
err = io.ErrShortWrite
|
||||
if err != nil {
|
||||
break
|
||||
} else if err != nil {
|
||||
} else if wrLen < int(payloadLen) {
|
||||
err = io.ErrShortWrite
|
||||
break
|
||||
}
|
||||
} else {
|
||||
|
Loading…
Reference in New Issue
Block a user