You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
goproxy/utils/structs.go

463 lines
11 KiB
Go

package utils
import (
"bytes"
"crypto/tls"
"encoding/base64"
"fmt"
"io"
"io/ioutil"
"log"
"net"
"net/url"
"strings"
"time"
)
type Checker struct {
data ConcurrentMap
blockedMap ConcurrentMap
directMap ConcurrentMap
interval int64
timeout int
}
type CheckerItem struct {
IsHTTPS bool
Method string
URL string
Domain string
Host string
Data []byte
SuccessCount uint
FailCount uint
}
//NewChecker args:
//timeout : tcp timeout milliseconds ,connect to host
//interval: recheck domain interval seconds
func NewChecker(timeout int, interval int64, blockedFile, directFile string) Checker {
ch := Checker{
data: NewConcurrentMap(),
interval: interval,
timeout: timeout,
}
ch.blockedMap = ch.loadMap(blockedFile)
ch.directMap = ch.loadMap(directFile)
if !ch.blockedMap.IsEmpty() {
log.Printf("blocked file loaded , domains : %d", ch.blockedMap.Count())
}
if !ch.directMap.IsEmpty() {
log.Printf("direct file loaded , domains : %d", ch.directMap.Count())
}
ch.start()
return ch
}
func (c *Checker) loadMap(f string) (dataMap ConcurrentMap) {
dataMap = NewConcurrentMap()
if PathExists(f) {
_contents, err := ioutil.ReadFile(f)
if err != nil {
log.Printf("load file err:%s", err)
return
}
for _, line := range strings.Split(string(_contents), "\n") {
line = strings.Trim(line, "\r \t")
if line != "" {
dataMap.Set(line, true)
}
}
}
return
}
func (c *Checker) start() {
go func() {
for {
for _, v := range c.data.Items() {
go func(item CheckerItem) {
if c.isNeedCheck(item) {
//log.Printf("check %s", item.Domain)
var conn net.Conn
var err error
if item.IsHTTPS {
conn, err = ConnectHost(item.Host, c.timeout)
if err == nil {
conn.SetDeadline(time.Now().Add(time.Millisecond))
conn.Close()
}
} else {
err = HTTPGet(item.URL, c.timeout)
}
if err != nil {
item.FailCount = item.FailCount + 1
} else {
item.SuccessCount = item.SuccessCount + 1
}
c.data.Set(item.Host, item)
}
}(v.(CheckerItem))
}
time.Sleep(time.Second * time.Duration(c.interval))
}
}()
}
func (c *Checker) isNeedCheck(item CheckerItem) bool {
var minCount uint = 5
if (item.SuccessCount >= minCount && item.SuccessCount > item.FailCount) ||
(item.FailCount >= minCount && item.SuccessCount > item.FailCount) ||
c.domainIsInMap(item.Host, false) ||
c.domainIsInMap(item.Host, true) {
return false
}
return true
}
func (c *Checker) IsBlocked(address string) (blocked bool, failN, successN uint) {
if c.domainIsInMap(address, true) {
//log.Printf("%s in blocked ? true", address)
return true, 0, 0
}
if c.domainIsInMap(address, false) {
//log.Printf("%s in direct ? true", address)
return false, 0, 0
}
_item, ok := c.data.Get(address)
if !ok {
//log.Printf("%s not in map, blocked true", address)
return true, 0, 0
}
item := _item.(CheckerItem)
return item.FailCount >= item.SuccessCount, item.FailCount, item.SuccessCount
}
func (c *Checker) domainIsInMap(address string, blockedMap bool) bool {
u, err := url.Parse("http://" + address)
if err != nil {
log.Printf("blocked check , url parse err:%s", err)
return true
}
domainSlice := strings.Split(u.Hostname(), ".")
if len(domainSlice) > 1 {
subSlice := domainSlice[:len(domainSlice)-1]
topDomain := strings.Join(domainSlice[len(domainSlice)-1:], ".")
checkDomain := topDomain
for i := len(subSlice) - 1; i >= 0; i-- {
checkDomain = subSlice[i] + "." + checkDomain
if !blockedMap && c.directMap.Has(checkDomain) {
return true
}
if blockedMap && c.blockedMap.Has(checkDomain) {
return true
}
}
}
return false
}
func (c *Checker) Add(address string, isHTTPS bool, method, URL string, data []byte) {
if c.domainIsInMap(address, false) || c.domainIsInMap(address, true) {
return
}
if !isHTTPS && strings.ToLower(method) != "get" {
return
}
var item CheckerItem
u := strings.Split(address, ":")
item = CheckerItem{
URL: URL,
Domain: u[0],
Host: address,
Data: data,
IsHTTPS: isHTTPS,
Method: method,
}
c.data.SetIfAbsent(item.Host, item)
}
type BasicAuth struct {
data ConcurrentMap
}
func NewBasicAuth() BasicAuth {
return BasicAuth{
data: NewConcurrentMap(),
}
}
func (ba *BasicAuth) AddFromFile(file string) (n int, err error) {
_content, err := ioutil.ReadFile(file)
if err != nil {
return
}
userpassArr := strings.Split(strings.Replace(string(_content), "\r", "", -1), "\n")
for _, userpass := range userpassArr {
if strings.HasPrefix("#", userpass) {
continue
}
u := strings.Split(strings.Trim(userpass, " "), ":")
if len(u) == 2 {
ba.data.Set(u[0], u[1])
n++
}
}
return
}
func (ba *BasicAuth) Add(userpassArr []string) (n int) {
for _, userpass := range userpassArr {
u := strings.Split(userpass, ":")
if len(u) == 2 {
ba.data.Set(u[0], u[1])
n++
}
}
return
}
func (ba *BasicAuth) Check(userpass string) (ok bool) {
u := strings.Split(strings.Trim(userpass, " "), ":")
if len(u) == 2 {
if p, _ok := ba.data.Get(u[0]); _ok {
return p.(string) == u[1]
}
}
return
}
func (ba *BasicAuth) Total() (n int) {
n = ba.data.Count()
return
}
type HTTPRequest struct {
HeadBuf []byte
conn *net.Conn
Host string
Method string
URL string
hostOrURL string
isBasicAuth bool
basicAuth *BasicAuth
}
func NewHTTPRequest(inConn *net.Conn, bufSize int, isBasicAuth bool, basicAuth *BasicAuth) (req HTTPRequest, err error) {
buf := make([]byte, bufSize)
len := 0
req = HTTPRequest{
conn: inConn,
}
len, err = (*inConn).Read(buf[:])
if err != nil {
if err != io.EOF {
err = fmt.Errorf("http decoder read err:%s", err)
}
CloseConn(inConn)
return
}
req.HeadBuf = buf[:len]
index := bytes.IndexByte(req.HeadBuf, '\n')
if index == -1 {
err = fmt.Errorf("http decoder data line err:%s", string(req.HeadBuf)[:50])
CloseConn(inConn)
return
}
fmt.Sscanf(string(req.HeadBuf[:index]), "%s%s", &req.Method, &req.hostOrURL)
if req.Method == "" || req.hostOrURL == "" {
err = fmt.Errorf("http decoder data err:%s", string(req.HeadBuf)[:50])
CloseConn(inConn)
return
}
req.Method = strings.ToUpper(req.Method)
req.isBasicAuth = isBasicAuth
req.basicAuth = basicAuth
log.Printf("%s:%s", req.Method, req.hostOrURL)
if req.IsHTTPS() {
err = req.HTTPS()
} else {
err = req.HTTP()
}
return
}
func (req *HTTPRequest) HTTP() (err error) {
if req.isBasicAuth {
err = req.BasicAuth()
if err != nil {
return
}
}
req.URL, err = req.getHTTPURL()
if err == nil {
u, _ := url.Parse(req.URL)
req.Host = u.Host
req.addPortIfNot()
}
return
}
func (req *HTTPRequest) HTTPS() (err error) {
req.Host = req.hostOrURL
req.addPortIfNot()
//_, err = fmt.Fprint(*req.conn, "HTTP/1.1 200 Connection established\r\n\r\n")
return
}
func (req *HTTPRequest) HTTPSReply() (err error) {
_, err = fmt.Fprint(*req.conn, "HTTP/1.1 200 Connection established\r\n\r\n")
return
}
func (req *HTTPRequest) IsHTTPS() bool {
return req.Method == "CONNECT"
}
func (req *HTTPRequest) BasicAuth() (err error) {
//log.Printf("request :%s", string(b[:n]))
authorization, err := req.getHeader("Authorization")
if err != nil {
fmt.Fprint((*req.conn), "HTTP/1.1 401 Unauthorized\r\nWWW-Authenticate: Basic realm=\"\"\r\n\r\nUnauthorized")
CloseConn(req.conn)
return
}
//log.Printf("Authorization:%s", authorization)
basic := strings.Fields(authorization)
if len(basic) != 2 {
err = fmt.Errorf("authorization data error,ERR:%s", authorization)
CloseConn(req.conn)
return
}
user, err := base64.StdEncoding.DecodeString(basic[1])
if err != nil {
err = fmt.Errorf("authorization data parse error,ERR:%s", err)
CloseConn(req.conn)
return
}
authOk := (*req.basicAuth).Check(string(user))
//log.Printf("auth %s,%v", string(user), authOk)
if !authOk {
fmt.Fprint((*req.conn), "HTTP/1.1 401 Unauthorized\r\n\r\nUnauthorized")
CloseConn(req.conn)
err = fmt.Errorf("basic auth fail")
return
}
return
}
func (req *HTTPRequest) getHTTPURL() (URL string, err error) {
if !strings.HasPrefix(req.hostOrURL, "/") {
return req.hostOrURL, nil
}
_host, err := req.getHeader("host")
if err != nil {
return
}
URL = fmt.Sprintf("http://%s%s", _host, req.hostOrURL)
return
}
func (req *HTTPRequest) getHeader(key string) (val string, err error) {
key = strings.ToUpper(key)
lines := strings.Split(string(req.HeadBuf), "\r\n")
for _, line := range lines {
line := strings.SplitN(strings.Trim(line, "\r\n "), ":", 2)
if len(line) == 2 {
k := strings.ToUpper(strings.Trim(line[0], " "))
v := strings.Trim(line[1], " ")
if key == k {
val = v
return
}
}
}
err = fmt.Errorf("can not find HOST header")
return
}
func (req *HTTPRequest) addPortIfNot() (newHost string) {
//newHost = req.Host
port := "80"
if req.IsHTTPS() {
port = "443"
}
if (!strings.HasPrefix(req.Host, "[") && strings.Index(req.Host, ":") == -1) || (strings.HasPrefix(req.Host, "[") && strings.HasSuffix(req.Host, "]")) {
//newHost = req.Host + ":" + port
//req.headBuf = []byte(strings.Replace(string(req.headBuf), req.Host, newHost, 1))
req.Host = req.Host + ":" + port
}
return
}
type OutPool struct {
Pool ConnPool
dur int
isTLS bool
certBytes []byte
keyBytes []byte
address string
timeout int
}
func NewOutPool(dur int, isTLS bool, certBytes, keyBytes []byte, address string, timeout int, InitialCap int, MaxCap int) (op OutPool) {
op = OutPool{
dur: dur,
isTLS: isTLS,
certBytes: certBytes,
keyBytes: keyBytes,
address: address,
timeout: timeout,
}
var err error
op.Pool, err = NewConnPool(poolConfig{
IsActive: func(conn interface{}) bool { return true },
Release: func(conn interface{}) {
if conn != nil {
conn.(net.Conn).SetDeadline(time.Now().Add(time.Millisecond))
conn.(net.Conn).Close()
// log.Println("conn released")
}
},
InitialCap: InitialCap,
MaxCap: MaxCap,
Factory: func() (conn interface{}, err error) {
conn, err = op.getConn()
return
},
})
if err != nil {
log.Fatalf("init conn pool fail ,%s", err)
} else {
if InitialCap > 0 {
log.Printf("init conn pool success")
op.initPoolDeamon()
} else {
log.Printf("conn pool closed")
}
}
return
}
func (op *OutPool) getConn() (conn interface{}, err error) {
if op.isTLS {
var _conn tls.Conn
_conn, err = TlsConnectHost(op.address, op.timeout, op.certBytes, op.keyBytes)
if err == nil {
conn = net.Conn(&_conn)
}
} else {
conn, err = ConnectHost(op.address, op.timeout)
}
return
}
func (op *OutPool) initPoolDeamon() {
go func() {
if op.dur <= 0 {
return
}
log.Printf("pool deamon started")
for {
time.Sleep(time.Second * time.Duration(op.dur))
conn, err := op.getConn()
if err != nil {
log.Printf("pool deamon err %s , release pool", err)
op.Pool.ReleaseAll()
} else {
conn.(net.Conn).SetDeadline(time.Now().Add(time.Millisecond))
conn.(net.Conn).Close()
}
}
}()
}