diff --git a/.golangci.yml b/.golangci.yml new file mode 100644 index 0000000..b687be7 --- /dev/null +++ b/.golangci.yml @@ -0,0 +1,120 @@ +linters: + disable-all: true + enable: + # Re-enable the default linters + - errcheck + - gosimple + - govet + - ineffassign + - staticcheck + - typecheck + - unused + + # Enable the "always useful" linters as of 1.53.3 + - asasalint + - asciicheck + - bidichk + - decorder + - dogsled + - dupl + - dupword + - errchkjson + - errname + - errorlint + - exhaustive + - exportloopref + - forbidigo + - forcetypeassert + - gci + - gocheckcompilerdirectives + - gochecknoinits + - goconst + - gocritic + - godot + - godox + - gofumpt + - gomoddirectives + - goprintffuncname + - gosec + - gosmopolitan + - importas + - interfacebloat + - makezero + - mirror + - misspell + - musttag + - nakedret + - nestif + - nilerr + - nilnil + - nolintlint + - nonamedreturns + - prealloc + - predeclared + - reassign + - revive + - tagalign + - tenv + - testableexamples + - unconvert + - unparam + - usestdlibvars + - wastedassign + - whitespace + + # Disabled: Run periodically, but too many places to annotate + # - gomnd + + # Disabled: Not how I do things + # - exhaustruct # Zero value is fine. + # - funlen # I'm not breaking up my math. + # - gochecknoglobals # How else am I supposed to declare constants. + # - lll # The 70s called and wants their ttys back. + # - paralleltest + # - varnamelen # The papers use short variable names. + # - tagliatelle # I want my tags to match the files. + # - thelper + # - tparallel + # - testpackage + # - wsl # Nice idea, not how I like to write code. + # - goerr113 # Nice idea, this package has too much legacy bs. + # - ireturn # By virtue of the PT API we are interface heavy. + + # Disabled: Annoying/Useless + # - cyclop + # - gocognit + # - gocyclo + # - maintidx + # - wrapcheck + + # Disabled: Irrelevant/redundant + # - bodyclose + # - containedctx + # - contextcheck + # - depguard + # - durationcheck + # - execinquery + # - ginkgolinter + # - gofmt + # - goheader + # - goimports + # - gomodguard + # - grouper + # - loggercheck + # - nlreturn + # - noctx + # - nosprintfhostport + # - promlinter + # - rowserrcheck + # - sqlclosecheck + # - stylecheck + # - zerologlint + +linters-settings: + gci: + sections: + - standard + - default + - prefix(gitlab.com/yawning/obfs4.git) + skip-generated: true + custom-order: true diff --git a/common/csrand/csrand.go b/common/csrand/csrand.go index ddf14e8..e22c184 100644 --- a/common/csrand/csrand.go +++ b/common/csrand/csrand.go @@ -45,7 +45,7 @@ var ( csRandSourceInstance csRandSource // Rand is a math/rand instance backed by crypto/rand CSPRNG. - Rand = rand.New(csRandSourceInstance) + Rand = rand.New(csRandSourceInstance) //nolint:gosec ) type csRandSource struct { @@ -63,7 +63,7 @@ func (r csRandSource) Int63() int64 { return int64(val) } -func (r csRandSource) Seed(seed int64) { +func (r csRandSource) Seed(_ int64) { // No-op. } diff --git a/common/drbg/hash_drbg.go b/common/drbg/hash_drbg.go index 5a9cc7f..543345e 100644 --- a/common/drbg/hash_drbg.go +++ b/common/drbg/hash_drbg.go @@ -30,12 +30,14 @@ package drbg // import "gitlab.com/yawning/obfs4.git/common/drbg" import ( + "bytes" "encoding/binary" "encoding/hex" "fmt" "hash" "github.com/dchest/siphash" + "gitlab.com/yawning/obfs4.git/common/csrand" ) @@ -60,33 +62,33 @@ func (seed *Seed) Hex() string { } // NewSeed returns a Seed initialized with the runtime CSPRNG. -func NewSeed() (seed *Seed, err error) { - seed = new(Seed) - if err = csrand.Bytes(seed.Bytes()[:]); err != nil { +func NewSeed() (*Seed, error) { + seed := new(Seed) + if err := csrand.Bytes(seed.Bytes()[:]); err != nil { return nil, err } - return + return seed, nil } // SeedFromBytes creates a Seed from the raw bytes, truncating to SeedLength as // appropriate. -func SeedFromBytes(src []byte) (seed *Seed, err error) { +func SeedFromBytes(src []byte) (*Seed, error) { if len(src) < SeedLength { return nil, InvalidSeedLengthError(len(src)) } - seed = new(Seed) + seed := new(Seed) copy(seed.Bytes()[:], src) - return + return seed, nil } // SeedFromHex creates a Seed from the hexdecimal representation, truncating to // SeedLength as appropriate. -func SeedFromHex(encoded string) (seed *Seed, err error) { - var raw []byte - if raw, err = hex.DecodeString(encoded); err != nil { +func SeedFromHex(encoded string) (*Seed, error) { + raw, err := hex.DecodeString(encoded) + if err != nil { return nil, err } @@ -133,7 +135,7 @@ func (drbg *HashDrbg) Int63() int64 { } // Seed does nothing, call NewHashDrbg if you want to reseed. -func (drbg *HashDrbg) Seed(seed int64) { +func (drbg *HashDrbg) Seed(_ int64) { // No-op. } @@ -142,7 +144,5 @@ func (drbg *HashDrbg) NextBlock() []byte { _, _ = drbg.sip.Write(drbg.ofb[:]) copy(drbg.ofb[:], drbg.sip.Sum(nil)) - ret := make([]byte, Size) - copy(ret, drbg.ofb[:]) - return ret + return bytes.Clone(drbg.ofb[:]) } diff --git a/common/log/log.go b/common/log/log.go index 5e08c64..b8a5eaa 100644 --- a/common/log/log.go +++ b/common/log/log.go @@ -30,8 +30,9 @@ package log // import "gitlab.com/yawning/obfs4.git/common/log" import ( + "errors" "fmt" - "io/ioutil" + "io" "log" "net" "os" @@ -54,20 +55,22 @@ const ( LevelDebug ) -var logLevel = LevelInfo -var enableLogging bool -var unsafeLogging bool +var ( + logLevel = LevelInfo + enableLogging bool + unsafeLogging bool +) // Init initializes logging with the given path, and log safety options. func Init(enable bool, logFilePath string, unsafe bool) error { if enable { - f, err := os.OpenFile(logFilePath, os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0600) + f, err := os.OpenFile(logFilePath, os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0o600) if err != nil { return err } log.SetOutput(f) } else { - log.SetOutput(ioutil.Discard) + log.SetOutput(io.Discard) } enableLogging = enable unsafeLogging = unsafe @@ -163,8 +166,8 @@ func ElideError(err error) string { // If err is not a net.Error, just return the string representation, // presumably transport authors know what they are doing. - netErr, ok := err.(net.Error) - if !ok { + var netErr net.Error + if !errors.As(err, &netErr) { return err.Error() } diff --git a/common/ntor/ntor.go b/common/ntor/ntor.go index 17a9ff7..346507a 100644 --- a/common/ntor/ntor.go +++ b/common/ntor/ntor.go @@ -70,15 +70,17 @@ const ( // KeySeedLength is the length of the derived KEY_SEED. KeySeedLength = sha256.Size - // AuthLength is the lenght of the derived AUTH. + // AuthLength is the length of the derived AUTH. AuthLength = sha256.Size ) -var protoID = []byte("ntor-curve25519-sha256-1") -var tMac = append(protoID, []byte(":mac")...) -var tKey = append(protoID, []byte(":key_extract")...) -var tVerify = append(protoID, []byte(":key_verify")...) -var mExpand = append(protoID, []byte(":key_expand")...) +var ( + protoID = []byte("ntor-curve25519-sha256-1") + tMac = append(protoID, []byte(":mac")...) + tKey = append(protoID, []byte(":key_extract")...) + tVerify = append(protoID, []byte(":key_verify")...) + mExpand = append(protoID, []byte(":key_expand")...) +) // PublicKeyLengthError is the error returned when the public key being // imported is an invalid length. @@ -320,48 +322,43 @@ func KeypairFromHex(encoded string) (*Keypair, error) { // ServerHandshake does the server side of a ntor handshake and returns status, // KEY_SEED, and AUTH. If status is not true, the handshake MUST be aborted. -func ServerHandshake(clientPublic *PublicKey, serverKeypair *Keypair, idKeypair *Keypair, id *NodeID) (ok bool, keySeed *KeySeed, auth *Auth) { +func ServerHandshake(clientPublic *PublicKey, serverKeypair *Keypair, idKeypair *Keypair, id *NodeID) (bool, *KeySeed, *Auth) { var notOk int var secretInput bytes.Buffer // Server side uses EXP(X,y) | EXP(X,b) var exp [SharedSecretLength]byte - curve25519.ScalarMult(&exp, serverKeypair.private.Bytes(), - clientPublic.Bytes()) + curve25519.ScalarMult(&exp, serverKeypair.private.Bytes(), clientPublic.Bytes()) //nolint:staticcheck notOk |= constantTimeIsZero(exp[:]) secretInput.Write(exp[:]) - curve25519.ScalarMult(&exp, idKeypair.private.Bytes(), - clientPublic.Bytes()) + curve25519.ScalarMult(&exp, idKeypair.private.Bytes(), clientPublic.Bytes()) //nolint:staticcheck notOk |= constantTimeIsZero(exp[:]) secretInput.Write(exp[:]) - keySeed, auth = ntorCommon(secretInput, id, idKeypair.public, + keySeed, auth := ntorCommon(secretInput, id, idKeypair.public, clientPublic, serverKeypair.public) return notOk == 0, keySeed, auth } // ClientHandshake does the client side of a ntor handshake and returnes // status, KEY_SEED, and AUTH. If status is not true or AUTH does not match -// the value recieved from the server, the handshake MUST be aborted. -func ClientHandshake(clientKeypair *Keypair, serverPublic *PublicKey, idPublic *PublicKey, id *NodeID) (ok bool, keySeed *KeySeed, auth *Auth) { +// the value received from the server, the handshake MUST be aborted. +func ClientHandshake(clientKeypair *Keypair, serverPublic *PublicKey, idPublic *PublicKey, id *NodeID) (bool, *KeySeed, *Auth) { var notOk int var secretInput bytes.Buffer // Client side uses EXP(Y,x) | EXP(B,x) var exp [SharedSecretLength]byte - curve25519.ScalarMult(&exp, clientKeypair.private.Bytes(), - serverPublic.Bytes()) + curve25519.ScalarMult(&exp, clientKeypair.private.Bytes(), serverPublic.Bytes()) //nolint:staticcheck notOk |= constantTimeIsZero(exp[:]) secretInput.Write(exp[:]) - curve25519.ScalarMult(&exp, clientKeypair.private.Bytes(), - idPublic.Bytes()) + curve25519.ScalarMult(&exp, clientKeypair.private.Bytes(), idPublic.Bytes()) //nolint:staticcheck notOk |= constantTimeIsZero(exp[:]) secretInput.Write(exp[:]) - keySeed, auth = ntorCommon(secretInput, id, idPublic, - clientKeypair.public, serverPublic) + keySeed, auth := ntorCommon(secretInput, id, idPublic, clientKeypair.public, serverPublic) return notOk == 0, keySeed, auth } @@ -402,7 +399,7 @@ func ntorCommon(secretInput bytes.Buffer, id *NodeID, b *PublicKey, x *PublicKey // auth_input = verify | ID | B | Y | X | PROTOID | "Server" authInput := bytes.NewBuffer(verify) _, _ = authInput.Write(suffix.Bytes()) - _, _ = authInput.Write([]byte("Server")) + _, _ = authInput.WriteString("Server") h = hmac.New(sha256.New, tMac) _, _ = h.Write(authInput.Bytes()) tmp = h.Sum(nil) diff --git a/common/probdist/weighted_dist.go b/common/probdist/weighted_dist.go index d0d380c..72f1e22 100644 --- a/common/probdist/weighted_dist.go +++ b/common/probdist/weighted_dist.go @@ -64,8 +64,8 @@ type WeightedDist struct { // based on a HashDrbg initialized with seed. Optionally, bias the weight // generation to match the ScrambleSuit non-uniform distribution from // obfsproxy. -func New(seed *drbg.Seed, min, max int, biased bool) (w *WeightedDist) { - w = &WeightedDist{minValue: min, maxValue: max, biased: biased} +func New(seed *drbg.Seed, min, max int, biased bool) *WeightedDist { + w := &WeightedDist{minValue: min, maxValue: max, biased: biased} if max <= min { panic(fmt.Sprintf("wDist.Reset(): min >= max (%d, %d)", min, max)) @@ -73,7 +73,7 @@ func New(seed *drbg.Seed, min, max int, biased bool) (w *WeightedDist) { w.Reset(seed) - return + return w } // genValues creates a slice containing a random number of random values @@ -132,7 +132,7 @@ func (w *WeightedDist) genTables() { scaled := make([]float64, n) for i, weight := range w.weights { // Multiply each probability by $n$. - p_i := weight * float64(n) / sum + p_i := weight * float64(n) / sum //nolint:revive scaled[i] = p_i // For each scaled probability $p_i$: @@ -148,9 +148,9 @@ func (w *WeightedDist) genTables() { // While $Small$ and $Large$ are not empty: ($Large$ might be emptied first) for small.Len() > 0 && large.Len() > 0 { // Remove the first element from $Small$; call it $l$. - l := small.Remove(small.Front()).(int) + l, _ := small.Remove(small.Front()).(int) // Remove the first element from $Large$; call it $g$. - g := large.Remove(large.Front()).(int) + g, _ := large.Remove(large.Front()).(int) // Set $Prob[l] = p_l$. prob[l] = scaled[l] @@ -172,7 +172,7 @@ func (w *WeightedDist) genTables() { // While $Large$ is not empty: for large.Len() > 0 { // Remove the first element from $Large$; call it $g$. - g := large.Remove(large.Front()).(int) + g, _ := large.Remove(large.Front()).(int) // Set $Prob[g] = 1$. prob[g] = 1.0 } @@ -180,7 +180,7 @@ func (w *WeightedDist) genTables() { // While $Small$ is not empty: This is only possible due to numerical instability. for small.Len() > 0 { // Remove the first element from $Small$; call it $l$. - l := small.Remove(small.Front()).(int) + l, _ := small.Remove(small.Front()).(int) // Set $Prob[l] = 1$. prob[l] = 1.0 } @@ -194,7 +194,7 @@ func (w *WeightedDist) genTables() { func (w *WeightedDist) Reset(seed *drbg.Seed) { // Initialize the deterministic random number generator. drbg, _ := drbg.NewHashDrbg(seed) - rng := rand.New(drbg) + rng := rand.New(drbg) //nolint:gosec w.Lock() defer w.Unlock() diff --git a/common/probdist/weighted_dist_test.go b/common/probdist/weighted_dist_test.go index 4619212..883485c 100644 --- a/common/probdist/weighted_dist_test.go +++ b/common/probdist/weighted_dist_test.go @@ -28,7 +28,6 @@ package probdist import ( - "fmt" "testing" "gitlab.com/yawning/obfs4.git/common/drbg" @@ -49,7 +48,7 @@ func TestWeightedDist(t *testing.T) { w := New(seed, 0, 999, true) if debug { // Dump a string representation of the probability table. - fmt.Println("Table:") + t.Logf("Table:") var sum float64 for _, weight := range w.weights { sum += weight @@ -57,10 +56,9 @@ func TestWeightedDist(t *testing.T) { for i, weight := range w.weights { p := weight / sum if p > 0.000001 { // Filter out tiny values. - fmt.Printf(" [%d]: %f\n", w.minValue+w.values[i], p) + t.Logf(" [%d]: %f", w.minValue+w.values[i], p) } } - fmt.Println() } for i := 0; i < nrTrials; i++ { @@ -69,11 +67,11 @@ func TestWeightedDist(t *testing.T) { } if debug { - fmt.Println("Generated:") + t.Logf("Generated:") for value, count := range hist { if count != 0 { p := float64(count) / float64(nrTrials) - fmt.Printf(" [%d]: %f (%d)\n", value, p, count) + t.Logf(" [%d]: %f (%d)", value, p, count) } } } diff --git a/common/replayfilter/replay_filter.go b/common/replayfilter/replay_filter.go index 00ee3e9..3823d9b 100644 --- a/common/replayfilter/replay_filter.go +++ b/common/replayfilter/replay_filter.go @@ -39,6 +39,7 @@ import ( "time" "github.com/dchest/siphash" + "gitlab.com/yawning/obfs4.git/common/csrand" ) @@ -67,21 +68,21 @@ type ReplayFilter struct { } // New creates a new ReplayFilter instance. -func New(ttl time.Duration) (filter *ReplayFilter, err error) { +func New(ttl time.Duration) (*ReplayFilter, error) { // Initialize the SipHash-2-4 instance with a random key. var key [16]byte - if err = csrand.Bytes(key[:]); err != nil { - return + if err := csrand.Bytes(key[:]); err != nil { + return nil, err } - filter = new(ReplayFilter) + filter := new(ReplayFilter) filter.filter = make(map[uint64]*entry) filter.fifo = list.New() filter.key[0] = binary.BigEndian.Uint64(key[0:8]) filter.key[1] = binary.BigEndian.Uint64(key[8:16]) filter.ttl = ttl - return + return filter, nil } // TestAndSet queries the filter for a given byte sequence, inserts the diff --git a/common/socks5/args.go b/common/socks5/args.go index a5efb43..479d710 100644 --- a/common/socks5/args.go +++ b/common/socks5/args.go @@ -29,7 +29,8 @@ package socks5 import ( "fmt" - "git.torproject.org/pluggable-transports/goptlib.git" + + "gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/goptlib" ) // parseClientParameters takes a client parameter string formatted according to @@ -37,14 +38,14 @@ import ( // specification, and returns it as a goptlib Args structure. // // This is functionally identical to the equivalently named goptlib routine. -func parseClientParameters(argStr string) (args pt.Args, err error) { - args = make(pt.Args) +func parseClientParameters(argStr string) (pt.Args, error) { + args := make(pt.Args) if len(argStr) == 0 { - return + return args, nil } var key string - var acc []byte + acc := make([]byte, 0, len(argStr)) prevIsEscape := false for idx, ch := range []byte(argStr) { switch ch { diff --git a/common/socks5/args_test.go b/common/socks5/args_test.go index d9d3f22..8683e8c 100644 --- a/common/socks5/args_test.go +++ b/common/socks5/args_test.go @@ -5,7 +5,7 @@ package socks5 import ( "testing" - "git.torproject.org/pluggable-transports/goptlib.git" + "gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/goptlib" ) func stringSlicesEqual(a, b []string) bool { diff --git a/common/socks5/rfc1929.go b/common/socks5/rfc1929.go index 93a2c16..42162e1 100644 --- a/common/socks5/rfc1929.go +++ b/common/socks5/rfc1929.go @@ -35,12 +35,12 @@ const ( authRFC1929Fail = 0x01 ) -func (req *Request) authRFC1929() (err error) { - sendErrResp := func() { +func (req *Request) authRFC1929() error { + sendErrResp := func(err error) error { // Swallow write/flush errors, the auth failure is the relevant error. - resp := []byte{authRFC1929Ver, authRFC1929Fail} - _, _ = req.rw.Write(resp[:]) + _, _ = req.rw.Write([]byte{authRFC1929Ver, authRFC1929Fail}) _ = req.flushBuffers() + return err // Pass this through from the arg. } // The client sends a Username/Password request. @@ -50,39 +50,35 @@ func (req *Request) authRFC1929() (err error) { // uint8_t plen (>= 1) // uint8_t passwd[plen] - if err = req.readByteVerify("auth version", authRFC1929Ver); err != nil { - sendErrResp() - return + if err := req.readByteVerify("auth version", authRFC1929Ver); err != nil { + return sendErrResp(err) } // Read the username. - var ulen byte + var ( + ulen byte + err error + ) if ulen, err = req.readByte(); err != nil { - sendErrResp() - return + return sendErrResp(err) } else if ulen < 1 { - sendErrResp() - return fmt.Errorf("username with 0 length") + return sendErrResp(fmt.Errorf("username with 0 length")) } var uname []byte if uname, err = req.readBytes(int(ulen)); err != nil { - sendErrResp() - return + return sendErrResp(err) } // Read the password. var plen byte if plen, err = req.readByte(); err != nil { - sendErrResp() - return + return sendErrResp(err) } else if plen < 1 { - sendErrResp() - return fmt.Errorf("password with 0 length") + return sendErrResp(fmt.Errorf("password with 0 length")) } var passwd []byte if passwd, err = req.readBytes(int(plen)); err != nil { - sendErrResp() - return + return sendErrResp(err) } // Pluggable transports use the username/password field to pass @@ -95,11 +91,10 @@ func (req *Request) authRFC1929() (err error) { argStr += string(passwd) } if req.Args, err = parseClientParameters(argStr); err != nil { - sendErrResp() - return + return sendErrResp(err) } resp := []byte{authRFC1929Ver, authRFC1929Success} - _, err = req.rw.Write(resp[:]) - return + _, err = req.rw.Write(resp) + return err } diff --git a/common/socks5/socks5.go b/common/socks5/socks5.go index 7630d3d..28976a1 100644 --- a/common/socks5/socks5.go +++ b/common/socks5/socks5.go @@ -30,23 +30,24 @@ // 1929. // // Notes: -// * GSSAPI authentication, is NOT supported. -// * Only the CONNECT command is supported. -// * The authentication provided by the client is always accepted as it is -// used as a channel to pass information rather than for authentication for -// pluggable transports. +// - GSSAPI authentication, is NOT supported. +// - Only the CONNECT command is supported. +// - The authentication provided by the client is always accepted as it is +// used as a channel to pass information rather than for authentication for +// pluggable transports. package socks5 // import "gitlab.com/yawning/obfs4.git/common/socks5" import ( "bufio" "bytes" + "errors" "fmt" "io" "net" "syscall" "time" - "git.torproject.org/pluggable-transports/goptlib.git" + "gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/goptlib" ) const ( @@ -89,16 +90,16 @@ func Version() string { // ErrorToReplyCode converts an error to the "best" reply code. func ErrorToReplyCode(err error) ReplyCode { - opErr, ok := err.(*net.OpError) - if !ok { + var opErr *net.OpError + if !errors.As(err, &opErr) { return ReplyGeneralFailure } - errno, ok := opErr.Err.(syscall.Errno) - if !ok { + var errno syscall.Errno + if !errors.As(opErr.Err, &errno) { return ReplyGeneralFailure } - switch errno { + switch errno { //nolint:exhaustive case syscall.EADDRNOTAVAIL: return ReplyAddressNotSupported case syscall.ETIMEDOUT: @@ -307,7 +308,7 @@ func (req *Request) readCommand() error { return err } addr := make(net.IP, net.IPv6len) - copy(addr[:], rawAddr[:]) + copy(addr[:], rawAddr) host = fmt.Sprintf("[%s]", addr.String()) default: _ = req.Reply(ReplyAddressNotSupported) diff --git a/common/socks5/socks_test.go b/common/socks5/socks_test.go index ace7fd3..433904e 100644 --- a/common/socks5/socks_test.go +++ b/common/socks5/socks_test.go @@ -48,11 +48,11 @@ type testReadWriter struct { writeBuf bytes.Buffer } -func (c *testReadWriter) Read(buf []byte) (n int, err error) { +func (c *testReadWriter) Read(buf []byte) (int, error) { return c.readBuf.Read(buf) } -func (c *testReadWriter) Write(buf []byte) (n int, err error) { +func (c *testReadWriter) Write(buf []byte) (int, error) { return c.writeBuf.Write(buf) } @@ -96,11 +96,11 @@ func TestAuthInvalidVersion(t *testing.T) { // VER = 03, NMETHODS = 01, METHODS = [00] c.writeHex("030100") if _, err := req.negotiateAuth(); err == nil { - t.Error("negotiateAuth(InvalidVersion) succeded") + t.Error("negotiateAuth(InvalidVersion) succeeded") } } -// TestAuthInvalidNMethods tests auth negotiaton with no methods. +// TestAuthInvalidNMethods tests auth negotiation with no methods. func TestAuthInvalidNMethods(t *testing.T) { c := new(testReadWriter) req := c.toRequest() @@ -120,7 +120,7 @@ func TestAuthInvalidNMethods(t *testing.T) { } } -// TestAuthNoneRequired tests auth negotiaton with NO AUTHENTICATION REQUIRED. +// TestAuthNoneRequired tests auth negotiation with NO AUTHENTICATION REQUIRED. func TestAuthNoneRequired(t *testing.T) { c := new(testReadWriter) req := c.toRequest() @@ -230,7 +230,7 @@ func TestRFC1929InvalidVersion(t *testing.T) { // VER = 03, ULEN = 5, UNAME = "ABCDE", PLEN = 5, PASSWD = "abcde" c.writeHex("03054142434445056162636465") if err := req.authenticate(authUsernamePassword); err == nil { - t.Error("authenticate(InvalidVersion) succeded") + t.Error("authenticate(InvalidVersion) succeeded") } if msg := c.readHex(); msg != "0101" { t.Error("authenticate(InvalidVersion) invalid response:", msg) @@ -245,7 +245,7 @@ func TestRFC1929InvalidUlen(t *testing.T) { // VER = 01, ULEN = 0, UNAME = "", PLEN = 5, PASSWD = "abcde" c.writeHex("0100056162636465") if err := req.authenticate(authUsernamePassword); err == nil { - t.Error("authenticate(InvalidUlen) succeded") + t.Error("authenticate(InvalidUlen) succeeded") } if msg := c.readHex(); msg != "0101" { t.Error("authenticate(InvalidUlen) invalid response:", msg) @@ -260,7 +260,7 @@ func TestRFC1929InvalidPlen(t *testing.T) { // VER = 01, ULEN = 5, UNAME = "ABCDE", PLEN = 0, PASSWD = "" c.writeHex("0105414243444500") if err := req.authenticate(authUsernamePassword); err == nil { - t.Error("authenticate(InvalidPlen) succeded") + t.Error("authenticate(InvalidPlen) succeeded") } if msg := c.readHex(); msg != "0101" { t.Error("authenticate(InvalidPlen) invalid response:", msg) @@ -275,7 +275,7 @@ func TestRFC1929InvalidPTArgs(t *testing.T) { // VER = 01, ULEN = 5, UNAME = "ABCDE", PLEN = 5, PASSWD = "abcde" c.writeHex("01054142434445056162636465") if err := req.authenticate(authUsernamePassword); err == nil { - t.Error("authenticate(InvalidArgs) succeded") + t.Error("authenticate(InvalidArgs) succeeded") } if msg := c.readHex(); msg != "0101" { t.Error("authenticate(InvalidArgs) invalid response:", msg) @@ -301,7 +301,7 @@ func TestRFC1929Success(t *testing.T) { } } -// TestRequestInvalidHdr tests SOCKS5 requests with invalid VER/CMD/RSV/ATYPE +// TestRequestInvalidHdr tests SOCKS5 requests with invalid VER/CMD/RSV/ATYPE. func TestRequestInvalidHdr(t *testing.T) { c := new(testReadWriter) req := c.toRequest() @@ -309,7 +309,7 @@ func TestRequestInvalidHdr(t *testing.T) { // VER = 03, CMD = 01, RSV = 00, ATYPE = 01, DST.ADDR = 127.0.0.1, DST.PORT = 9050 c.writeHex("030100017f000001235a") if err := req.readCommand(); err == nil { - t.Error("readCommand(InvalidVer) succeded") + t.Error("readCommand(InvalidVer) succeeded") } if msg := c.readHex(); msg != "05010001000000000000" { t.Error("readCommand(InvalidVer) invalid response:", msg) @@ -319,7 +319,7 @@ func TestRequestInvalidHdr(t *testing.T) { // VER = 05, CMD = 05, RSV = 00, ATYPE = 01, DST.ADDR = 127.0.0.1, DST.PORT = 9050 c.writeHex("050500017f000001235a") if err := req.readCommand(); err == nil { - t.Error("readCommand(InvalidCmd) succeded") + t.Error("readCommand(InvalidCmd) succeeded") } if msg := c.readHex(); msg != "05070001000000000000" { t.Error("readCommand(InvalidCmd) invalid response:", msg) @@ -329,7 +329,7 @@ func TestRequestInvalidHdr(t *testing.T) { // VER = 05, CMD = 01, RSV = 30, ATYPE = 01, DST.ADDR = 127.0.0.1, DST.PORT = 9050 c.writeHex("050130017f000001235a") if err := req.readCommand(); err == nil { - t.Error("readCommand(InvalidRsv) succeded") + t.Error("readCommand(InvalidRsv) succeeded") } if msg := c.readHex(); msg != "05010001000000000000" { t.Error("readCommand(InvalidRsv) invalid response:", msg) @@ -339,7 +339,7 @@ func TestRequestInvalidHdr(t *testing.T) { // VER = 05, CMD = 01, RSV = 01, ATYPE = 05, DST.ADDR = 127.0.0.1, DST.PORT = 9050 c.writeHex("050100057f000001235a") if err := req.readCommand(); err == nil { - t.Error("readCommand(InvalidAtype) succeded") + t.Error("readCommand(InvalidAtype) succeeded") } if msg := c.readHex(); msg != "05080001000000000000" { t.Error("readCommand(InvalidAtype) invalid response:", msg) diff --git a/common/uniformdh/uniformdh.go b/common/uniformdh/uniformdh.go index 1d500a3..47109b6 100644 --- a/common/uniformdh/uniformdh.go +++ b/common/uniformdh/uniformdh.go @@ -32,6 +32,7 @@ package uniformdh // import "gitlab.com/yawning/obfs4.git/common/uniformdh" import ( + "bytes" "fmt" "io" "math/big" @@ -54,8 +55,16 @@ const ( g = 2 ) -var modpGroup *big.Int -var gen *big.Int +var ( + modpGroup = func() *big.Int { + n, ok := new(big.Int).SetString(modpStr, 16) + if !ok { + panic("Failed to load the RFC3526 MODP Group") + } + return n + }() + gen = big.NewInt(g) +) // A PrivateKey represents a UniformDH private key. type PrivateKey struct { @@ -70,14 +79,11 @@ type PublicKey struct { } // Bytes returns the byte representation of a PublicKey. -func (pub *PublicKey) Bytes() (pubBytes []byte, err error) { +func (pub *PublicKey) Bytes() ([]byte, error) { if len(pub.bytes) != Size || pub.bytes == nil { return nil, fmt.Errorf("public key is not initialized") } - pubBytes = make([]byte, Size) - copy(pubBytes, pub.bytes) - - return + return bytes.Clone(pub.bytes), nil } // SetBytes sets the PublicKey from a byte slice. @@ -85,25 +91,22 @@ func (pub *PublicKey) SetBytes(pubBytes []byte) error { if len(pubBytes) != Size { return fmt.Errorf("public key length %d is not %d", len(pubBytes), Size) } - pub.bytes = make([]byte, Size) - copy(pub.bytes, pubBytes) + pub.bytes = bytes.Clone(pubBytes) pub.publicKey = new(big.Int).SetBytes(pub.bytes) return nil } // GenerateKey generates a UniformDH keypair using the random source random. -func GenerateKey(random io.Reader) (priv *PrivateKey, err error) { - privBytes := make([]byte, Size) - if _, err = io.ReadFull(random, privBytes); err != nil { - return +func GenerateKey(random io.Reader) (*PrivateKey, error) { + var privBytes [Size]byte + if _, err := io.ReadFull(random, privBytes[:]); err != nil { + return nil, err } - priv, err = generateKey(privBytes) - - return + return generateKey(privBytes[:]) } -func generateKey(privBytes []byte) (priv *PrivateKey, err error) { +func generateKey(privBytes []byte) (*PrivateKey, error) { // This function does all of the actual heavy lifting of creating a public // key from a raw 192 byte private key. It is split so that the KAT tests // can be written easily, and not exposed since non-ephemeral keys are a @@ -132,52 +135,26 @@ func generateKey(privBytes []byte) (priv *PrivateKey, err error) { // to the key so that it is always exactly Size bytes. pubBytes := make([]byte, Size) if wasEven { - err = prependZeroBytes(pubBytes, pubBn.Bytes()) + pubBn.FillBytes(pubBytes) } else { - err = prependZeroBytes(pubBytes, pubAlt.Bytes()) - } - if err != nil { - return + pubAlt.FillBytes(pubBytes) } - priv = new(PrivateKey) + priv := new(PrivateKey) priv.PublicKey.bytes = pubBytes priv.PublicKey.publicKey = pubBn priv.privateKey = privBn - return + return priv, nil } // Handshake generates a shared secret given a PrivateKey and PublicKey. -func Handshake(privateKey *PrivateKey, publicKey *PublicKey) (sharedSecret []byte, err error) { +func Handshake(privateKey *PrivateKey, publicKey *PublicKey) ([]byte, error) { // When a party wants to calculate the shared secret, she raises the // foreign public key to her private key. secretBn := new(big.Int).Exp(publicKey.publicKey, privateKey.privateKey, modpGroup) - sharedSecret = make([]byte, Size) - err = prependZeroBytes(sharedSecret, secretBn.Bytes()) - - return -} + sharedSecret := make([]byte, Size) + secretBn.FillBytes(sharedSecret) -func prependZeroBytes(dst, src []byte) error { - zeros := len(dst) - len(src) - if zeros < 0 { - return fmt.Errorf("src length is greater than destination: %d", zeros) - } - for i := 0; i < zeros; i++ { - dst[i] = 0 - } - copy(dst[zeros:], src) - - return nil -} - -func init() { - // Load the MODP group and the generator. - var ok bool - modpGroup, ok = new(big.Int).SetString(modpStr, 16) - if !ok { - panic("Failed to load the RFC3526 MODP Group") - } - gen = big.NewInt(g) + return sharedSecret, nil } diff --git a/common/uniformdh/uniformdh_test.go b/common/uniformdh/uniformdh_test.go index 326293b..ab09a8e 100644 --- a/common/uniformdh/uniformdh_test.go +++ b/common/uniformdh/uniformdh_test.go @@ -101,7 +101,14 @@ const ( "a81359543e77e4a4cfa7598a4152e4c0" ) -var xPriv, xPub, yPriv, yPub, ss []byte +var ( + // Load the test vectors into byte slices. + xPriv, _ = hex.DecodeString(xPrivStr) + xPub, _ = hex.DecodeString(xPubStr) + yPriv, _ = hex.DecodeString(yPrivStr) + yPub, _ = hex.DecodeString(yPubStr) + ss, _ = hex.DecodeString(ssStr) +) // TestGenerateKeyOdd tests creating a UniformDH keypair with a odd private // key. @@ -137,7 +144,7 @@ func TestGenerateKeyEven(t *testing.T) { } } -// TestHandshake tests conductiong a UniformDH handshake with know values. +// TestHandshake tests conducting a UniformDH handshake with know values. func TestHandshake(t *testing.T) { xX, err := generateKey(xPriv) if err != nil { @@ -193,28 +200,3 @@ func BenchmarkHandshake(b *testing.B) { _ = yX } } - -func init() { - // Load the test vectors into byte slices. - var err error - xPriv, err = hex.DecodeString(xPrivStr) - if err != nil { - panic("hex.DecodeString(xPrivStr) failed") - } - xPub, err = hex.DecodeString(xPubStr) - if err != nil { - panic("hex.DecodeString(xPubStr) failed") - } - yPriv, err = hex.DecodeString(yPrivStr) - if err != nil { - panic("hex.DecodeString(yPrivStr) failed") - } - yPub, err = hex.DecodeString(yPubStr) - if err != nil { - panic("hex.DecodeString(yPubStr) failed") - } - ss, err = hex.DecodeString(ssStr) - if err != nil { - panic("hex.DecodeString(ssStr) failed") - } -} diff --git a/go.mod b/go.mod index 96e5cee..3f24c77 100644 --- a/go.mod +++ b/go.mod @@ -2,11 +2,13 @@ module gitlab.com/yawning/obfs4.git require ( filippo.io/edwards25519 v1.0.0 - git.torproject.org/pluggable-transports/goptlib.git v1.3.0 github.com/dchest/siphash v1.2.3 - gitlab.com/yawning/edwards25519-extra.git v0.0.0-20211229043746-2f91fcc9fbdb - golang.org/x/crypto v0.9.0 - golang.org/x/net v0.10.0 + gitlab.com/yawning/edwards25519-extra.git v0.0.0-20220726154925-def713fd18e4 + gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/goptlib v1.4.0 + golang.org/x/crypto v0.11.0 + golang.org/x/net v0.12.0 ) +require golang.org/x/sys v0.10.0 // indirect + go 1.20 diff --git a/go.sum b/go.sum index c67246b..e6623c8 100644 --- a/go.sum +++ b/go.sum @@ -1,48 +1,22 @@ filippo.io/edwards25519 v1.0.0-rc.1.0.20210721174708-390f27c3be20/go.mod h1:N1IkdkCkiLB6tki+MYJoSx2JTY9NUlxZE7eHn5EwJns= filippo.io/edwards25519 v1.0.0 h1:0wAIcmJUqRdI8IJ/3eGi5/HwXZWPujYXXlkrQogz0Ek= filippo.io/edwards25519 v1.0.0/go.mod h1:N1IkdkCkiLB6tki+MYJoSx2JTY9NUlxZE7eHn5EwJns= -git.torproject.org/pluggable-transports/goptlib.git v1.3.0 h1:G+iuRUblCCC2xnO+0ag1/4+aaM98D5mjWP1M0v9s8a0= -git.torproject.org/pluggable-transports/goptlib.git v1.3.0/go.mod h1:4PBMl1dg7/3vMWSoWb46eGWlrxkUyn/CAJmxhDLAlDs= github.com/dchest/siphash v1.2.3 h1:QXwFc8cFOR2dSa/gE6o/HokBMWtLUaNDVd+22aKHeEA= github.com/dchest/siphash v1.2.3/go.mod h1:0NvQU092bT0ipiFN++/rXm69QG9tVxLAlQHIXMPAkHc= -github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= -gitlab.com/yawning/edwards25519-extra.git v0.0.0-20211229043746-2f91fcc9fbdb h1:qRSZHsODmAP5qDvb3YsO7Qnf3TRiVbGxNG/WYnlM4/o= -gitlab.com/yawning/edwards25519-extra.git v0.0.0-20211229043746-2f91fcc9fbdb/go.mod h1:gvdJuZuO/tPZyhEV8K3Hmoxv/DWud5L4qEQxfYjEUTo= -golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +gitlab.com/yawning/edwards25519-extra.git v0.0.0-20220726154925-def713fd18e4 h1:LeXiZggivkDGgmkl7+r+m/2xj3rd+K/30/0obRKayAU= +gitlab.com/yawning/edwards25519-extra.git v0.0.0-20220726154925-def713fd18e4/go.mod h1:gvdJuZuO/tPZyhEV8K3Hmoxv/DWud5L4qEQxfYjEUTo= +gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/goptlib v1.4.0 h1:Y7fHDMy11yyjM+YlHfcM3svaujdL+m5DqS444wbj8o4= +gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/goptlib v1.4.0/go.mod h1:70bhd4JKW/+1HLfm+TMrgHJsUHG4coelMWwiVEJ2gAg= golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= -golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= -golang.org/x/crypto v0.9.0 h1:LF6fAI+IutBocDJ2OT0Q1g8plpYljMZ4+lty+dsqw3g= -golang.org/x/crypto v0.9.0/go.mod h1:yrmDGqONDYtNj3tH8X9dzUun2m2lzPa9ngI6/RUPGR0= -golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= -golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= -golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/crypto v0.11.0 h1:6Ewdq3tDic1mg5xRO4milcWCfMVQhI4NkqWWvqejpuA= +golang.org/x/crypto v0.11.0/go.mod h1:xgJhtzW8F9jGdVFWZESrid1U1bjeNy4zgy5cRr/CIio= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= -golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= -golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= -golang.org/x/net v0.10.0 h1:X2//UzNDwYmtCLn7To6G58Wr6f5ahEAQgKNzv9Y951M= -golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= -golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/net v0.12.0 h1:cfawfvKITfUsFCeJIHJrbSxpeu/E81khclypR0GVT50= +golang.org/x/net v0.12.0/go.mod h1:zEVYFnQC7m/vmpQFELhcD1EWkZlX69l4oqgmer6hfKA= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.8.0 h1:EBmGv8NaZBZTWvrbjNoL6HVt+IVy3QDQpJs7VRIw3tU= -golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.10.0 h1:SqMFp9UcQJZa+pmYuAKjd9xq1f0j5rLcDIk0mj4qAsA= +golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= -golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= -golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= -golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo= -golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= -golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= -golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= -golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= -golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= -golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= -golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU= -golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= diff --git a/internal/x25519ell2/x25519ell2.go b/internal/x25519ell2/x25519ell2.go index eb2b1dd..8d02141 100644 --- a/internal/x25519ell2/x25519ell2.go +++ b/internal/x25519ell2/x25519ell2.go @@ -22,7 +22,6 @@ import ( "filippo.io/edwards25519" "filippo.io/edwards25519/field" - "gitlab.com/yawning/edwards25519-extra.git/elligator2" ) @@ -52,7 +51,7 @@ var ( 0xbb, 0x4a, 0xde, 0x38, 0x32, 0x99, 0x33, 0xe9, 0x28, 0x4a, 0x39, 0x06, 0xa0, 0xb9, 0xd5, 0x1f, }) - // Low order point Edwards y-coordinate `-lop_x * sqrtm1` + // Low order point Edwards y-coordinate `-lop_x * sqrtm1`. feLopY = mustFeFromBytes([]byte{ 0x26, 0xe8, 0x95, 0x8f, 0xc2, 0xb2, 0x27, 0xb0, 0x45, 0xc3, 0xf4, 0x89, 0xf2, 0xef, 0x98, 0xf0, 0xd5, 0xdf, 0xac, 0x05, 0xd3, 0xc6, 0x33, 0x39, 0xb1, 0x38, 0x02, 0x88, 0x6d, 0x53, 0xfc, 0x05, diff --git a/obfs4proxy/obfs4proxy.go b/obfs4proxy/obfs4proxy.go index 32359d0..a695514 100644 --- a/obfs4proxy/obfs4proxy.go +++ b/obfs4proxy/obfs4proxy.go @@ -41,12 +41,13 @@ import ( "sync" "syscall" - "git.torproject.org/pluggable-transports/goptlib.git" + "gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/goptlib" + "golang.org/x/net/proxy" + "gitlab.com/yawning/obfs4.git/common/log" "gitlab.com/yawning/obfs4.git/common/socks5" "gitlab.com/yawning/obfs4.git/transports" "gitlab.com/yawning/obfs4.git/transports/base" - "golang.org/x/net/proxy" ) const ( @@ -55,23 +56,27 @@ const ( socksAddr = "127.0.0.1:0" ) -var stateDir string -var termMon *termMonitor +var ( + stateDir string + termMon *termMonitor +) -func clientSetup() (launched bool, listeners []net.Listener) { +func clientSetup() (bool, []net.Listener) { ptClientInfo, err := pt.ClientSetup(transports.Transports()) if err != nil { golog.Fatal(err) } - ptClientProxy, err := ptGetProxy() + ptClientProxy, err := ptGetProxy(&ptClientInfo) if err != nil { golog.Fatal(err) } else if ptClientProxy != nil { - ptProxyDone() + pt.ProxyDone() } // Launch each of the client listeners. + var launched bool + listeners := make([]net.Listener, 0, len(ptClientInfo.MethodNames)) for _, name := range ptClientInfo.MethodNames { t := transports.Get(name) if t == nil { @@ -103,7 +108,7 @@ func clientSetup() (launched bool, listeners []net.Listener) { } pt.CmethodsDone() - return + return launched, listeners } func clientAcceptLoop(f base.ClientFactory, ln net.Listener, proxyURI *url.URL) error { @@ -111,10 +116,7 @@ func clientAcceptLoop(f base.ClientFactory, ln net.Listener, proxyURI *url.URL) for { conn, err := ln.Accept() if err != nil { - if e, ok := err.(net.Error); ok && !e.Temporary() { - return err - } - continue + return err } go clientHandler(f, conn, proxyURI) } @@ -176,12 +178,14 @@ func clientHandler(f base.ClientFactory, conn net.Conn, proxyURI *url.URL) { } } -func serverSetup() (launched bool, listeners []net.Listener) { +func serverSetup() (bool, []net.Listener) { ptServerInfo, err := pt.ServerSetup(transports.Transports()) if err != nil { golog.Fatal(err) } + var launched bool + listeners := make([]net.Listener, 0, len(ptServerInfo.Bindaddrs)) for _, bindaddr := range ptServerInfo.Bindaddrs { name := bindaddr.MethodName t := transports.Get(name) @@ -218,7 +222,7 @@ func serverSetup() (launched bool, listeners []net.Listener) { } pt.SmethodsDone() - return + return launched, listeners } func serverAcceptLoop(f base.ServerFactory, ln net.Listener, info *pt.ServerInfo) error { @@ -226,10 +230,7 @@ func serverAcceptLoop(f base.ServerFactory, ln net.Listener, info *pt.ServerInfo for { conn, err := ln.Accept() if err != nil { - if e, ok := err.(net.Error); ok && !e.Temporary() { - return err - } - continue + return err } go serverHandler(f, conn, info) } @@ -317,7 +318,7 @@ func main() { flag.Parse() if *showVer { - fmt.Printf("%s\n", getVersion()) + fmt.Printf("%s\n", getVersion()) //nolint:forbidigo os.Exit(0) } if err := log.SetLogLevel(*logLevelStr); err != nil { diff --git a/obfs4proxy/proxy_http.go b/obfs4proxy/proxy_http.go index 1adadf8..e263935 100644 --- a/obfs4proxy/proxy_http.go +++ b/obfs4proxy/proxy_http.go @@ -30,6 +30,7 @@ package main import ( "bufio" "encoding/base64" + "errors" "fmt" "net" "net/http" @@ -69,14 +70,14 @@ func (s *httpProxy) Dial(network, addr string) (net.Conn, error) { return nil, err } conn := new(httpConn) - conn.httpConn = httputil.NewClientConn(c, nil) // nolint: staticcheck + conn.httpConn = httputil.NewClientConn(c, nil) //nolint:staticcheck conn.remoteAddr, err = net.ResolveTCPAddr(network, addr) if err != nil { conn.httpConn.Close() return nil, err } - // HACK HACK HACK HACK. http.ReadRequest also does this. + // HACK: http.ReadRequest also does this. reqURL, err := url.Parse("http://" + addr) if err != nil { conn.httpConn.Close() @@ -84,7 +85,7 @@ func (s *httpProxy) Dial(network, addr string) (net.Conn, error) { } reqURL.Scheme = "" - req, err := http.NewRequest("CONNECT", reqURL.String(), nil) + req, err := http.NewRequest(http.MethodConnect, reqURL.String(), nil) if err != nil { conn.httpConn.Close() return nil, err @@ -93,16 +94,16 @@ func (s *httpProxy) Dial(network, addr string) (net.Conn, error) { if s.haveAuth { // SetBasicAuth doesn't quite do what is appropriate, because // the correct header is `Proxy-Authorization`. - req.Header.Set("Proxy-Authorization", "Basic " + base64.StdEncoding.EncodeToString([]byte(s.username+":"+s.password))) + req.Header.Set("Proxy-Authorization", "Basic "+base64.StdEncoding.EncodeToString([]byte(s.username+":"+s.password))) } req.Header.Set("User-Agent", "") resp, err := conn.httpConn.Do(req) - if err != nil && err != httputil.ErrPersistEOF { // nolint: staticcheck + if err != nil && !errors.Is(err, httputil.ErrPersistEOF) { //nolint:staticcheck conn.httpConn.Close() return nil, err } - if resp.StatusCode != 200 { + if resp.StatusCode != http.StatusOK { conn.httpConn.Close() return nil, fmt.Errorf("proxy error: %s", resp.Status) } @@ -113,7 +114,7 @@ func (s *httpProxy) Dial(network, addr string) (net.Conn, error) { type httpConn struct { remoteAddr *net.TCPAddr - httpConn *httputil.ClientConn // nolint: staticcheck + httpConn *httputil.ClientConn //nolint:staticcheck hijackedConn net.Conn staleReader *bufio.Reader } @@ -156,6 +157,6 @@ func (c *httpConn) SetWriteDeadline(t time.Time) error { return c.hijackedConn.SetWriteDeadline(t) } -func init() { +func init() { //nolint:gochecknoinits proxy.RegisterDialerType("http", newHTTP) } diff --git a/obfs4proxy/proxy_socks4.go b/obfs4proxy/proxy_socks4.go index ac8be9e..2e623f6 100644 --- a/obfs4proxy/proxy_socks4.go +++ b/obfs4proxy/proxy_socks4.go @@ -150,7 +150,7 @@ func socks4ErrorToString(code byte) string { case socks4Rejected: return "request rejected or failed" case socks4RejectedIdentdFailed: - return "request rejected becasue SOCKS server cannot connect to identd on the client" + return "request rejected because SOCKS server cannot connect to identd on the client" case socks4RejectedIdentdMismatch: return "request rejected because the client program and identd report different user-ids" default: @@ -158,7 +158,7 @@ func socks4ErrorToString(code byte) string { } } -func init() { +func init() { //nolint:gochecknoinits // Despite the scheme name, this really is SOCKS4. proxy.RegisterDialerType("socks4a", newSOCKS4) } diff --git a/obfs4proxy/pt_extras.go b/obfs4proxy/pt_extras.go index 18bc2df..5f10f14 100644 --- a/obfs4proxy/pt_extras.go +++ b/obfs4proxy/pt_extras.go @@ -35,11 +35,11 @@ import ( "os" "strconv" - "git.torproject.org/pluggable-transports/goptlib.git" + "gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/goptlib" ) // This file contains things that probably should be in goptlib but are not -// yet or are not finalized. +// yet or not exposed. func ptEnvError(msg string) error { line := []byte(fmt.Sprintf("ENV-ERROR %s\n", msg)) @@ -47,89 +47,61 @@ func ptEnvError(msg string) error { return errors.New(msg) } -func ptProxyError(msg string) error { - line := []byte(fmt.Sprintf("PROXY-ERROR %s\n", msg)) - _, _ = pt.Stdout.Write(line) - return errors.New(msg) -} - -func ptProxyDone() { - line := []byte("PROXY DONE\n") - _, _ = pt.Stdout.Write(line) -} - func ptIsClient() (bool, error) { clientEnv := os.Getenv("TOR_PT_CLIENT_TRANSPORTS") serverEnv := os.Getenv("TOR_PT_SERVER_TRANSPORTS") - if clientEnv != "" && serverEnv != "" { + switch { + case clientEnv != "" && serverEnv != "": return false, ptEnvError("TOR_PT_[CLIENT,SERVER]_TRANSPORTS both set") - } else if clientEnv != "" { + case clientEnv != "": return true, nil - } else if serverEnv != "" { + case serverEnv != "": return false, nil } return false, errors.New("not launched as a managed transport") } -func ptGetProxy() (*url.URL, error) { - specString := os.Getenv("TOR_PT_PROXY") - if specString == "" { - return nil, nil - } - spec, err := url.Parse(specString) - if err != nil { - return nil, ptProxyError(fmt.Sprintf("failed to parse proxy config: %s", err)) +func ptGetProxy(info *pt.ClientInfo) (*url.URL, error) { + proxyURL := info.ProxyURL + if proxyURL == nil { + return nil, nil //nolint:nilnil } - // Validate the TOR_PT_PROXY uri. - if !spec.IsAbs() { - return nil, ptProxyError("proxy URI is relative, must be absolute") - } - if spec.Path != "" { - return nil, ptProxyError("proxy URI has a path defined") - } - if spec.RawQuery != "" { - return nil, ptProxyError("proxy URI has a query defined") - } - if spec.Fragment != "" { - return nil, ptProxyError("proxy URI has a fragment defined") - } - - switch spec.Scheme { + // Validate the arguments. + switch proxyURL.Scheme { case "http": // The most forgiving of proxies. case "socks4a": - if spec.User != nil { - _, isSet := spec.User.Password() + if proxyURL.User != nil { + _, isSet := proxyURL.User.Password() if isSet { - return nil, ptProxyError("proxy URI specified SOCKS4a and a password") + return nil, pt.ProxyError("proxy URI proxyURLified SOCKS4a and a password") } } case "socks5": - if spec.User != nil { + if proxyURL.User != nil { // UNAME/PASSWD both must be between 1 and 255 bytes long. (RFC1929) - user := spec.User.Username() - passwd, isSet := spec.User.Password() + user := proxyURL.User.Username() + passwd, isSet := proxyURL.User.Password() if len(user) < 1 || len(user) > 255 { - return nil, ptProxyError("proxy URI specified a invalid SOCKS5 username") + return nil, pt.ProxyError("proxy URI proxyURLified a invalid SOCKS5 username") } if !isSet || len(passwd) < 1 || len(passwd) > 255 { - return nil, ptProxyError("proxy URI specified a invalid SOCKS5 password") + return nil, pt.ProxyError("proxy URI proxyURLified a invalid SOCKS5 password") } } default: - return nil, ptProxyError(fmt.Sprintf("proxy URI has invalid scheme: %s", spec.Scheme)) + return nil, pt.ProxyError(fmt.Sprintf("proxy URI has invalid scheme: %s", proxyURL.Scheme)) } - _, err = resolveAddrStr(spec.Host) - if err != nil { - return nil, ptProxyError(fmt.Sprintf("proxy URI has invalid host: %s", err)) + if _, err := resolveAddrStr(proxyURL.Host); err != nil { + return nil, pt.ProxyError(fmt.Sprintf("proxy URI has invalid host: %s", err)) } - return spec, nil + return proxyURL, nil } // Sigh, pt.resolveAddr() isn't exported. Include our own getto version that diff --git a/obfs4proxy/termmon.go b/obfs4proxy/termmon.go index 59304c9..aa89b82 100644 --- a/obfs4proxy/termmon.go +++ b/obfs4proxy/termmon.go @@ -29,7 +29,6 @@ package main import ( "io" - "io/ioutil" "os" "os/signal" "runtime" @@ -73,7 +72,7 @@ func (m *termMonitor) wait(termOnNoHandlers bool) os.Signal { } func (m *termMonitor) termOnStdinClose() { - _, err := io.Copy(ioutil.Discard, os.Stdin) + _, err := io.Copy(io.Discard, os.Stdin) // io.Copy() will return a nil on EOF, since reaching EOF is // expected behavior. No matter what, if this unblocks, assume @@ -103,9 +102,9 @@ func (m *termMonitor) termOnPPIDChange(ppid int) { m.sigChan <- syscall.SIGTERM } -func newTermMonitor() (m *termMonitor) { +func newTermMonitor() *termMonitor { ppid := os.Getppid() - m = new(termMonitor) + m := new(termMonitor) m.sigChan = make(chan os.Signal) m.handlerChan = make(chan int) signal.Notify(m.sigChan, syscall.SIGINT, syscall.SIGTERM) @@ -113,7 +112,7 @@ func newTermMonitor() (m *termMonitor) { // If tor supports feature #15435, we can use Stdin being closed as an // indication that tor has died, or wants the PT to shutdown for any // reason. - if ptShouldExitOnStdinClose() { + if ptShouldExitOnStdinClose() { //nolint:nestif go m.termOnStdinClose() } else { // Instead of feature #15435, use various kludges and hacks: @@ -124,12 +123,12 @@ func newTermMonitor() (m *termMonitor) { // Errors here are non-fatal, since it might still be // possible to fall back to a generic implementation. if err := termMonitorOSInit(m); err == nil { - return + return m } } if runtime.GOOS != "windows" { go m.termOnPPIDChange(ppid) } } - return + return m } diff --git a/obfs4proxy/termmon_linux.go b/obfs4proxy/termmon_linux.go index 926e630..13f8c54 100644 --- a/obfs4proxy/termmon_linux.go +++ b/obfs4proxy/termmon_linux.go @@ -32,18 +32,18 @@ import ( "syscall" ) -func termMonitorInitLinux(m *termMonitor) error { +func termMonitorInitLinux(_ *termMonitor) error { // Use prctl() to have the kernel deliver a SIGTERM if the parent // process dies. This beats anything else that can be done before // #15435 is implemented. _, _, errno := syscall.Syscall(syscall.SYS_PRCTL, syscall.PR_SET_PDEATHSIG, uintptr(syscall.SIGTERM), 0) if errno != 0 { var err error = errno - return fmt.Errorf("prctl(PR_SET_PDEATHSIG, SIGTERM) returned: %s", err) + return fmt.Errorf("prctl(PR_SET_PDEATHSIG, SIGTERM) returned: %w", err) } return nil } -func init() { +func init() { //nolint:gochecknoinits termMonitorOSInit = termMonitorInitLinux } diff --git a/transports/base/base.go b/transports/base/base.go index bc6e025..ab7bdbc 100644 --- a/transports/base/base.go +++ b/transports/base/base.go @@ -32,7 +32,7 @@ package base // import "gitlab.com/yawning/obfs4.git/transports/base" import ( "net" - "git.torproject.org/pluggable-transports/goptlib.git" + "gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/goptlib" ) type DialFunc func(string, string) (net.Conn, error) @@ -48,12 +48,12 @@ type ClientFactory interface { // for use with WrapConn. This routine is called before the outgoing // TCP/IP connection is created to allow doing things (like keypair // generation) to be hidden from third parties. - ParseArgs(args *pt.Args) (interface{}, error) + ParseArgs(args *pt.Args) (any, error) // Dial creates an outbound net.Conn, and does whatever is required // (eg: handshaking) to get the connection to the point where it is // ready to relay data. - Dial(network, address string, dialFn DialFunc, args interface{}) (net.Conn, error) + Dial(network, address string, dialFn DialFunc, args any) (net.Conn, error) } // ServerFactory is the interface that defines the factory for creating diff --git a/transports/meeklite/base.go b/transports/meeklite/base.go index e9d8ca1..b8f3732 100644 --- a/transports/meeklite/base.go +++ b/transports/meeklite/base.go @@ -36,7 +36,8 @@ import ( "fmt" "net" - "git.torproject.org/pluggable-transports/goptlib.git" + "gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/goptlib" + "gitlab.com/yawning/obfs4.git/transports/base" ) @@ -51,15 +52,13 @@ func (t *Transport) Name() string { } // ClientFactory returns a new meekClientFactory instance. -func (t *Transport) ClientFactory(stateDir string) (base.ClientFactory, error) { +func (t *Transport) ClientFactory(_ string) (base.ClientFactory, error) { cf := &meekClientFactory{transport: t} return cf, nil } // ServerFactory will one day return a new meekServerFactory instance. -func (t *Transport) ServerFactory(stateDir string, args *pt.Args) (base.ServerFactory, error) { - // TODO: Fill this in eventually, though for servers people should - // just use the real thing. +func (t *Transport) ServerFactory(_ string, _ *pt.Args) (base.ServerFactory, error) { return nil, fmt.Errorf("server not supported") } @@ -71,18 +70,18 @@ func (cf *meekClientFactory) Transport() base.Transport { return cf.transport } -func (cf *meekClientFactory) ParseArgs(args *pt.Args) (interface{}, error) { +func (cf *meekClientFactory) ParseArgs(args *pt.Args) (any, error) { return newClientArgs(args) } -func (cf *meekClientFactory) Dial(network, addr string, dialFn base.DialFunc, args interface{}) (net.Conn, error) { +func (cf *meekClientFactory) Dial(_, _ string, dialFn base.DialFunc, args any) (net.Conn, error) { // Validate args before opening outgoing connection. ca, ok := args.(*meekClientArgs) if !ok { return nil, fmt.Errorf("invalid argument type for args") } - return newMeekConn(network, addr, dialFn, ca) + return newMeekConn(dialFn, ca) } var ( diff --git a/transports/meeklite/meek.go b/transports/meeklite/meek.go index 17c7a67..3407f9b 100644 --- a/transports/meeklite/meek.go +++ b/transports/meeklite/meek.go @@ -35,7 +35,6 @@ import ( "errors" "fmt" "io" - "io/ioutil" "net" "net/http" gourl "net/url" @@ -44,7 +43,8 @@ import ( "sync" "time" - "git.torproject.org/pluggable-transports/goptlib.git" + "gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/goptlib" + "gitlab.com/yawning/obfs4.git/transports/base" ) @@ -83,8 +83,11 @@ func (ca *meekClientArgs) String() string { return transportName + ":" + ca.front + ":" + ca.url.String() } -func newClientArgs(args *pt.Args) (ca *meekClientArgs, err error) { - ca = &meekClientArgs{} +func newClientArgs(args *pt.Args) (*meekClientArgs, error) { + var ( + ca meekClientArgs + err error + ) // Parse the URL argument. str, ok := args.Get(urlArg) @@ -104,7 +107,7 @@ func newClientArgs(args *pt.Args) (ca *meekClientArgs, err error) { // Parse the (optional) front argument. ca.front, _ = args.Get(frontArg) - return ca, nil + return &ca, nil } type meekConn struct { @@ -119,18 +122,18 @@ type meekConn struct { rdBuf *bytes.Buffer } -func (c *meekConn) Read(p []byte) (n int, err error) { +func (c *meekConn) Read(p []byte) (int, error) { // If there is data left over from the previous read, // service the request using the buffered data. if c.rdBuf != nil { if c.rdBuf.Len() == 0 { panic("empty read buffer") } - n, err = c.rdBuf.Read(p) + n, err := c.rdBuf.Read(p) if c.rdBuf.Len() == 0 { c.rdBuf = nil } - return + return n, err } // Wait for the worker to enqueue more incoming data. @@ -142,16 +145,16 @@ func (c *meekConn) Read(p []byte) (n int, err error) { // Ew, an extra copy, but who am I kidding, it's meek. buf := bytes.NewBuffer(b) - n, err = buf.Read(p) + n, err := buf.Read(p) if buf.Len() > 0 { // If there's data pending, stash the buffer so the next // Read() call will use it to fulfuill the Read(). c.rdBuf = buf } - return + return n, err } -func (c *meekConn) Write(b []byte) (n int, err error) { +func (c *meekConn) Write(b []byte) (int, error) { // Check to see if the connection is actually open. select { case <-c.workerCloseChan: @@ -196,19 +199,19 @@ func (c *meekConn) RemoteAddr() net.Addr { return c.args } -func (c *meekConn) SetDeadline(t time.Time) error { +func (c *meekConn) SetDeadline(_ time.Time) error { return ErrNotSupported } -func (c *meekConn) SetReadDeadline(t time.Time) error { +func (c *meekConn) SetReadDeadline(_ time.Time) error { return ErrNotSupported } -func (c *meekConn) SetWriteDeadline(t time.Time) error { +func (c *meekConn) SetWriteDeadline(_ time.Time) error { return ErrNotSupported } -func (c *meekConn) enqueueWrite(b []byte) (ok bool) { +func (c *meekConn) enqueueWrite(b []byte) (ok bool) { //nolint:nonamedreturns defer func() { if err := recover(); err != nil { ok = false @@ -218,21 +221,26 @@ func (c *meekConn) enqueueWrite(b []byte) (ok bool) { return true } -func (c *meekConn) roundTrip(sndBuf []byte) (recvBuf []byte, err error) { - var req *http.Request - var resp *http.Response +func (c *meekConn) roundTrip(sndBuf []byte) ([]byte, error) { + var ( + req *http.Request + resp *http.Response + err error + ) + + url := *c.args.url + host := url.Host + if c.args.front != "" { + url.Host = c.args.front + } + urlStr := url.String() for retries := 0; retries < maxRetries; retries++ { - url := *c.args.url - host := url.Host - if c.args.front != "" { - url.Host = c.args.front - } var body io.Reader if len(sndBuf) > 0 { body = bytes.NewReader(sndBuf) } - req, err = http.NewRequest("POST", url.String(), body) + req, err = http.NewRequest(http.MethodPost, urlStr, body) if err != nil { return nil, err } @@ -248,16 +256,17 @@ func (c *meekConn) roundTrip(sndBuf []byte) (recvBuf []byte, err error) { } if resp.StatusCode == http.StatusOK { - recvBuf, err = ioutil.ReadAll(io.LimitReader(resp.Body, maxPayloadLength)) + var recvBuf []byte + recvBuf, err = io.ReadAll(io.LimitReader(resp.Body, maxPayloadLength)) resp.Body.Close() - return + return recvBuf, err } resp.Body.Close() err = fmt.Errorf("status code was %d, not %d", resp.StatusCode, http.StatusOK) time.Sleep(retryDelay) } - return + return nil, err } func (c *meekConn) ioWorker() { @@ -305,19 +314,20 @@ loop: } // Determine the next poll interval. - if len(rdBuf) > 0 { + switch { + case len(rdBuf) > 0: // Received data, enqueue the read. c.workerRdChan <- rdBuf // And poll immediately. interval = 0 - } else if wrSz > 0 { + case wrSz > 0: // Sent data, poll immediately. interval = 0 - } else if interval == 0 { + case interval == 0: // Neither sent nor received data after a poll, re-initialize the delay. interval = initPollInterval - } else { + default: // Apply a multiplicative backoff. interval = time.Duration(float64(interval) * pollIntervalMultiplier) if interval > maxPollInterval { @@ -337,7 +347,7 @@ loop: _ = c.Close() } -func newMeekConn(network, addr string, dialFn base.DialFunc, ca *meekClientArgs) (net.Conn, error) { +func newMeekConn(dialFn base.DialFunc, ca *meekClientArgs) (net.Conn, error) { id, err := newSessionID() if err != nil { return nil, err diff --git a/transports/obfs2/obfs2.go b/transports/obfs2/obfs2.go index 531bcd4..ff0e4f6 100644 --- a/transports/obfs2/obfs2.go +++ b/transports/obfs2/obfs2.go @@ -40,7 +40,8 @@ import ( "net" "time" - "git.torproject.org/pluggable-transports/goptlib.git" + "gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/goptlib" + "gitlab.com/yawning/obfs4.git/common/csrand" "gitlab.com/yawning/obfs4.git/transports/base" ) @@ -81,13 +82,13 @@ func (t *Transport) Name() string { } // ClientFactory returns a new obfs2ClientFactory instance. -func (t *Transport) ClientFactory(stateDir string) (base.ClientFactory, error) { +func (t *Transport) ClientFactory(_ string) (base.ClientFactory, error) { cf := &obfs2ClientFactory{transport: t} return cf, nil } // ServerFactory returns a new obfs2ServerFactory instance. -func (t *Transport) ServerFactory(stateDir string, args *pt.Args) (base.ServerFactory, error) { +func (t *Transport) ServerFactory(_ string, args *pt.Args) (base.ServerFactory, error) { if err := validateArgs(args); err != nil { return nil, err } @@ -104,11 +105,11 @@ func (cf *obfs2ClientFactory) Transport() base.Transport { return cf.transport } -func (cf *obfs2ClientFactory) ParseArgs(args *pt.Args) (interface{}, error) { +func (cf *obfs2ClientFactory) ParseArgs(args *pt.Args) (any, error) { return nil, validateArgs(args) } -func (cf *obfs2ClientFactory) Dial(network, addr string, dialFn base.DialFunc, args interface{}) (net.Conn, error) { +func (cf *obfs2ClientFactory) Dial(network, addr string, dialFn base.DialFunc, _ any) (net.Conn, error) { conn, err := dialFn(network, addr) if err != nil { return nil, err @@ -154,46 +155,46 @@ func (conn *obfs2Conn) Write(b []byte) (int, error) { return conn.tx.Write(b) } -func newObfs2ClientConn(conn net.Conn) (c *obfs2Conn, err error) { +func newObfs2ClientConn(conn net.Conn) (*obfs2Conn, error) { // Initialize a client connection, and start the handshake timeout. - c = &obfs2Conn{conn, true, nil, nil} + c := &obfs2Conn{conn, true, nil, nil} deadline := time.Now().Add(clientHandshakeTimeout) - if err = c.SetDeadline(deadline); err != nil { + if err := c.SetDeadline(deadline); err != nil { return nil, err } // Handshake. - if err = c.handshake(); err != nil { + if err := c.handshake(); err != nil { return nil, err } // Disarm the handshake timer. - if err = c.SetDeadline(time.Time{}); err != nil { + if err := c.SetDeadline(time.Time{}); err != nil { return nil, err } - return + return c, nil } -func newObfs2ServerConn(conn net.Conn) (c *obfs2Conn, err error) { +func newObfs2ServerConn(conn net.Conn) (*obfs2Conn, error) { // Initialize a server connection, and start the handshake timeout. - c = &obfs2Conn{conn, false, nil, nil} + c := &obfs2Conn{conn, false, nil, nil} deadline := time.Now().Add(serverHandshakeTimeout) - if err = c.SetDeadline(deadline); err != nil { + if err := c.SetDeadline(deadline); err != nil { return nil, err } // Handshake. - if err = c.handshake(); err != nil { + if err := c.handshake(); err != nil { return nil, err } // Disarm the handshake timer. - if err = c.SetDeadline(time.Time{}); err != nil { + if err := c.SetDeadline(time.Time{}); err != nil { return nil, err } - return + return c, nil } func (conn *obfs2Conn) handshake() error { @@ -220,7 +221,7 @@ func (conn *obfs2Conn) handshake() error { } else { padMagic = []byte(responderPadString) } - padKey, padIV := hsKdf(padMagic, seed[:], conn.isInitiator) + padKey, padIV := hsKdf(padMagic, seed[:]) padLen := uint32(csrand.IntRange(0, maxPadding)) hsBlob := make([]byte, hsLen+padLen) @@ -265,7 +266,7 @@ func (conn *obfs2Conn) handshake() error { } else { peerPadMagic = []byte(initiatorPadString) } - peerKey, peerIV := hsKdf(peerPadMagic, peerSeed[:], !conn.isInitiator) + peerKey, peerIV := hsKdf(peerPadMagic, peerSeed[:]) rxBlock, err := aes.NewCipher(peerKey) if err != nil { return err @@ -273,7 +274,7 @@ func (conn *obfs2Conn) handshake() error { rxStream := cipher.NewCTR(rxBlock, peerIV) conn.rx = &cipher.StreamReader{S: rxStream, R: conn.Conn} hsHdr := make([]byte, hsLen) - if _, err := io.ReadFull(conn, hsHdr[:]); err != nil { + if _, err := io.ReadFull(conn, hsHdr); err != nil { return err } @@ -296,11 +297,7 @@ func (conn *obfs2Conn) handshake() error { } // Derive the actual keys. - if err := conn.kdf(seed[:], peerSeed[:]); err != nil { - return err - } - - return nil + return conn.kdf(seed[:], peerSeed[:]) } func (conn *obfs2Conn) kdf(seed, peerSeed []byte) error { @@ -321,14 +318,14 @@ func (conn *obfs2Conn) kdf(seed, peerSeed []byte) error { combSeed = append(combSeed, seed...) } - initKey, initIV := hsKdf([]byte(initiatorKdfString), combSeed, true) + initKey, initIV := hsKdf([]byte(initiatorKdfString), combSeed) initBlock, err := aes.NewCipher(initKey) if err != nil { return err } initStream := cipher.NewCTR(initBlock, initIV) - respKey, respIV := hsKdf([]byte(responderKdfString), combSeed, false) + respKey, respIV := hsKdf([]byte(responderKdfString), combSeed) respBlock, err := aes.NewCipher(respKey) if err != nil { return err @@ -346,16 +343,16 @@ func (conn *obfs2Conn) kdf(seed, peerSeed []byte) error { return nil } -func hsKdf(magic, seed []byte, isInitiator bool) (padKey, padIV []byte) { +func hsKdf(magic, seed []byte) ([]byte, []byte) { // The actual key/IV is derived in the form of: // m = MAC(magic, seed) // KEY = m[:KEYLEN] // IV = m[KEYLEN:] m := mac(magic, seed) - padKey = m[:keyLen] - padIV = m[keyLen:] + padKey := m[:keyLen] + padIV := m[keyLen:] - return + return padKey, padIV } func mac(s, x []byte) []byte { @@ -368,7 +365,9 @@ func mac(s, x []byte) []byte { return h.Sum(nil) } -var _ base.ClientFactory = (*obfs2ClientFactory)(nil) -var _ base.ServerFactory = (*obfs2ServerFactory)(nil) -var _ base.Transport = (*Transport)(nil) -var _ net.Conn = (*obfs2Conn)(nil) +var ( + _ base.ClientFactory = (*obfs2ClientFactory)(nil) + _ base.ServerFactory = (*obfs2ServerFactory)(nil) + _ base.Transport = (*Transport)(nil) + _ net.Conn = (*obfs2Conn)(nil) +) diff --git a/transports/obfs3/obfs3.go b/transports/obfs3/obfs3.go index 42bdd90..fb5d99a 100644 --- a/transports/obfs3/obfs3.go +++ b/transports/obfs3/obfs3.go @@ -40,7 +40,8 @@ import ( "net" "time" - "git.torproject.org/pluggable-transports/goptlib.git" + "gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/goptlib" + "gitlab.com/yawning/obfs4.git/common/csrand" "gitlab.com/yawning/obfs4.git/common/uniformdh" "gitlab.com/yawning/obfs4.git/transports/base" @@ -69,13 +70,13 @@ func (t *Transport) Name() string { } // ClientFactory returns a new obfs3ClientFactory instance. -func (t *Transport) ClientFactory(stateDir string) (base.ClientFactory, error) { +func (t *Transport) ClientFactory(_ string) (base.ClientFactory, error) { cf := &obfs3ClientFactory{transport: t} return cf, nil } // ServerFactory returns a new obfs3ServerFactory instance. -func (t *Transport) ServerFactory(stateDir string, args *pt.Args) (base.ServerFactory, error) { +func (t *Transport) ServerFactory(_ string, _ *pt.Args) (base.ServerFactory, error) { sf := &obfs3ServerFactory{transport: t} return sf, nil } @@ -88,11 +89,11 @@ func (cf *obfs3ClientFactory) Transport() base.Transport { return cf.transport } -func (cf *obfs3ClientFactory) ParseArgs(args *pt.Args) (interface{}, error) { - return nil, nil +func (cf *obfs3ClientFactory) ParseArgs(_ *pt.Args) (any, error) { + return nil, nil //nolint:nilnil } -func (cf *obfs3ClientFactory) Dial(network, addr string, dialFn base.DialFunc, args interface{}) (net.Conn, error) { +func (cf *obfs3ClientFactory) Dial(network, addr string, dialFn base.DialFunc, _ any) (net.Conn, error) { conn, err := dialFn(network, addr) if err != nil { return nil, err @@ -133,46 +134,46 @@ type obfs3Conn struct { tx *cipher.StreamWriter } -func newObfs3ClientConn(conn net.Conn) (c *obfs3Conn, err error) { +func newObfs3ClientConn(conn net.Conn) (*obfs3Conn, error) { // Initialize a client connection, and start the handshake timeout. - c = &obfs3Conn{conn, true, nil, nil, new(bytes.Buffer), nil, nil} + c := &obfs3Conn{conn, true, nil, nil, new(bytes.Buffer), nil, nil} deadline := time.Now().Add(clientHandshakeTimeout) - if err = c.SetDeadline(deadline); err != nil { + if err := c.SetDeadline(deadline); err != nil { return nil, err } // Handshake. - if err = c.handshake(); err != nil { + if err := c.handshake(); err != nil { return nil, err } // Disarm the handshake timer. - if err = c.SetDeadline(time.Time{}); err != nil { + if err := c.SetDeadline(time.Time{}); err != nil { return nil, err } - return + return c, nil } -func newObfs3ServerConn(conn net.Conn) (c *obfs3Conn, err error) { +func newObfs3ServerConn(conn net.Conn) (*obfs3Conn, error) { // Initialize a server connection, and start the handshake timeout. - c = &obfs3Conn{conn, false, nil, nil, new(bytes.Buffer), nil, nil} + c := &obfs3Conn{conn, false, nil, nil, new(bytes.Buffer), nil, nil} deadline := time.Now().Add(serverHandshakeTimeout) - if err = c.SetDeadline(deadline); err != nil { + if err := c.SetDeadline(deadline); err != nil { return nil, err } // Handshake. - if err = c.handshake(); err != nil { + if err := c.handshake(); err != nil { return nil, err } // Disarm the handshake timer. - if err = c.SetDeadline(time.Time{}); err != nil { + if err := c.SetDeadline(time.Time{}); err != nil { return nil, err } - return + return c, nil } func (conn *obfs3Conn) handshake() error { @@ -217,11 +218,7 @@ func (conn *obfs3Conn) handshake() error { if err != nil { return err } - if err := conn.kdf(sharedSecret); err != nil { - return err - } - - return nil + return conn.kdf(sharedSecret) } func (conn *obfs3Conn) kdf(sharedSecret []byte) error { @@ -313,13 +310,13 @@ func (conn *obfs3Conn) findPeerMagic() error { } } -func (conn *obfs3Conn) Read(b []byte) (n int, err error) { +func (conn *obfs3Conn) Read(b []byte) (int, error) { // If this is the first time we read data post handshake, scan for the // magic value. if conn.rxMagic != nil { - if err = conn.findPeerMagic(); err != nil { + if err := conn.findPeerMagic(); err != nil { conn.Close() - return + return 0, err } conn.rxMagic = nil } @@ -339,20 +336,20 @@ func (conn *obfs3Conn) Read(b []byte) (n int, err error) { return conn.rx.Read(b) } -func (conn *obfs3Conn) Write(b []byte) (n int, err error) { +func (conn *obfs3Conn) Write(b []byte) (int, error) { // If this is the first time we write data post handshake, send the // padding/magic value. if conn.txMagic != nil { padLen := csrand.IntRange(0, maxPadding/2) blob := make([]byte, padLen+len(conn.txMagic)) - if err = csrand.Bytes(blob[:padLen]); err != nil { + if err := csrand.Bytes(blob[:padLen]); err != nil { conn.Close() - return + return 0, err } copy(blob[padLen:], conn.txMagic) - if _, err = conn.Conn.Write(blob); err != nil { + if _, err := conn.Conn.Write(blob); err != nil { conn.Close() - return + return 0, err } conn.txMagic = nil } @@ -360,7 +357,9 @@ func (conn *obfs3Conn) Write(b []byte) (n int, err error) { return conn.tx.Write(b) } -var _ base.ClientFactory = (*obfs3ClientFactory)(nil) -var _ base.ServerFactory = (*obfs3ServerFactory)(nil) -var _ base.Transport = (*Transport)(nil) -var _ net.Conn = (*obfs3Conn)(nil) +var ( + _ base.ClientFactory = (*obfs3ClientFactory)(nil) + _ base.ServerFactory = (*obfs3ServerFactory)(nil) + _ base.Transport = (*Transport)(nil) + _ net.Conn = (*obfs3Conn)(nil) +) diff --git a/transports/obfs4/framing/framing.go b/transports/obfs4/framing/framing.go index 10604a9..679348b 100644 --- a/transports/obfs4/framing/framing.go +++ b/transports/obfs4/framing/framing.go @@ -25,39 +25,40 @@ * POSSIBILITY OF SUCH DAMAGE. */ -// // Package framing implements the obfs4 link framing and cryptography. // // The Encoder/Decoder shared secret format is: -// uint8_t[32] NaCl secretbox key -// uint8_t[16] NaCl Nonce prefix -// uint8_t[16] SipHash-2-4 key (used to obfsucate length) -// uint8_t[8] SipHash-2-4 IV +// +// uint8_t[32] NaCl secretbox key +// uint8_t[16] NaCl Nonce prefix +// uint8_t[16] SipHash-2-4 key (used to obfsucate length) +// uint8_t[8] SipHash-2-4 IV // // The frame format is: -// uint16_t length (obfsucated, big endian) -// NaCl secretbox (Poly1305/XSalsa20) containing: -// uint8_t[16] tag (Part of the secretbox construct) -// uint8_t[] payload +// +// uint16_t length (obfsucated, big endian) +// NaCl secretbox (Poly1305/XSalsa20) containing: +// uint8_t[16] tag (Part of the secretbox construct) +// uint8_t[] payload // // The length field is length of the NaCl secretbox XORed with the truncated // SipHash-2-4 digest ran in OFB mode. // -// Initialize K, IV[0] with values from the shared secret. -// On each packet, IV[n] = H(K, IV[n - 1]) -// mask[n] = IV[n][0:2] -// obfsLen = length ^ mask[n] +// Initialize K, IV[0] with values from the shared secret. +// On each packet, IV[n] = H(K, IV[n - 1]) +// mask[n] = IV[n][0:2] +// obfsLen = length ^ mask[n] // // The NaCl secretbox (Poly1305/XSalsa20) nonce format is: -// uint8_t[24] prefix (Fixed) -// uint64_t counter (Big endian) +// +// uint8_t[24] prefix (Fixed) +// uint64_t counter (Big endian) // // The counter is initialized to 1, and is incremented on each frame. Since // the protocol is designed to be used over a reliable medium, the nonce is not // transmitted over the wire as both sides of the conversation know the prefix // and the initial counter value. It is imperative that the counter does not // wrap, and sessions MUST terminate before 2^64 frames are sent. -// package framing // import "gitlab.com/yawning/obfs4.git/transports/obfs4/framing" import ( @@ -67,9 +68,10 @@ import ( "fmt" "io" + "golang.org/x/crypto/nacl/secretbox" + "gitlab.com/yawning/obfs4.git/common/csrand" "gitlab.com/yawning/obfs4.git/common/drbg" - "golang.org/x/crypto/nacl/secretbox" ) const ( @@ -175,7 +177,7 @@ func NewEncoder(key []byte) *Encoder { // Encode encodes a single frame worth of payload and returns the encoded // length. InvalidPayloadLengthError is recoverable, all other errors MUST be // treated as fatal and the session aborted. -func (encoder *Encoder) Encode(frame, payload []byte) (n int, err error) { +func (encoder *Encoder) Encode(frame, payload []byte) (int, error) { payloadLen := len(payload) if MaximumFramePayloadLength < payloadLen { return 0, InvalidPayloadLengthError(payloadLen) @@ -186,7 +188,7 @@ func (encoder *Encoder) Encode(frame, payload []byte) (n int, err error) { // Generate a new nonce. var nonce [nonceLength]byte - if err = encoder.nonce.bytes(&nonce); err != nil { + if err := encoder.nonce.bytes(&nonce); err != nil { return 0, err } encoder.nonce.counter++ diff --git a/transports/obfs4/framing/framing_test.go b/transports/obfs4/framing/framing_test.go index d830625..585f6ed 100644 --- a/transports/obfs4/framing/framing_test.go +++ b/transports/obfs4/framing/framing_test.go @@ -30,6 +30,7 @@ package framing import ( "bytes" "crypto/rand" + "errors" "testing" ) @@ -89,7 +90,9 @@ func TestEncoder_Encode_Oversize(t *testing.T) { var buf [MaximumFramePayloadLength + 1]byte _, _ = rand.Read(buf[:]) // YOLO _, err := encoder.Encode(frame[:], buf[:]) - if _, ok := err.(InvalidPayloadLengthError); !ok { + + var payloadErr InvalidPayloadLengthError + if !errors.As(err, &payloadErr) { t.Error("Encoder.encode() returned unexpected error:", err) } } @@ -150,7 +153,7 @@ func BenchmarkEncoder_Encode(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { - transfered := 0 + var xfered int buffer := bytes.NewBuffer(payload) for 0 < buffer.Len() { n, err := buffer.Read(chopBuf[:]) @@ -159,11 +162,10 @@ func BenchmarkEncoder_Encode(b *testing.B) { } n, _ = encoder.Encode(frame[:], chopBuf[:n]) - transfered += n - FrameOverhead + xfered += n - FrameOverhead } - if transfered != len(payload) { - b.Fatalf("Transfered length mismatch: %d != %d", transfered, - len(payload)) + if xfered != len(payload) { + b.Fatalf("Xfered length mismatch: %d != %d", xfered, len(payload)) } } } diff --git a/transports/obfs4/handshake_ntor.go b/transports/obfs4/handshake_ntor.go index c5d0e6c..c39a116 100644 --- a/transports/obfs4/handshake_ntor.go +++ b/transports/obfs4/handshake_ntor.go @@ -280,7 +280,7 @@ func (hs *serverHandshake) parseClientHandshake(filter *replayfilter.ReplayFilte macFound := false for _, off := range []int64{0, -1, 1} { // Allow epoch to be off by up to a hour in either direction. - epochHour := []byte(strconv.FormatInt(getEpochHour()+int64(off), 10)) + epochHour := []byte(strconv.FormatInt(getEpochHour()+off, 10)) hs.mac.Reset() _, _ = hs.mac.Write(resp[:pos+markLength]) _, _ = hs.mac.Write(epochHour) @@ -367,7 +367,7 @@ func getEpochHour() int64 { return time.Now().Unix() / 3600 } -func findMarkMac(mark, buf []byte, startPos, maxPos int, fromTail bool) (pos int) { +func findMarkMac(mark, buf []byte, startPos, maxPos int, fromTail bool) int { if len(mark) != markLength { panic(fmt.Sprintf("BUG: Invalid mark length: %d", len(mark))) } @@ -387,19 +387,19 @@ func findMarkMac(mark, buf []byte, startPos, maxPos int, fromTail bool) (pos int // The server can optimize the search process by only examining the // tail of the buffer. The client can't send valid data past M_C | // MAC_C as it does not have the server's public key yet. - pos = endPos - (markLength + macLength) + pos := endPos - (markLength + macLength) if !hmac.Equal(buf[pos:pos+markLength], mark) { return -1 } - return + return pos } // The client has to actually do a substring search since the server can // and will send payload trailing the response. // // XXX: bytes.Index() uses a naive search, which kind of sucks. - pos = bytes.Index(buf[startPos:endPos], mark) + pos := bytes.Index(buf[startPos:endPos], mark) if pos == -1 { return -1 } @@ -411,7 +411,7 @@ func findMarkMac(mark, buf []byte, startPos, maxPos int, fromTail bool) (pos int // Return the index relative to the start of the slice. pos += startPos - return + return pos } func makePad(padLen int) ([]byte, error) { diff --git a/transports/obfs4/handshake_ntor_test.go b/transports/obfs4/handshake_ntor_test.go index 701c610..3d26d74 100644 --- a/transports/obfs4/handshake_ntor_test.go +++ b/transports/obfs4/handshake_ntor_test.go @@ -115,7 +115,7 @@ func TestHandshakeNtorClient(t *testing.T) { serverHs := newServerHandshake(nodeID, idKeypair, serverKeypair) _, err = serverHs.parseClientHandshake(serverFilter, clientBlob) if err == nil { - t.Fatalf("serverHandshake.parseClientHandshake() succeded (oversized)") + t.Fatalf("serverHandshake.parseClientHandshake() succeeded (oversized)") } // Test undersized client padding. @@ -127,7 +127,7 @@ func TestHandshakeNtorClient(t *testing.T) { serverHs = newServerHandshake(nodeID, idKeypair, serverKeypair) _, err = serverHs.parseClientHandshake(serverFilter, clientBlob) if err == nil { - t.Fatalf("serverHandshake.parseClientHandshake() succeded (undersized)") + t.Fatalf("serverHandshake.parseClientHandshake() succeeded (undersized)") } } @@ -204,7 +204,7 @@ func TestHandshakeNtorServer(t *testing.T) { serverHs := newServerHandshake(nodeID, idKeypair, serverKeypair) _, err = serverHs.parseClientHandshake(serverFilter, clientBlob) if err == nil { - t.Fatalf("serverHandshake.parseClientHandshake() succeded (oversized)") + t.Fatalf("serverHandshake.parseClientHandshake() succeeded (oversized)") } // Test undersized client padding. @@ -216,7 +216,7 @@ func TestHandshakeNtorServer(t *testing.T) { serverHs = newServerHandshake(nodeID, idKeypair, serverKeypair) _, err = serverHs.parseClientHandshake(serverFilter, clientBlob) if err == nil { - t.Fatalf("serverHandshake.parseClientHandshake() succeded (undersized)") + t.Fatalf("serverHandshake.parseClientHandshake() succeeded (undersized)") } // Test oversized server padding. @@ -243,6 +243,6 @@ func TestHandshakeNtorServer(t *testing.T) { } _, _, err = clientHs.parseServerHandshake(serverBlob) if err == nil { - t.Fatalf("clientHandshake.parseServerHandshake() succeded (oversized)") + t.Fatalf("clientHandshake.parseServerHandshake() succeeded (oversized)") } } diff --git a/transports/obfs4/obfs4.go b/transports/obfs4/obfs4.go index 9723735..e053ddc 100644 --- a/transports/obfs4/obfs4.go +++ b/transports/obfs4/obfs4.go @@ -32,17 +32,18 @@ package obfs4 // import "gitlab.com/yawning/obfs4.git/transports/obfs4" import ( "bytes" "crypto/sha256" + "errors" "flag" "fmt" "io" - "io/ioutil" "math/rand" "net" "strconv" "syscall" "time" - "git.torproject.org/pluggable-transports/goptlib.git" + "gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/goptlib" + "gitlab.com/yawning/obfs4.git/common/drbg" "gitlab.com/yawning/obfs4.git/common/ntor" "gitlab.com/yawning/obfs4.git/common/probdist" @@ -81,7 +82,7 @@ const ( // biasedDist controls if the probability table will be ScrambleSuit style or // uniformly distributed. -var biasedDist bool +var biasedDist = flag.Bool(biasCmdArg, false, "Enable obfs4 using ScrambleSuit style table generation") type obfs4ClientArgs struct { nodeID *ntor.NodeID @@ -99,7 +100,7 @@ func (t *Transport) Name() string { } // ClientFactory returns a new obfs4ClientFactory instance. -func (t *Transport) ClientFactory(stateDir string) (base.ClientFactory, error) { +func (t *Transport) ClientFactory(_ string) (base.ClientFactory, error) { cf := &obfs4ClientFactory{transport: t} return cf, nil } @@ -137,7 +138,7 @@ func (t *Transport) ServerFactory(stateDir string, args *pt.Args) (base.ServerFa if err != nil { return nil, err } - rng := rand.New(drbg) + rng := rand.New(drbg) //nolint:gosec sf := &obfs4ServerFactory{t, &ptArgs, st.nodeID, st.identityKey, st.drbgSeed, iatSeed, st.iatMode, filter, rng.Intn(maxCloseDelay)} return sf, nil @@ -151,14 +152,14 @@ func (cf *obfs4ClientFactory) Transport() base.Transport { return cf.transport } -func (cf *obfs4ClientFactory) ParseArgs(args *pt.Args) (interface{}, error) { +func (cf *obfs4ClientFactory) ParseArgs(args *pt.Args) (any, error) { var nodeID *ntor.NodeID var publicKey *ntor.PublicKey // The "new" (version >= 0.0.3) bridge lines use a unified "cert" argument // for the Node ID and Public Key. certStr, ok := args.Get(certArg) - if ok { + if ok { //nolint:nestif cert, err := serverCertFromString(certStr) if err != nil { return nil, err @@ -195,7 +196,7 @@ func (cf *obfs4ClientFactory) ParseArgs(args *pt.Args) (interface{}, error) { return nil, fmt.Errorf("invalid iat-mode '%d'", iatMode) } - // Generate the session key pair before connectiong to hide the Elligator2 + // Generate the session key pair before connecting to hide the Elligator2 // rejection sampling from network observers. sessionKey, err := ntor.NewKeypair(true) if err != nil { @@ -205,7 +206,7 @@ func (cf *obfs4ClientFactory) ParseArgs(args *pt.Args) (interface{}, error) { return &obfs4ClientArgs{nodeID, publicKey, sessionKey, iatMode}, nil } -func (cf *obfs4ClientFactory) Dial(network, addr string, dialFn base.DialFunc, args interface{}) (net.Conn, error) { +func (cf *obfs4ClientFactory) Dial(network, addr string, dialFn base.DialFunc, args any) (net.Conn, error) { // Validate args before bothering to open connection. ca, ok := args.(*obfs4ClientArgs) if !ok { @@ -259,10 +260,10 @@ func (sf *obfs4ServerFactory) WrapConn(conn net.Conn) (net.Conn, error) { return nil, err } - lenDist := probdist.New(sf.lenSeed, 0, framing.MaximumSegmentLength, biasedDist) + lenDist := probdist.New(sf.lenSeed, 0, framing.MaximumSegmentLength, *biasedDist) var iatDist *probdist.WeightedDist if sf.iatSeed != nil { - iatDist = probdist.New(sf.iatSeed, 0, maxIATDelay, biasedDist) + iatDist = probdist.New(sf.iatSeed, 0, maxIATDelay, *biasedDist) } c := &obfs4Conn{conn, true, lenDist, iatDist, sf.iatMode, bytes.NewBuffer(nil), bytes.NewBuffer(nil), make([]byte, consumeReadSize), nil, nil} @@ -294,25 +295,28 @@ type obfs4Conn struct { decoder *framing.Decoder } -func newObfs4ClientConn(conn net.Conn, args *obfs4ClientArgs) (c *obfs4Conn, err error) { +func newObfs4ClientConn(conn net.Conn, args *obfs4ClientArgs) (*obfs4Conn, error) { // Generate the initial protocol polymorphism distribution(s). - var seed *drbg.Seed + var ( + seed *drbg.Seed + err error + ) if seed, err = drbg.NewSeed(); err != nil { - return + return nil, err } - lenDist := probdist.New(seed, 0, framing.MaximumSegmentLength, biasedDist) + lenDist := probdist.New(seed, 0, framing.MaximumSegmentLength, *biasedDist) var iatDist *probdist.WeightedDist if args.iatMode != iatNone { var iatSeed *drbg.Seed iatSeedSrc := sha256.Sum256(seed.Bytes()[:]) if iatSeed, err = drbg.SeedFromBytes(iatSeedSrc[:]); err != nil { - return + return nil, err } - iatDist = probdist.New(iatSeed, 0, maxIATDelay, biasedDist) + iatDist = probdist.New(iatSeed, 0, maxIATDelay, *biasedDist) } // Allocate the client structure. - c = &obfs4Conn{conn, false, lenDist, iatDist, args.iatMode, bytes.NewBuffer(nil), bytes.NewBuffer(nil), make([]byte, consumeReadSize), nil, nil} + c := &obfs4Conn{conn, false, lenDist, iatDist, args.iatMode, bytes.NewBuffer(nil), bytes.NewBuffer(nil), make([]byte, consumeReadSize), nil, nil} // Start the handshake timeout. deadline := time.Now().Add(clientHandshakeTimeout) @@ -329,7 +333,7 @@ func newObfs4ClientConn(conn net.Conn, args *obfs4ClientArgs) (c *obfs4Conn, err return nil, err } - return + return c, nil } func (conn *obfs4Conn) clientHandshake(nodeID *ntor.NodeID, peerIdentityKey *ntor.PublicKey, sessionKey *ntor.Keypair) error { @@ -359,14 +363,14 @@ func (conn *obfs4Conn) clientHandshake(nodeID *ntor.NodeID, peerIdentityKey *nto conn.receiveBuffer.Write(hsBuf[:n]) n, seed, err := hs.parseServerHandshake(conn.receiveBuffer.Bytes()) - if err == ErrMarkNotFoundYet { + if errors.Is(err, ErrMarkNotFoundYet) { continue } else if err != nil { return err } _ = conn.receiveBuffer.Next(n) - // Use the derived key material to intialize the link crypto. + // Use the derived key material to initialize the link crypto. okm := ntor.Kdf(seed, framing.KeyLength*2) conn.encoder = framing.NewEncoder(okm[:framing.KeyLength]) conn.decoder = framing.NewDecoder(okm[framing.KeyLength:]) @@ -398,7 +402,7 @@ func (conn *obfs4Conn) serverHandshake(sf *obfs4ServerFactory, sessionKey *ntor. conn.receiveBuffer.Write(hsBuf[:n]) seed, err := hs.parseClientHandshake(sf.replayFilter, conn.receiveBuffer.Bytes()) - if err == ErrMarkNotFoundYet { + if errors.Is(err, ErrMarkNotFoundYet) { continue } else if err != nil { return err @@ -406,10 +410,10 @@ func (conn *obfs4Conn) serverHandshake(sf *obfs4ServerFactory, sessionKey *ntor. conn.receiveBuffer.Reset() if err := conn.Conn.SetDeadline(time.Time{}); err != nil { - return nil + return err } - // Use the derived key material to intialize the link crypto. + // Use the derived key material to initialize the link crypto. okm := ntor.Kdf(seed, framing.KeyLength*2) conn.encoder = framing.NewEncoder(okm[framing.KeyLength:]) conn.decoder = framing.NewDecoder(okm[:framing.KeyLength]) @@ -421,7 +425,7 @@ func (conn *obfs4Conn) serverHandshake(sf *obfs4ServerFactory, sessionKey *ntor. // the length obfuscation, this makes the amount of data received from the // server inconsistent with the length sent from the client. // - // Rebalance this by tweaking the client mimimum padding/server maximum + // Rebalance this by tweaking the client minimum padding/server maximum // padding, and sending the PRNG seed unpadded (As in, treat the PRNG seed // as part of the server response). See inlineSeedFrameLength in // handshake_ntor.go. @@ -447,13 +451,14 @@ func (conn *obfs4Conn) serverHandshake(sf *obfs4ServerFactory, sessionKey *ntor. return nil } -func (conn *obfs4Conn) Read(b []byte) (n int, err error) { +func (conn *obfs4Conn) Read(b []byte) (int, error) { // If there is no payload from the previous Read() calls, consume data off // the network. Not all data received is guaranteed to be usable payload, // so do this in a loop till data is present or an error occurs. + var err error for conn.receiveDecodedBuffer.Len() == 0 { err = conn.readPackets() - if err == framing.ErrAgain { + if errors.Is(err, framing.ErrAgain) { // Don't proagate this back up the call stack if we happen to break // out of the loop. err = nil @@ -465,6 +470,7 @@ func (conn *obfs4Conn) Read(b []byte) (n int, err error) { // Even if err is set, attempt to do the read anyway so that all decoded // data gets relayed before the connection is torn down. + var n int if conn.receiveDecodedBuffer.Len() > 0 { var berr error n, berr = conn.receiveDecodedBuffer.Read(b) @@ -475,28 +481,29 @@ func (conn *obfs4Conn) Read(b []byte) (n int, err error) { } } - return + return n, err } -func (conn *obfs4Conn) Write(b []byte) (n int, err error) { +func (conn *obfs4Conn) Write(b []byte) (int, error) { chopBuf := bytes.NewBuffer(b) - var payload [maxPacketPayloadLength]byte - var frameBuf bytes.Buffer + var ( + payload [maxPacketPayloadLength]byte + frameBuf bytes.Buffer + n int + ) // Chop the pending data into payload frames. for chopBuf.Len() > 0 { // Send maximum sized frames. - rdLen := 0 - rdLen, err = chopBuf.Read(payload[:]) + rdLen, err := chopBuf.Read(payload[:]) if err != nil { return 0, err } else if rdLen == 0 { - panic(fmt.Sprintf("BUG: Write(), chopping length was 0")) + panic("BUG: Write(), chopping length was 0") } n += rdLen - err = conn.makePacket(&frameBuf, packetTypePayload, payload[:rdLen], 0) - if err != nil { + if err = conn.makePacket(&frameBuf, packetTypePayload, payload[:rdLen], 0); err != nil { return 0, err } } @@ -504,7 +511,7 @@ func (conn *obfs4Conn) Write(b []byte) (n int, err error) { if conn.iatMode != iatParanoid { // For non-paranoid IAT, pad once per burst. Paranoid IAT handles // things differently. - if err = conn.padBurst(&frameBuf, conn.lenDist.Sample()); err != nil { + if err := conn.padBurst(&frameBuf, conn.lenDist.Sample()); err != nil { return 0, err } } @@ -513,10 +520,11 @@ func (conn *obfs4Conn) Write(b []byte) (n int, err error) { // because the frame encoder state is advanced, and the code doesn't keep // frameBuf around. In theory, write timeouts and whatnot could be // supported if this wasn't the case, but that complicates the code. - if conn.iatMode != iatNone { + var err error + if conn.iatMode != iatNone { //nolint:nestif var iatFrame [framing.MaximumSegmentLength]byte for frameBuf.Len() > 0 { - iatWrLen := 0 + var iatWrLen int switch conn.iatMode { case iatEnabled: @@ -549,7 +557,7 @@ func (conn *obfs4Conn) Write(b []byte) (n int, err error) { if err != nil { return 0, err } else if iatWrLen == 0 { - panic(fmt.Sprintf("BUG: Write(), iat length was 0")) + panic("BUG: Write(), iat length was 0") } // Calculate the delay. The delay resolution is 100 usec, leading @@ -557,8 +565,7 @@ func (conn *obfs4Conn) Write(b []byte) (n int, err error) { iatDelta := time.Duration(conn.iatDist.Sample() * 100) // Write then sleep. - _, err = conn.Conn.Write(iatFrame[:iatWrLen]) - if err != nil { + if _, err = conn.Conn.Write(iatFrame[:iatWrLen]); err != nil { return 0, err } time.Sleep(iatDelta * time.Microsecond) @@ -567,14 +574,14 @@ func (conn *obfs4Conn) Write(b []byte) (n int, err error) { _, err = conn.Conn.Write(frameBuf.Bytes()) } - return + return n, err } -func (conn *obfs4Conn) SetDeadline(t time.Time) error { +func (conn *obfs4Conn) SetDeadline(_ time.Time) error { return syscall.ENOTSUP } -func (conn *obfs4Conn) SetWriteDeadline(t time.Time) error { +func (conn *obfs4Conn) SetWriteDeadline(_ time.Time) error { return syscall.ENOTSUP } @@ -594,13 +601,13 @@ func (conn *obfs4Conn) closeAfterDelay(sf *obfs4ServerFactory, startTime time.Ti // Consume and discard data on this connection until the specified interval // passes. - _, _ = io.Copy(ioutil.Discard, conn.Conn) + _, _ = io.Copy(io.Discard, conn.Conn) } -func (conn *obfs4Conn) padBurst(burst *bytes.Buffer, toPadTo int) (err error) { +func (conn *obfs4Conn) padBurst(burst *bytes.Buffer, toPadTo int) error { tailLen := burst.Len() % framing.MaximumSegmentLength - padLen := 0 + var padLen int if toPadTo >= tailLen { padLen = toPadTo - tailLen } else { @@ -608,32 +615,24 @@ func (conn *obfs4Conn) padBurst(burst *bytes.Buffer, toPadTo int) (err error) { } if padLen > headerLength { - err = conn.makePacket(burst, packetTypePayload, []byte{}, - uint16(padLen-headerLength)) - if err != nil { - return + if err := conn.makePacket(burst, packetTypePayload, []byte{}, uint16(padLen-headerLength)); err != nil { + return err } } else if padLen > 0 { - err = conn.makePacket(burst, packetTypePayload, []byte{}, - maxPacketPayloadLength) - if err != nil { - return + if err := conn.makePacket(burst, packetTypePayload, []byte{}, maxPacketPayloadLength); err != nil { + return err } - err = conn.makePacket(burst, packetTypePayload, []byte{}, - uint16(padLen)) - if err != nil { - return + if err := conn.makePacket(burst, packetTypePayload, []byte{}, uint16(padLen)); err != nil { + return err } } - return -} - -func init() { - flag.BoolVar(&biasedDist, biasCmdArg, false, "Enable obfs4 using ScrambleSuit style table generation") + return nil } -var _ base.ClientFactory = (*obfs4ClientFactory)(nil) -var _ base.ServerFactory = (*obfs4ServerFactory)(nil) -var _ base.Transport = (*Transport)(nil) -var _ net.Conn = (*obfs4Conn)(nil) +var ( + _ base.ClientFactory = (*obfs4ClientFactory)(nil) + _ base.ServerFactory = (*obfs4ServerFactory)(nil) + _ base.Transport = (*Transport)(nil) + _ net.Conn = (*obfs4Conn)(nil) +) diff --git a/transports/obfs4/packet.go b/transports/obfs4/packet.go index 8d47af9..9b19190 100644 --- a/transports/obfs4/packet.go +++ b/transports/obfs4/packet.go @@ -30,6 +30,7 @@ package obfs4 import ( "crypto/sha256" "encoding/binary" + "errors" "fmt" "io" @@ -52,7 +53,7 @@ const ( ) // InvalidPacketLengthError is the error returned when decodePacket detects a -// invalid packet length/ +// invalid packet length. type InvalidPacketLengthError int func (e InvalidPacketLengthError) Error() string { @@ -85,7 +86,7 @@ func (conn *obfs4Conn) makePacket(w io.Writer, pktType uint8, data []byte, padLe pkt[0] = pktType binary.BigEndian.PutUint16(pkt[1:], uint16(len(data))) if len(data) > 0 { - copy(pkt[3:], data[:]) + copy(pkt[3:], data) } copy(pkt[3+len(data):], zeroPadBytes[:padLen]) @@ -108,23 +109,28 @@ func (conn *obfs4Conn) makePacket(w io.Writer, pktType uint8, data []byte, padLe return nil } -func (conn *obfs4Conn) readPackets() (err error) { +func (conn *obfs4Conn) readPackets() error { // Attempt to read off the network. rdLen, rdErr := conn.Conn.Read(conn.readBuffer) conn.receiveBuffer.Write(conn.readBuffer[:rdLen]) - var decoded [framing.MaximumFramePayloadLength]byte + var ( + decoded [framing.MaximumFramePayloadLength]byte + err error + ) +bufferLoop: for conn.receiveBuffer.Len() > 0 { // Decrypt an AEAD frame. - decLen := 0 + var decLen int decLen, err = conn.decoder.Decode(decoded[:], conn.receiveBuffer) - if err == framing.ErrAgain { - break - } else if err != nil { - break - } else if decLen < packetOverhead { + switch { + case errors.Is(err, framing.ErrAgain): + break bufferLoop + case err != nil: + break bufferLoop + case decLen < packetOverhead: err = InvalidPacketLengthError(decLen) - break + break bufferLoop } // Decode the packet. @@ -171,5 +177,5 @@ func (conn *obfs4Conn) readPackets() (err error) { return rdErr } - return + return err } diff --git a/transports/obfs4/statefile.go b/transports/obfs4/statefile.go index cbf1d6e..20fc5bd 100644 --- a/transports/obfs4/statefile.go +++ b/transports/obfs4/statefile.go @@ -28,16 +28,17 @@ package obfs4 import ( + "bytes" "encoding/base64" "encoding/json" "fmt" - "io/ioutil" "os" "path" "strconv" "strings" - "git.torproject.org/pluggable-transports/goptlib.git" + "gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/goptlib" + "gitlab.com/yawning/obfs4.git/common/csrand" "gitlab.com/yawning/obfs4.git/common/drbg" "gitlab.com/yawning/obfs4.git/common/ntor" @@ -81,7 +82,7 @@ func (cert *obfs4ServerCert) unpack() (*ntor.NodeID, *ntor.PublicKey) { func serverCertFromString(encoded string) (*obfs4ServerCert, error) { decoded, err := base64.StdEncoding.DecodeString(encoded + certSuffix) if err != nil { - return nil, fmt.Errorf("failed to decode cert: %s", err) + return nil, fmt.Errorf("failed to decode cert: %w", err) } if len(decoded) != certLength { @@ -93,7 +94,10 @@ func serverCertFromString(encoded string) (*obfs4ServerCert, error) { func serverCertFromState(st *obfs4ServerState) *obfs4ServerCert { cert := new(obfs4ServerCert) - cert.raw = append(st.nodeID.Bytes()[:], st.identityKey.Public().Bytes()[:]...) + + cert.raw = bytes.Clone(st.nodeID.Bytes()[:]) + cert.raw = append(cert.raw, st.identityKey.Public().Bytes()[:]...) + return cert } @@ -121,15 +125,16 @@ func serverStateFromArgs(stateDir string, args *pt.Args) (*obfs4ServerState, err // Either a private key, node id, and seed are ALL specified, or // they should be loaded from the state file. - if !privKeyOk && !nodeIDOk && !seedOk { + switch { + case !privKeyOk && !nodeIDOk && !seedOk: if err := jsonServerStateFromFile(stateDir, &js); err != nil { return nil, err } - } else if !privKeyOk { + case !privKeyOk: return nil, fmt.Errorf("missing argument '%s'", privateKeyArg) - } else if !nodeIDOk { + case !nodeIDOk: return nil, fmt.Errorf("missing argument '%s'", nodeIDArg) - } else if !seedOk { + case !seedOk: return nil, fmt.Errorf("missing argument '%s'", seedArg) } @@ -177,7 +182,7 @@ func serverStateFromJSONServerState(stateDir string, js *jsonServerState) (*obfs func jsonServerStateFromFile(stateDir string, js *jsonServerState) error { fPath := path.Join(stateDir, stateFile) - f, err := ioutil.ReadFile(fPath) + f, err := os.ReadFile(fPath) if err != nil { if os.IsNotExist(err) { if err = newJSONServerState(stateDir, js); err == nil { @@ -188,27 +193,29 @@ func jsonServerStateFromFile(stateDir string, js *jsonServerState) error { } if err := json.Unmarshal(f, js); err != nil { - return fmt.Errorf("failed to load statefile '%s': %s", fPath, err) + return fmt.Errorf("failed to load statefile '%s': %w", fPath, err) } return nil } -func newJSONServerState(stateDir string, js *jsonServerState) (err error) { +func newJSONServerState(stateDir string, js *jsonServerState) error { // Generate everything a server needs, using the cryptographic PRNG. var st obfs4ServerState rawID := make([]byte, ntor.NodeIDLength) - if err = csrand.Bytes(rawID); err != nil { - return + if err := csrand.Bytes(rawID); err != nil { + return err } + + var err error if st.nodeID, err = ntor.NewNodeID(rawID); err != nil { - return + return err } if st.identityKey, err = ntor.NewKeypair(false); err != nil { - return + return err } if st.drbgSeed, err = drbg.NewSeed(); err != nil { - return + return err } st.iatMode = iatNone @@ -228,11 +235,7 @@ func writeJSONServerState(stateDir string, js *jsonServerState) error { if encoded, err = json.Marshal(js); err != nil { return err } - if err = ioutil.WriteFile(path.Join(stateDir, stateFile), encoded, 0600); err != nil { - return err - } - - return nil + return os.WriteFile(path.Join(stateDir, stateFile), encoded, 0o600) } func newBridgeFile(stateDir string, st *obfs4ServerState) error { @@ -252,9 +255,5 @@ func newBridgeFile(stateDir string, st *obfs4ServerState) error { st.clientString()) tmp := []byte(prefix + bridgeLine) - if err := ioutil.WriteFile(path.Join(stateDir, bridgeFile), tmp, 0600); err != nil { - return err - } - - return nil + return os.WriteFile(path.Join(stateDir, bridgeFile), tmp, 0o600) } diff --git a/transports/scramblesuit/base.go b/transports/scramblesuit/base.go index 655ad7a..9947573 100644 --- a/transports/scramblesuit/base.go +++ b/transports/scramblesuit/base.go @@ -33,7 +33,8 @@ import ( "fmt" "net" - "git.torproject.org/pluggable-transports/goptlib.git" + "gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/goptlib" + "gitlab.com/yawning/obfs4.git/transports/base" ) @@ -58,8 +59,7 @@ func (t *Transport) ClientFactory(stateDir string) (base.ClientFactory, error) { } // ServerFactory will one day return a new ssServerFactory instance. -func (t *Transport) ServerFactory(stateDir string, args *pt.Args) (base.ServerFactory, error) { - // TODO: Fill this in eventually, though obfs4 is better. +func (t *Transport) ServerFactory(_ string, _ *pt.Args) (base.ServerFactory, error) { return nil, fmt.Errorf("server not supported") } @@ -72,11 +72,11 @@ func (cf *ssClientFactory) Transport() base.Transport { return cf.transport } -func (cf *ssClientFactory) ParseArgs(args *pt.Args) (interface{}, error) { +func (cf *ssClientFactory) ParseArgs(args *pt.Args) (any, error) { return newClientArgs(args) } -func (cf *ssClientFactory) Dial(network, addr string, dialFn base.DialFunc, args interface{}) (net.Conn, error) { +func (cf *ssClientFactory) Dial(network, addr string, dialFn base.DialFunc, args any) (net.Conn, error) { // Validate args before opening outgoing connection. ca, ok := args.(*ssClientArgs) if !ok { @@ -95,5 +95,7 @@ func (cf *ssClientFactory) Dial(network, addr string, dialFn base.DialFunc, args return conn, nil } -var _ base.ClientFactory = (*ssClientFactory)(nil) -var _ base.Transport = (*Transport)(nil) +var ( + _ base.ClientFactory = (*ssClientFactory)(nil) + _ base.Transport = (*Transport)(nil) +) diff --git a/transports/scramblesuit/conn.go b/transports/scramblesuit/conn.go index cc18e18..7b0f66c 100644 --- a/transports/scramblesuit/conn.go +++ b/transports/scramblesuit/conn.go @@ -42,7 +42,9 @@ import ( "net" "time" - "git.torproject.org/pluggable-transports/goptlib.git" + "gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/goptlib" + "golang.org/x/crypto/hkdf" + "gitlab.com/yawning/obfs4.git/common/csrand" "gitlab.com/yawning/obfs4.git/common/drbg" "gitlab.com/yawning/obfs4.git/common/probdist" @@ -87,8 +89,11 @@ type ssClientArgs struct { sessionKey *uniformdh.PrivateKey } -func newClientArgs(args *pt.Args) (ca *ssClientArgs, err error) { - ca = &ssClientArgs{} +func newClientArgs(args *pt.Args) (*ssClientArgs, error) { + var ( + ca ssClientArgs + err error + ) if ca.kB, err = parsePasswordArg(args); err != nil { return nil, err } @@ -99,7 +104,7 @@ func newClientArgs(args *pt.Args) (ca *ssClientArgs, err error) { if ca.sessionKey, err = uniformdh.GenerateKey(csrand.Reader); err != nil { return nil, err } - return + return &ca, nil } func parsePasswordArg(args *pt.Args) (*ssSharedSecret, error) { @@ -112,7 +117,7 @@ func parsePasswordArg(args *pt.Args) (*ssSharedSecret, error) { // shared secret (k_B) used for handshaking. decoded, err := base32.StdEncoding.DecodeString(str) if err != nil { - return nil, fmt.Errorf("failed to decode password: %s", err) + return nil, fmt.Errorf("failed to decode password: %w", err) } if len(decoded) != sharedSecretLength { return nil, fmt.Errorf("password length %d is invalid", len(decoded)) @@ -131,7 +136,7 @@ func newCryptoState(aesKey []byte, ivPrefix []byte, macKey []byte) (*ssCryptoSta // The ScrambleSuit CTR-AES256 link crypto uses an 8 byte prefix from the // KDF, and a 64 bit counter initialized to 1 as the IV. The initial value // of the counter isn't documented in the spec either. - var initialCtr = []byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01} + initialCtr := []byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01} iv := make([]byte, 0, aes.BlockSize) iv = append(iv, ivPrefix...) iv = append(iv, initialCtr...) @@ -168,7 +173,8 @@ type ssRxState struct { payloadLen int } -func (conn *ssConn) Read(b []byte) (n int, err error) { +func (conn *ssConn) Read(b []byte) (int, error) { + var err error // If the receive payload buffer is empty, consume data off the network. for conn.receiveDecodedBuffer.Len() == 0 { if err = conn.readPackets(); err != nil { @@ -177,17 +183,19 @@ func (conn *ssConn) Read(b []byte) (n int, err error) { } // Service the read request using buffered payload. + var n int if conn.receiveDecodedBuffer.Len() > 0 { n, _ = conn.receiveDecodedBuffer.Read(b) } - return + return n, err } -func (conn *ssConn) Write(b []byte) (n int, err error) { +func (conn *ssConn) Write(b []byte) (int, error) { var frameBuf bytes.Buffer p := b toSend := len(p) + var n int for toSend > 0 { // Send as much payload as will fit into each frame as possible. wrLen := len(p) @@ -195,7 +203,7 @@ func (conn *ssConn) Write(b []byte) (n int, err error) { wrLen = maxPayloadLength } payload := p[:wrLen] - if err = conn.makePacket(&frameBuf, pktPayload, payload, 0); err != nil { + if err := conn.makePayloadPacket(&frameBuf, payload, 0); err != nil { return 0, err } @@ -205,28 +213,28 @@ func (conn *ssConn) Write(b []byte) (n int, err error) { } // Pad out the burst as appropriate. - if err = conn.padBurst(&frameBuf, conn.lenDist.Sample()); err != nil { + if err := conn.padBurst(&frameBuf, conn.lenDist.Sample()); err != nil { return 0, err } // Write and return. - _, err = conn.Conn.Write(frameBuf.Bytes()) - return + _, err := conn.Conn.Write(frameBuf.Bytes()) + return n, err } -func (conn *ssConn) SetDeadline(t time.Time) error { +func (conn *ssConn) SetDeadline(_ time.Time) error { return ErrNotSupported } -func (conn *ssConn) SetReadDeadline(t time.Time) error { +func (conn *ssConn) SetReadDeadline(_ time.Time) error { return ErrNotSupported } -func (conn *ssConn) SetWriteDeadline(t time.Time) error { +func (conn *ssConn) SetWriteDeadline(_ time.Time) error { return ErrNotSupported } -func (conn *ssConn) makePacket(w io.Writer, pktType byte, data []byte, padLen int) error { +func (conn *ssConn) makePayloadPacket(w io.Writer, data []byte, padLen int) error { payloadLen := len(data) totalLen := payloadLen + padLen if totalLen > maxPayloadLength { @@ -238,7 +246,7 @@ func (conn *ssConn) makePacket(w io.Writer, pktType byte, data []byte, padLen in pkt := make([]byte, pktHdrLength, pktHdrLength+payloadLen+padLen) binary.BigEndian.PutUint16(pkt[0:], uint16(totalLen)) binary.BigEndian.PutUint16(pkt[2:], uint16(payloadLen)) - pkt[4] = pktType + pkt[4] = pktPayload pkt = append(pkt, data...) pkt = append(pkt, zeroPadBytes[:padLen]...) @@ -319,7 +327,7 @@ func (conn *ssConn) readPackets() error { // Authenticate the packet, by comparing the received MAC with the one // calculated over the ciphertext consumed off the network. cmpMAC := conn.rxCrypto.mac.Sum(nil)[:macLength] - if !hmac.Equal(cmpMAC, conn.receiveState.mac[:]) { + if !hmac.Equal(cmpMAC, conn.receiveState.mac) { return ErrInvalidPacket } @@ -426,7 +434,7 @@ handshakeUDH: // Attempt to process all the data seen so far as a response. var seed []byte n, seed, err = hs.parseServerHandshake(conn.receiveBuffer.Bytes()) - if err == errMarkNotFoundYet { + if errors.Is(err, errMarkNotFoundYet) { // No response found yet, keep trying. continue } else if err != nil { @@ -444,7 +452,12 @@ handshakeUDH: func (conn *ssConn) initCrypto(seed []byte) error { // Use HKDF-SHA256 (Expand only, no Extract) to generate session keys from // initial keying material. - okm := hkdfExpand(sha256.New, seed, nil, kdfSecretLength) + rd := hkdf.Expand(sha256.New, seed, nil) + okm := make([]byte, kdfSecretLength) + if _, err := io.ReadFull(rd, okm); err != nil { + return err + } + var err error conn.txCrypto, err = newCryptoState(okm[0:32], okm[32:40], okm[80:112]) if err != nil { @@ -463,7 +476,7 @@ func (conn *ssConn) padBurst(burst *bytes.Buffer, sampleLen int) error { // the ScrambleSuit MTU) is sampleLen bytes. dataLen := burst.Len() % maxSegmentLength - padLen := 0 + var padLen int if sampleLen >= dataLen { padLen = sampleLen - dataLen } else { @@ -481,12 +494,12 @@ func (conn *ssConn) padBurst(burst *bytes.Buffer, sampleLen int) error { if padLen > maxSegmentLength { // Note: packetmorpher.py: getPadding is slightly wrong and only // accounts for one of the two packet headers. - if err := conn.makePacket(burst, pktPayload, nil, 700-pktOverhead); err != nil { + if err := conn.makePayloadPacket(burst, nil, 700-pktOverhead); err != nil { return err } - return conn.makePacket(burst, pktPayload, nil, padLen-(700+2*pktOverhead)) + return conn.makePayloadPacket(burst, nil, padLen-(700+2*pktOverhead)) } - return conn.makePacket(burst, pktPayload, nil, padLen-pktOverhead) + return conn.makePayloadPacket(burst, nil, padLen-pktOverhead) } func newScrambleSuitClientConn(conn net.Conn, tStore *ssTicketStore, ca *ssClientArgs) (net.Conn, error) { diff --git a/transports/scramblesuit/handshake_ticket.go b/transports/scramblesuit/handshake_ticket.go index f415be0..a850dcd 100644 --- a/transports/scramblesuit/handshake_ticket.go +++ b/transports/scramblesuit/handshake_ticket.go @@ -34,7 +34,6 @@ import ( "errors" "fmt" "hash" - "io/ioutil" "net" "os" "path" @@ -56,9 +55,7 @@ const ( ticketMaxPadLength = 1388 ) -var ( - errInvalidTicket = errors.New("scramblesuit: invalid serialized ticket") -) +var errInvalidTicket = errors.New("scramblesuit: invalid serialized ticket") type ssTicketStore struct { sync.Mutex @@ -129,7 +126,7 @@ func (s *ssTicketStore) getTicket(addr net.Addr) (*ssTicket, error) { } // No ticket was found, that's fine. - return nil, nil + return nil, nil //nolint:nilnil } func (s *ssTicketStore) serialize() error { @@ -146,7 +143,7 @@ func (s *ssTicketStore) serialize() error { if err != nil { return err } - return ioutil.WriteFile(s.filePath, jsonStr, 0600) + return os.WriteFile(s.filePath, jsonStr, 0o600) } func loadTicketStore(stateDir string) (*ssTicketStore, error) { @@ -154,7 +151,7 @@ func loadTicketStore(stateDir string) (*ssTicketStore, error) { s := &ssTicketStore{filePath: fPath} s.store = make(map[string]*ssTicket) - f, err := ioutil.ReadFile(fPath) + f, err := os.ReadFile(fPath) if err != nil { // No ticket store is fine. if os.IsNotExist(err) { @@ -167,7 +164,7 @@ func loadTicketStore(stateDir string) (*ssTicketStore, error) { encMap := make(map[string]*ssTicketJSON) if err = json.Unmarshal(f, &encMap); err != nil { - return nil, fmt.Errorf("failed to load ticket store '%s': '%s'", fPath, err) + return nil, fmt.Errorf("failed to load ticket store '%s': %w", fPath, err) } for k, v := range encMap { raw, err := base32.StdEncoding.DecodeString(v.KeyTicket) diff --git a/transports/scramblesuit/hkdf_expand.go b/transports/scramblesuit/hkdf_expand.go deleted file mode 100644 index cbdcbb3..0000000 --- a/transports/scramblesuit/hkdf_expand.go +++ /dev/null @@ -1,67 +0,0 @@ -/* - * Copyright (c) 2015, Yawning Angel - * All rights reserved. - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * * Redistributions of source code must retain the above copyright notice, - * this list of conditions and the following disclaimer. - * - * * Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE - * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE - * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR - * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF - * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS - * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN - * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) - * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE - * POSSIBILITY OF SUCH DAMAGE. - */ - -package scramblesuit - -import ( - "crypto/hmac" - "hash" -) - -func hkdfExpand(hashFn func() hash.Hash, prk []byte, info []byte, l int) []byte { - // Why, yes. golang.org/x/crypto/hkdf exists, and is a fine - // implementation of HKDF. However it does both the extract - // and expand, while ScrambleSuit only does extract, with no - // way to separate the two steps. - - h := hmac.New(hashFn, prk) - digestSz := h.Size() - if l > 255*digestSz { - panic("hkdf: requested OKM length > 255*HashLen") - } - - var t []byte - okm := make([]byte, 0, l) - toAppend := l - ctr := byte(1) - for toAppend > 0 { - h.Reset() - _, _ = h.Write(t) - _, _ = h.Write(info) - _, _ = h.Write([]byte{ctr}) - t = h.Sum(nil) - ctr++ - - aLen := digestSz - if toAppend < digestSz { - aLen = toAppend - } - okm = append(okm, t[:aLen]...) - toAppend -= aLen - } - return okm -} diff --git a/transports/transports.go b/transports/transports.go index 2e83688..b4e6377 100644 --- a/transports/transports.go +++ b/transports/transports.go @@ -41,8 +41,10 @@ import ( "gitlab.com/yawning/obfs4.git/transports/scramblesuit" ) -var transportMapLock sync.Mutex -var transportMap map[string]base.Transport = make(map[string]base.Transport) +var ( + transportMapLock sync.Mutex + transportMap map[string]base.Transport = make(map[string]base.Transport) +) // Register registers a transport protocol. func Register(transport base.Transport) error { @@ -64,7 +66,7 @@ func Transports() []string { transportMapLock.Lock() defer transportMapLock.Unlock() - var ret []string + ret := make([]string, 0, len(transportMap)) for name := range transportMap { ret = append(ret, name) }