diff --git a/internal/common/tls.go b/internal/common/tls.go index 4917e8e..fd2fce4 100644 --- a/internal/common/tls.go +++ b/internal/common/tls.go @@ -1,6 +1,7 @@ package common import ( + "bytes" "encoding/binary" "io" "net" @@ -36,18 +37,15 @@ func AddRecordLayer(input []byte, typ byte, ver uint16) []byte { type TLSConn struct { net.Conn - writeM sync.Mutex - writeBuf []byte + writeBufPool sync.Pool } func NewTLSConn(conn net.Conn) *TLSConn { - writeBuf := make([]byte, initialWriteBufSize) - writeBuf[0] = ApplicationData - writeBuf[1] = byte(VersionTLS13 >> 8) - writeBuf[2] = byte(VersionTLS13 & 0xFF) return &TLSConn{ - Conn: conn, - writeBuf: writeBuf, + Conn: conn, + writeBufPool: sync.Pool{New: func() interface{} { + return new(bytes.Buffer) + }}, } } @@ -95,13 +93,16 @@ func (tls *TLSConn) Read(buffer []byte) (n int, err error) { func (tls *TLSConn) Write(in []byte) (n int, err error) { msgLen := len(in) - tls.writeM.Lock() - tls.writeBuf = append(tls.writeBuf[:5], in...) - tls.writeBuf[3] = byte(msgLen >> 8) - tls.writeBuf[4] = byte(msgLen & 0xFF) - n, err = tls.Conn.Write(tls.writeBuf[:recordLayerLength+msgLen]) - tls.writeM.Unlock() - return n - recordLayerLength, err + writeBuf := tls.writeBufPool.Get().(*bytes.Buffer) + writeBuf.WriteByte(ApplicationData) + writeBuf.WriteByte(byte(VersionTLS13 >> 8)) + writeBuf.WriteByte(byte(VersionTLS13 & 0xFF)) + writeBuf.WriteByte(byte(msgLen >> 8)) + writeBuf.WriteByte(byte(msgLen & 0xFF)) + writeBuf.Write(in) + i, err := writeBuf.WriteTo(tls.Conn) + tls.writeBufPool.Put(writeBuf) + return int(i - recordLayerLength), err } func (tls *TLSConn) Close() error {