Fix infinite loop. Baseline

pull/2/head
Qian Wang 6 years ago
parent a8786a5576
commit 02fa072964

@ -6,7 +6,10 @@ import (
"io"
"log"
"net"
"net/http"
_ "net/http/pprof"
"os"
"runtime"
"strings"
"time"
@ -19,9 +22,8 @@ var version string
func pipe(dst io.ReadWriteCloser, src io.ReadWriteCloser) {
for {
_, err := io.Copy(dst, src)
if err != nil {
log.Println(err)
i, err := io.Copy(dst, src)
if err != nil || i == 0 {
go dst.Close()
go src.Close()
return
@ -102,10 +104,12 @@ func dispatchConnection(conn net.Conn, sta *server.State) {
newStream, err := sesh.AcceptStream()
if err != nil {
log.Printf("Failed to get new stream: %v", err)
continue
}
ssConn, err := net.Dial("tcp", sta.SS_LOCAL_HOST+":"+sta.SS_LOCAL_PORT)
if err != nil {
log.Printf("Failed to connect to ssserver: %v", err)
continue
}
go pipe(ssConn, newStream)
go pipe(newStream, ssConn)
@ -116,6 +120,10 @@ func dispatchConnection(conn net.Conn, sta *server.State) {
}
func main() {
runtime.SetBlockProfileRate(2)
go func() {
log.Println(http.ListenAndServe("0.0.0.0:8001", nil))
}()
// Should be 127.0.0.1 to listen to ss-server on this machine
var localHost string
// server_port in ss config, same as remotePort in plugin mode

@ -1,4 +1,4 @@
{
"WebServerAddr":"204.79.197.200:443",
"Key":"CN+VRP9OqZR0+Im2X/1y6FvaK7+GBnX6qCiovbo+eVo="
"Key":"UGUmcEmxWf0pKxfkZ/8EoP35Ht+wQnqf3L0xYgyQFlQ="
}

@ -2,6 +2,7 @@ package multiplex
import (
"container/heap"
"log"
)
// The data is multiplexed through several TCP connections, therefore the
@ -57,8 +58,10 @@ func (s *Stream) recvNewFrame() {
for {
f := <-s.newFrameCh
if f == nil {
log.Println("nil frame")
continue
}
// For the ease of demonstration, assume seq is uint8, i.e. it wraps around after 255
fs := &frameNode{
f.Seq,

@ -8,7 +8,7 @@ import (
const (
// Copied from smux
errBrokenPipe = "broken pipe"
errBrokenPipe = "broken stream"
errRepeatStreamClosing = "trying to close a closed stream"
acceptBacklog = 1024
@ -84,9 +84,9 @@ func (sesh *Session) AcceptStream() (*Stream, error) {
}
func (sesh *Session) delStream(id uint32) {
sesh.streamsM.RLock()
sesh.streamsM.Lock()
delete(sesh.streams, id)
sesh.streamsM.RUnlock()
sesh.streamsM.Unlock()
}
func (sesh *Session) isStream(id uint32) bool {

@ -8,7 +8,7 @@ import (
)
const (
readBuffer = 102400
readBuffer = 20480
)
type Stream struct {
@ -50,21 +50,29 @@ func makeStream(id uint32, sesh *Session) *Stream {
}
func (stream *Stream) Read(buf []byte) (n int, err error) {
if len(buf) != 0 {
if len(buf) == 0 {
select {
case <-stream.die:
log.Printf("Stream %v dying\n", stream.id)
return 0, errors.New(errBrokenPipe)
case data := <-stream.sortedBufCh:
if len(data) > 0 {
copy(buf, data)
return len(data), nil
} else {
// TODO: close stream here or not?
return 0, io.EOF
}
default:
return 0, nil
}
}
return 0, errors.New(errBrokenPipe)
select {
case <-stream.die:
log.Printf("Stream %v dying\n", stream.id)
return 0, errors.New(errBrokenPipe)
default:
}
data := <-stream.sortedBufCh
if len(data) > 0 {
copy(buf, data)
return len(data), nil
} else {
// TODO: close stream here or not?
return 0, io.EOF
}
}
@ -111,8 +119,8 @@ func (stream *Stream) Close() error {
return errors.New(errRepeatStreamClosing)
}
stream.closing = true
stream.session.delStream(stream.id)
close(stream.die)
stream.session.delStream(stream.id)
stream.session.closeQCh <- stream.id
return nil
}

@ -3,7 +3,7 @@ package multiplex
import (
"log"
"net"
"sort"
//"sort"
)
const (
@ -79,30 +79,36 @@ type sentNotifier struct {
func (ce *connEnclave) send(data []byte) {
// TODO: error handling
n, err := ce.remoteConn.Write(data)
_, err := ce.remoteConn.Write(data)
if err != nil {
ce.sb.closingCECh <- ce
log.Println(err)
}
sn := &sentNotifier{
ce,
n,
}
ce.sb.sentNotifyCh <- sn
/*
sn := &sentNotifier{
ce,
n,
}
ce.sb.sentNotifyCh <- sn
*/
}
// Dispatcher sends data coming from a stream to a remote connection
// I used channels here because I didn't want to use mutex
func (sb *switchboard) dispatch() {
var nextCE int
for {
select {
// dispatCh receives data from stream.Write
case data := <-sb.dispatCh:
go sb.ces[0].send(data)
sb.ces[0].sendQueue += len(data)
case notified := <-sb.sentNotifyCh:
notified.ce.sendQueue -= notified.sent
sort.Sort(byQ(sb.ces))
go sb.ces[nextCE%len(sb.ces)].send(data)
//sb.ces[0].sendQueue += len(data)
nextCE += 1
/*case notified := <-sb.sentNotifyCh:
notified.ce.sendQueue -= notified.sent
sort.Sort(byQ(sb.ces))*/
case conn := <-sb.newConnCh:
log.Println("newConn")
newCe := &connEnclave{
sb: sb,
remoteConn: conn,
@ -110,8 +116,9 @@ func (sb *switchboard) dispatch() {
}
sb.ces = append(sb.ces, newCe)
go sb.deplex(newCe)
sort.Sort(byQ(sb.ces))
//sort.Sort(byQ(sb.ces))
case closing := <-sb.closingCECh:
log.Println("Closing conn")
for i, ce := range sb.ces {
if closing == ce {
sb.ces = append(sb.ces[:i], sb.ces[i+1:]...)
@ -124,7 +131,7 @@ func (sb *switchboard) dispatch() {
}
func (sb *switchboard) deplex(ce *connEnclave) {
buf := make([]byte, 204800)
buf := make([]byte, 20480)
for {
i, err := sb.session.obfsedReader(ce.remoteConn, buf)
if err != nil {

@ -6,6 +6,7 @@ import (
"io"
prand "math/rand"
"net"
"strconv"
"time"
)
@ -45,15 +46,15 @@ func ReadTillDrain(conn net.Conn, buffer []byte) (n int, err error) {
}
dataLength := BtoInt(buffer[3:5])
if dataLength > len(buffer) {
err = errors.New("Reading TLS message: message size greater than buffer. message size: " + strconv.Itoa(dataLength))
return
}
left := dataLength
readPtr := 5
conn.SetReadDeadline(time.Now().Add(3 * time.Second))
for left != 0 {
if readPtr > len(buffer) || readPtr+left > len(buffer) {
err = errors.New("Reading TLS message: message size greater than buffer")
return
}
// If left > buffer size (i.e. our message got segmented), the entire MTU is read
// if left = buffer size, the entire buffer is all there left to read
// if left < buffer size (i.e. multiple messages came together),

Loading…
Cancel
Save