obfs4: Clean up and modernize the codebase

While the thought of dealing with this codebase makes me reach for the
Benzodiazepines, I might as well clean this up.
master
Yawning Angel 10 months ago
parent 645026c2ad
commit efdc692691

@ -0,0 +1,120 @@
linters:
disable-all: true
enable:
# Re-enable the default linters
- errcheck
- gosimple
- govet
- ineffassign
- staticcheck
- typecheck
- unused
# Enable the "always useful" linters as of 1.53.3
- asasalint
- asciicheck
- bidichk
- decorder
- dogsled
- dupl
- dupword
- errchkjson
- errname
- errorlint
- exhaustive
- exportloopref
- forbidigo
- forcetypeassert
- gci
- gocheckcompilerdirectives
- gochecknoinits
- goconst
- gocritic
- godot
- godox
- gofumpt
- gomoddirectives
- goprintffuncname
- gosec
- gosmopolitan
- importas
- interfacebloat
- makezero
- mirror
- misspell
- musttag
- nakedret
- nestif
- nilerr
- nilnil
- nolintlint
- nonamedreturns
- prealloc
- predeclared
- reassign
- revive
- tagalign
- tenv
- testableexamples
- unconvert
- unparam
- usestdlibvars
- wastedassign
- whitespace
# Disabled: Run periodically, but too many places to annotate
# - gomnd
# Disabled: Not how I do things
# - exhaustruct # Zero value is fine.
# - funlen # I'm not breaking up my math.
# - gochecknoglobals # How else am I supposed to declare constants.
# - lll # The 70s called and wants their ttys back.
# - paralleltest
# - varnamelen # The papers use short variable names.
# - tagliatelle # I want my tags to match the files.
# - thelper
# - tparallel
# - testpackage
# - wsl # Nice idea, not how I like to write code.
# - goerr113 # Nice idea, this package has too much legacy bs.
# - ireturn # By virtue of the PT API we are interface heavy.
# Disabled: Annoying/Useless
# - cyclop
# - gocognit
# - gocyclo
# - maintidx
# - wrapcheck
# Disabled: Irrelevant/redundant
# - bodyclose
# - containedctx
# - contextcheck
# - depguard
# - durationcheck
# - execinquery
# - ginkgolinter
# - gofmt
# - goheader
# - goimports
# - gomodguard
# - grouper
# - loggercheck
# - nlreturn
# - noctx
# - nosprintfhostport
# - promlinter
# - rowserrcheck
# - sqlclosecheck
# - stylecheck
# - zerologlint
linters-settings:
gci:
sections:
- standard
- default
- prefix(gitlab.com/yawning/obfs4.git)
skip-generated: true
custom-order: true

@ -45,7 +45,7 @@ var (
csRandSourceInstance csRandSource csRandSourceInstance csRandSource
// Rand is a math/rand instance backed by crypto/rand CSPRNG. // Rand is a math/rand instance backed by crypto/rand CSPRNG.
Rand = rand.New(csRandSourceInstance) Rand = rand.New(csRandSourceInstance) //nolint:gosec
) )
type csRandSource struct { type csRandSource struct {
@ -63,7 +63,7 @@ func (r csRandSource) Int63() int64 {
return int64(val) return int64(val)
} }
func (r csRandSource) Seed(seed int64) { func (r csRandSource) Seed(_ int64) {
// No-op. // No-op.
} }

@ -30,12 +30,14 @@
package drbg // import "gitlab.com/yawning/obfs4.git/common/drbg" package drbg // import "gitlab.com/yawning/obfs4.git/common/drbg"
import ( import (
"bytes"
"encoding/binary" "encoding/binary"
"encoding/hex" "encoding/hex"
"fmt" "fmt"
"hash" "hash"
"github.com/dchest/siphash" "github.com/dchest/siphash"
"gitlab.com/yawning/obfs4.git/common/csrand" "gitlab.com/yawning/obfs4.git/common/csrand"
) )
@ -60,33 +62,33 @@ func (seed *Seed) Hex() string {
} }
// NewSeed returns a Seed initialized with the runtime CSPRNG. // NewSeed returns a Seed initialized with the runtime CSPRNG.
func NewSeed() (seed *Seed, err error) { func NewSeed() (*Seed, error) {
seed = new(Seed) seed := new(Seed)
if err = csrand.Bytes(seed.Bytes()[:]); err != nil { if err := csrand.Bytes(seed.Bytes()[:]); err != nil {
return nil, err return nil, err
} }
return return seed, nil
} }
// SeedFromBytes creates a Seed from the raw bytes, truncating to SeedLength as // SeedFromBytes creates a Seed from the raw bytes, truncating to SeedLength as
// appropriate. // appropriate.
func SeedFromBytes(src []byte) (seed *Seed, err error) { func SeedFromBytes(src []byte) (*Seed, error) {
if len(src) < SeedLength { if len(src) < SeedLength {
return nil, InvalidSeedLengthError(len(src)) return nil, InvalidSeedLengthError(len(src))
} }
seed = new(Seed) seed := new(Seed)
copy(seed.Bytes()[:], src) copy(seed.Bytes()[:], src)
return return seed, nil
} }
// SeedFromHex creates a Seed from the hexdecimal representation, truncating to // SeedFromHex creates a Seed from the hexdecimal representation, truncating to
// SeedLength as appropriate. // SeedLength as appropriate.
func SeedFromHex(encoded string) (seed *Seed, err error) { func SeedFromHex(encoded string) (*Seed, error) {
var raw []byte raw, err := hex.DecodeString(encoded)
if raw, err = hex.DecodeString(encoded); err != nil { if err != nil {
return nil, err return nil, err
} }
@ -133,7 +135,7 @@ func (drbg *HashDrbg) Int63() int64 {
} }
// Seed does nothing, call NewHashDrbg if you want to reseed. // Seed does nothing, call NewHashDrbg if you want to reseed.
func (drbg *HashDrbg) Seed(seed int64) { func (drbg *HashDrbg) Seed(_ int64) {
// No-op. // No-op.
} }
@ -142,7 +144,5 @@ func (drbg *HashDrbg) NextBlock() []byte {
_, _ = drbg.sip.Write(drbg.ofb[:]) _, _ = drbg.sip.Write(drbg.ofb[:])
copy(drbg.ofb[:], drbg.sip.Sum(nil)) copy(drbg.ofb[:], drbg.sip.Sum(nil))
ret := make([]byte, Size) return bytes.Clone(drbg.ofb[:])
copy(ret, drbg.ofb[:])
return ret
} }

@ -30,8 +30,9 @@
package log // import "gitlab.com/yawning/obfs4.git/common/log" package log // import "gitlab.com/yawning/obfs4.git/common/log"
import ( import (
"errors"
"fmt" "fmt"
"io/ioutil" "io"
"log" "log"
"net" "net"
"os" "os"
@ -54,20 +55,22 @@ const (
LevelDebug LevelDebug
) )
var logLevel = LevelInfo var (
var enableLogging bool logLevel = LevelInfo
var unsafeLogging bool enableLogging bool
unsafeLogging bool
)
// Init initializes logging with the given path, and log safety options. // Init initializes logging with the given path, and log safety options.
func Init(enable bool, logFilePath string, unsafe bool) error { func Init(enable bool, logFilePath string, unsafe bool) error {
if enable { if enable {
f, err := os.OpenFile(logFilePath, os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0600) f, err := os.OpenFile(logFilePath, os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0o600)
if err != nil { if err != nil {
return err return err
} }
log.SetOutput(f) log.SetOutput(f)
} else { } else {
log.SetOutput(ioutil.Discard) log.SetOutput(io.Discard)
} }
enableLogging = enable enableLogging = enable
unsafeLogging = unsafe unsafeLogging = unsafe
@ -163,8 +166,8 @@ func ElideError(err error) string {
// If err is not a net.Error, just return the string representation, // If err is not a net.Error, just return the string representation,
// presumably transport authors know what they are doing. // presumably transport authors know what they are doing.
netErr, ok := err.(net.Error) var netErr net.Error
if !ok { if !errors.As(err, &netErr) {
return err.Error() return err.Error()
} }

@ -70,15 +70,17 @@ const (
// KeySeedLength is the length of the derived KEY_SEED. // KeySeedLength is the length of the derived KEY_SEED.
KeySeedLength = sha256.Size KeySeedLength = sha256.Size
// AuthLength is the lenght of the derived AUTH. // AuthLength is the length of the derived AUTH.
AuthLength = sha256.Size AuthLength = sha256.Size
) )
var protoID = []byte("ntor-curve25519-sha256-1") var (
var tMac = append(protoID, []byte(":mac")...) protoID = []byte("ntor-curve25519-sha256-1")
var tKey = append(protoID, []byte(":key_extract")...) tMac = append(protoID, []byte(":mac")...)
var tVerify = append(protoID, []byte(":key_verify")...) tKey = append(protoID, []byte(":key_extract")...)
var mExpand = append(protoID, []byte(":key_expand")...) tVerify = append(protoID, []byte(":key_verify")...)
mExpand = append(protoID, []byte(":key_expand")...)
)
// PublicKeyLengthError is the error returned when the public key being // PublicKeyLengthError is the error returned when the public key being
// imported is an invalid length. // imported is an invalid length.
@ -320,48 +322,43 @@ func KeypairFromHex(encoded string) (*Keypair, error) {
// ServerHandshake does the server side of a ntor handshake and returns status, // ServerHandshake does the server side of a ntor handshake and returns status,
// KEY_SEED, and AUTH. If status is not true, the handshake MUST be aborted. // KEY_SEED, and AUTH. If status is not true, the handshake MUST be aborted.
func ServerHandshake(clientPublic *PublicKey, serverKeypair *Keypair, idKeypair *Keypair, id *NodeID) (ok bool, keySeed *KeySeed, auth *Auth) { func ServerHandshake(clientPublic *PublicKey, serverKeypair *Keypair, idKeypair *Keypair, id *NodeID) (bool, *KeySeed, *Auth) {
var notOk int var notOk int
var secretInput bytes.Buffer var secretInput bytes.Buffer
// Server side uses EXP(X,y) | EXP(X,b) // Server side uses EXP(X,y) | EXP(X,b)
var exp [SharedSecretLength]byte var exp [SharedSecretLength]byte
curve25519.ScalarMult(&exp, serverKeypair.private.Bytes(), curve25519.ScalarMult(&exp, serverKeypair.private.Bytes(), clientPublic.Bytes()) //nolint:staticcheck
clientPublic.Bytes())
notOk |= constantTimeIsZero(exp[:]) notOk |= constantTimeIsZero(exp[:])
secretInput.Write(exp[:]) secretInput.Write(exp[:])
curve25519.ScalarMult(&exp, idKeypair.private.Bytes(), curve25519.ScalarMult(&exp, idKeypair.private.Bytes(), clientPublic.Bytes()) //nolint:staticcheck
clientPublic.Bytes())
notOk |= constantTimeIsZero(exp[:]) notOk |= constantTimeIsZero(exp[:])
secretInput.Write(exp[:]) secretInput.Write(exp[:])
keySeed, auth = ntorCommon(secretInput, id, idKeypair.public, keySeed, auth := ntorCommon(secretInput, id, idKeypair.public,
clientPublic, serverKeypair.public) clientPublic, serverKeypair.public)
return notOk == 0, keySeed, auth return notOk == 0, keySeed, auth
} }
// ClientHandshake does the client side of a ntor handshake and returnes // ClientHandshake does the client side of a ntor handshake and returnes
// status, KEY_SEED, and AUTH. If status is not true or AUTH does not match // status, KEY_SEED, and AUTH. If status is not true or AUTH does not match
// the value recieved from the server, the handshake MUST be aborted. // the value received from the server, the handshake MUST be aborted.
func ClientHandshake(clientKeypair *Keypair, serverPublic *PublicKey, idPublic *PublicKey, id *NodeID) (ok bool, keySeed *KeySeed, auth *Auth) { func ClientHandshake(clientKeypair *Keypair, serverPublic *PublicKey, idPublic *PublicKey, id *NodeID) (bool, *KeySeed, *Auth) {
var notOk int var notOk int
var secretInput bytes.Buffer var secretInput bytes.Buffer
// Client side uses EXP(Y,x) | EXP(B,x) // Client side uses EXP(Y,x) | EXP(B,x)
var exp [SharedSecretLength]byte var exp [SharedSecretLength]byte
curve25519.ScalarMult(&exp, clientKeypair.private.Bytes(), curve25519.ScalarMult(&exp, clientKeypair.private.Bytes(), serverPublic.Bytes()) //nolint:staticcheck
serverPublic.Bytes())
notOk |= constantTimeIsZero(exp[:]) notOk |= constantTimeIsZero(exp[:])
secretInput.Write(exp[:]) secretInput.Write(exp[:])
curve25519.ScalarMult(&exp, clientKeypair.private.Bytes(), curve25519.ScalarMult(&exp, clientKeypair.private.Bytes(), idPublic.Bytes()) //nolint:staticcheck
idPublic.Bytes())
notOk |= constantTimeIsZero(exp[:]) notOk |= constantTimeIsZero(exp[:])
secretInput.Write(exp[:]) secretInput.Write(exp[:])
keySeed, auth = ntorCommon(secretInput, id, idPublic, keySeed, auth := ntorCommon(secretInput, id, idPublic, clientKeypair.public, serverPublic)
clientKeypair.public, serverPublic)
return notOk == 0, keySeed, auth return notOk == 0, keySeed, auth
} }
@ -402,7 +399,7 @@ func ntorCommon(secretInput bytes.Buffer, id *NodeID, b *PublicKey, x *PublicKey
// auth_input = verify | ID | B | Y | X | PROTOID | "Server" // auth_input = verify | ID | B | Y | X | PROTOID | "Server"
authInput := bytes.NewBuffer(verify) authInput := bytes.NewBuffer(verify)
_, _ = authInput.Write(suffix.Bytes()) _, _ = authInput.Write(suffix.Bytes())
_, _ = authInput.Write([]byte("Server")) _, _ = authInput.WriteString("Server")
h = hmac.New(sha256.New, tMac) h = hmac.New(sha256.New, tMac)
_, _ = h.Write(authInput.Bytes()) _, _ = h.Write(authInput.Bytes())
tmp = h.Sum(nil) tmp = h.Sum(nil)

@ -64,8 +64,8 @@ type WeightedDist struct {
// based on a HashDrbg initialized with seed. Optionally, bias the weight // based on a HashDrbg initialized with seed. Optionally, bias the weight
// generation to match the ScrambleSuit non-uniform distribution from // generation to match the ScrambleSuit non-uniform distribution from
// obfsproxy. // obfsproxy.
func New(seed *drbg.Seed, min, max int, biased bool) (w *WeightedDist) { func New(seed *drbg.Seed, min, max int, biased bool) *WeightedDist {
w = &WeightedDist{minValue: min, maxValue: max, biased: biased} w := &WeightedDist{minValue: min, maxValue: max, biased: biased}
if max <= min { if max <= min {
panic(fmt.Sprintf("wDist.Reset(): min >= max (%d, %d)", min, max)) panic(fmt.Sprintf("wDist.Reset(): min >= max (%d, %d)", min, max))
@ -73,7 +73,7 @@ func New(seed *drbg.Seed, min, max int, biased bool) (w *WeightedDist) {
w.Reset(seed) w.Reset(seed)
return return w
} }
// genValues creates a slice containing a random number of random values // genValues creates a slice containing a random number of random values
@ -132,7 +132,7 @@ func (w *WeightedDist) genTables() {
scaled := make([]float64, n) scaled := make([]float64, n)
for i, weight := range w.weights { for i, weight := range w.weights {
// Multiply each probability by $n$. // Multiply each probability by $n$.
p_i := weight * float64(n) / sum p_i := weight * float64(n) / sum //nolint:revive
scaled[i] = p_i scaled[i] = p_i
// For each scaled probability $p_i$: // For each scaled probability $p_i$:
@ -148,9 +148,9 @@ func (w *WeightedDist) genTables() {
// While $Small$ and $Large$ are not empty: ($Large$ might be emptied first) // While $Small$ and $Large$ are not empty: ($Large$ might be emptied first)
for small.Len() > 0 && large.Len() > 0 { for small.Len() > 0 && large.Len() > 0 {
// Remove the first element from $Small$; call it $l$. // Remove the first element from $Small$; call it $l$.
l := small.Remove(small.Front()).(int) l, _ := small.Remove(small.Front()).(int)
// Remove the first element from $Large$; call it $g$. // Remove the first element from $Large$; call it $g$.
g := large.Remove(large.Front()).(int) g, _ := large.Remove(large.Front()).(int)
// Set $Prob[l] = p_l$. // Set $Prob[l] = p_l$.
prob[l] = scaled[l] prob[l] = scaled[l]
@ -172,7 +172,7 @@ func (w *WeightedDist) genTables() {
// While $Large$ is not empty: // While $Large$ is not empty:
for large.Len() > 0 { for large.Len() > 0 {
// Remove the first element from $Large$; call it $g$. // Remove the first element from $Large$; call it $g$.
g := large.Remove(large.Front()).(int) g, _ := large.Remove(large.Front()).(int)
// Set $Prob[g] = 1$. // Set $Prob[g] = 1$.
prob[g] = 1.0 prob[g] = 1.0
} }
@ -180,7 +180,7 @@ func (w *WeightedDist) genTables() {
// While $Small$ is not empty: This is only possible due to numerical instability. // While $Small$ is not empty: This is only possible due to numerical instability.
for small.Len() > 0 { for small.Len() > 0 {
// Remove the first element from $Small$; call it $l$. // Remove the first element from $Small$; call it $l$.
l := small.Remove(small.Front()).(int) l, _ := small.Remove(small.Front()).(int)
// Set $Prob[l] = 1$. // Set $Prob[l] = 1$.
prob[l] = 1.0 prob[l] = 1.0
} }
@ -194,7 +194,7 @@ func (w *WeightedDist) genTables() {
func (w *WeightedDist) Reset(seed *drbg.Seed) { func (w *WeightedDist) Reset(seed *drbg.Seed) {
// Initialize the deterministic random number generator. // Initialize the deterministic random number generator.
drbg, _ := drbg.NewHashDrbg(seed) drbg, _ := drbg.NewHashDrbg(seed)
rng := rand.New(drbg) rng := rand.New(drbg) //nolint:gosec
w.Lock() w.Lock()
defer w.Unlock() defer w.Unlock()

@ -28,7 +28,6 @@
package probdist package probdist
import ( import (
"fmt"
"testing" "testing"
"gitlab.com/yawning/obfs4.git/common/drbg" "gitlab.com/yawning/obfs4.git/common/drbg"
@ -49,7 +48,7 @@ func TestWeightedDist(t *testing.T) {
w := New(seed, 0, 999, true) w := New(seed, 0, 999, true)
if debug { if debug {
// Dump a string representation of the probability table. // Dump a string representation of the probability table.
fmt.Println("Table:") t.Logf("Table:")
var sum float64 var sum float64
for _, weight := range w.weights { for _, weight := range w.weights {
sum += weight sum += weight
@ -57,10 +56,9 @@ func TestWeightedDist(t *testing.T) {
for i, weight := range w.weights { for i, weight := range w.weights {
p := weight / sum p := weight / sum
if p > 0.000001 { // Filter out tiny values. if p > 0.000001 { // Filter out tiny values.
fmt.Printf(" [%d]: %f\n", w.minValue+w.values[i], p) t.Logf(" [%d]: %f", w.minValue+w.values[i], p)
} }
} }
fmt.Println()
} }
for i := 0; i < nrTrials; i++ { for i := 0; i < nrTrials; i++ {
@ -69,11 +67,11 @@ func TestWeightedDist(t *testing.T) {
} }
if debug { if debug {
fmt.Println("Generated:") t.Logf("Generated:")
for value, count := range hist { for value, count := range hist {
if count != 0 { if count != 0 {
p := float64(count) / float64(nrTrials) p := float64(count) / float64(nrTrials)
fmt.Printf(" [%d]: %f (%d)\n", value, p, count) t.Logf(" [%d]: %f (%d)", value, p, count)
} }
} }
} }

@ -39,6 +39,7 @@ import (
"time" "time"
"github.com/dchest/siphash" "github.com/dchest/siphash"
"gitlab.com/yawning/obfs4.git/common/csrand" "gitlab.com/yawning/obfs4.git/common/csrand"
) )
@ -67,21 +68,21 @@ type ReplayFilter struct {
} }
// New creates a new ReplayFilter instance. // New creates a new ReplayFilter instance.
func New(ttl time.Duration) (filter *ReplayFilter, err error) { func New(ttl time.Duration) (*ReplayFilter, error) {
// Initialize the SipHash-2-4 instance with a random key. // Initialize the SipHash-2-4 instance with a random key.
var key [16]byte var key [16]byte
if err = csrand.Bytes(key[:]); err != nil { if err := csrand.Bytes(key[:]); err != nil {
return return nil, err
} }
filter = new(ReplayFilter) filter := new(ReplayFilter)
filter.filter = make(map[uint64]*entry) filter.filter = make(map[uint64]*entry)
filter.fifo = list.New() filter.fifo = list.New()
filter.key[0] = binary.BigEndian.Uint64(key[0:8]) filter.key[0] = binary.BigEndian.Uint64(key[0:8])
filter.key[1] = binary.BigEndian.Uint64(key[8:16]) filter.key[1] = binary.BigEndian.Uint64(key[8:16])
filter.ttl = ttl filter.ttl = ttl
return return filter, nil
} }
// TestAndSet queries the filter for a given byte sequence, inserts the // TestAndSet queries the filter for a given byte sequence, inserts the

@ -29,7 +29,8 @@ package socks5
import ( import (
"fmt" "fmt"
"git.torproject.org/pluggable-transports/goptlib.git"
"gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/goptlib"
) )
// parseClientParameters takes a client parameter string formatted according to // parseClientParameters takes a client parameter string formatted according to
@ -37,14 +38,14 @@ import (
// specification, and returns it as a goptlib Args structure. // specification, and returns it as a goptlib Args structure.
// //
// This is functionally identical to the equivalently named goptlib routine. // This is functionally identical to the equivalently named goptlib routine.
func parseClientParameters(argStr string) (args pt.Args, err error) { func parseClientParameters(argStr string) (pt.Args, error) {
args = make(pt.Args) args := make(pt.Args)
if len(argStr) == 0 { if len(argStr) == 0 {
return return args, nil
} }
var key string var key string
var acc []byte acc := make([]byte, 0, len(argStr))
prevIsEscape := false prevIsEscape := false
for idx, ch := range []byte(argStr) { for idx, ch := range []byte(argStr) {
switch ch { switch ch {

@ -5,7 +5,7 @@ package socks5
import ( import (
"testing" "testing"
"git.torproject.org/pluggable-transports/goptlib.git" "gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/goptlib"
) )
func stringSlicesEqual(a, b []string) bool { func stringSlicesEqual(a, b []string) bool {

@ -35,12 +35,12 @@ const (
authRFC1929Fail = 0x01 authRFC1929Fail = 0x01
) )
func (req *Request) authRFC1929() (err error) { func (req *Request) authRFC1929() error {
sendErrResp := func() { sendErrResp := func(err error) error {
// Swallow write/flush errors, the auth failure is the relevant error. // Swallow write/flush errors, the auth failure is the relevant error.
resp := []byte{authRFC1929Ver, authRFC1929Fail} _, _ = req.rw.Write([]byte{authRFC1929Ver, authRFC1929Fail})
_, _ = req.rw.Write(resp[:])
_ = req.flushBuffers() _ = req.flushBuffers()
return err // Pass this through from the arg.
} }
// The client sends a Username/Password request. // The client sends a Username/Password request.
@ -50,39 +50,35 @@ func (req *Request) authRFC1929() (err error) {
// uint8_t plen (>= 1) // uint8_t plen (>= 1)
// uint8_t passwd[plen] // uint8_t passwd[plen]
if err = req.readByteVerify("auth version", authRFC1929Ver); err != nil { if err := req.readByteVerify("auth version", authRFC1929Ver); err != nil {
sendErrResp() return sendErrResp(err)
return
} }
// Read the username. // Read the username.
var ulen byte var (
ulen byte
err error
)
if ulen, err = req.readByte(); err != nil { if ulen, err = req.readByte(); err != nil {
sendErrResp() return sendErrResp(err)
return
} else if ulen < 1 { } else if ulen < 1 {
sendErrResp() return sendErrResp(fmt.Errorf("username with 0 length"))
return fmt.Errorf("username with 0 length")
} }
var uname []byte var uname []byte
if uname, err = req.readBytes(int(ulen)); err != nil { if uname, err = req.readBytes(int(ulen)); err != nil {
sendErrResp() return sendErrResp(err)
return
} }
// Read the password. // Read the password.
var plen byte var plen byte
if plen, err = req.readByte(); err != nil { if plen, err = req.readByte(); err != nil {
sendErrResp() return sendErrResp(err)
return
} else if plen < 1 { } else if plen < 1 {
sendErrResp() return sendErrResp(fmt.Errorf("password with 0 length"))
return fmt.Errorf("password with 0 length")
} }
var passwd []byte var passwd []byte
if passwd, err = req.readBytes(int(plen)); err != nil { if passwd, err = req.readBytes(int(plen)); err != nil {
sendErrResp() return sendErrResp(err)
return
} }
// Pluggable transports use the username/password field to pass // Pluggable transports use the username/password field to pass
@ -95,11 +91,10 @@ func (req *Request) authRFC1929() (err error) {
argStr += string(passwd) argStr += string(passwd)
} }
if req.Args, err = parseClientParameters(argStr); err != nil { if req.Args, err = parseClientParameters(argStr); err != nil {
sendErrResp() return sendErrResp(err)
return
} }
resp := []byte{authRFC1929Ver, authRFC1929Success} resp := []byte{authRFC1929Ver, authRFC1929Success}
_, err = req.rw.Write(resp[:]) _, err = req.rw.Write(resp)
return return err
} }

@ -30,23 +30,24 @@
// 1929. // 1929.
// //
// Notes: // Notes:
// * GSSAPI authentication, is NOT supported. // - GSSAPI authentication, is NOT supported.
// * Only the CONNECT command is supported. // - Only the CONNECT command is supported.
// * The authentication provided by the client is always accepted as it is // - The authentication provided by the client is always accepted as it is
// used as a channel to pass information rather than for authentication for // used as a channel to pass information rather than for authentication for
// pluggable transports. // pluggable transports.
package socks5 // import "gitlab.com/yawning/obfs4.git/common/socks5" package socks5 // import "gitlab.com/yawning/obfs4.git/common/socks5"
import ( import (
"bufio" "bufio"
"bytes" "bytes"
"errors"
"fmt" "fmt"
"io" "io"
"net" "net"
"syscall" "syscall"
"time" "time"
"git.torproject.org/pluggable-transports/goptlib.git" "gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/goptlib"
) )
const ( const (
@ -89,16 +90,16 @@ func Version() string {
// ErrorToReplyCode converts an error to the "best" reply code. // ErrorToReplyCode converts an error to the "best" reply code.
func ErrorToReplyCode(err error) ReplyCode { func ErrorToReplyCode(err error) ReplyCode {
opErr, ok := err.(*net.OpError) var opErr *net.OpError
if !ok { if !errors.As(err, &opErr) {
return ReplyGeneralFailure return ReplyGeneralFailure
} }
errno, ok := opErr.Err.(syscall.Errno) var errno syscall.Errno
if !ok { if !errors.As(opErr.Err, &errno) {
return ReplyGeneralFailure return ReplyGeneralFailure
} }
switch errno { switch errno { //nolint:exhaustive
case syscall.EADDRNOTAVAIL: case syscall.EADDRNOTAVAIL:
return ReplyAddressNotSupported return ReplyAddressNotSupported
case syscall.ETIMEDOUT: case syscall.ETIMEDOUT:
@ -307,7 +308,7 @@ func (req *Request) readCommand() error {
return err return err
} }
addr := make(net.IP, net.IPv6len) addr := make(net.IP, net.IPv6len)
copy(addr[:], rawAddr[:]) copy(addr[:], rawAddr)
host = fmt.Sprintf("[%s]", addr.String()) host = fmt.Sprintf("[%s]", addr.String())
default: default:
_ = req.Reply(ReplyAddressNotSupported) _ = req.Reply(ReplyAddressNotSupported)

@ -48,11 +48,11 @@ type testReadWriter struct {
writeBuf bytes.Buffer writeBuf bytes.Buffer
} }
func (c *testReadWriter) Read(buf []byte) (n int, err error) { func (c *testReadWriter) Read(buf []byte) (int, error) {
return c.readBuf.Read(buf) return c.readBuf.Read(buf)
} }
func (c *testReadWriter) Write(buf []byte) (n int, err error) { func (c *testReadWriter) Write(buf []byte) (int, error) {
return c.writeBuf.Write(buf) return c.writeBuf.Write(buf)
} }
@ -96,11 +96,11 @@ func TestAuthInvalidVersion(t *testing.T) {
// VER = 03, NMETHODS = 01, METHODS = [00] // VER = 03, NMETHODS = 01, METHODS = [00]
c.writeHex("030100") c.writeHex("030100")
if _, err := req.negotiateAuth(); err == nil { if _, err := req.negotiateAuth(); err == nil {
t.Error("negotiateAuth(InvalidVersion) succeded") t.Error("negotiateAuth(InvalidVersion) succeeded")
} }
} }
// TestAuthInvalidNMethods tests auth negotiaton with no methods. // TestAuthInvalidNMethods tests auth negotiation with no methods.
func TestAuthInvalidNMethods(t *testing.T) { func TestAuthInvalidNMethods(t *testing.T) {
c := new(testReadWriter) c := new(testReadWriter)
req := c.toRequest() req := c.toRequest()
@ -120,7 +120,7 @@ func TestAuthInvalidNMethods(t *testing.T) {
} }
} }
// TestAuthNoneRequired tests auth negotiaton with NO AUTHENTICATION REQUIRED. // TestAuthNoneRequired tests auth negotiation with NO AUTHENTICATION REQUIRED.
func TestAuthNoneRequired(t *testing.T) { func TestAuthNoneRequired(t *testing.T) {
c := new(testReadWriter) c := new(testReadWriter)
req := c.toRequest() req := c.toRequest()
@ -230,7 +230,7 @@ func TestRFC1929InvalidVersion(t *testing.T) {
// VER = 03, ULEN = 5, UNAME = "ABCDE", PLEN = 5, PASSWD = "abcde" // VER = 03, ULEN = 5, UNAME = "ABCDE", PLEN = 5, PASSWD = "abcde"
c.writeHex("03054142434445056162636465") c.writeHex("03054142434445056162636465")
if err := req.authenticate(authUsernamePassword); err == nil { if err := req.authenticate(authUsernamePassword); err == nil {
t.Error("authenticate(InvalidVersion) succeded") t.Error("authenticate(InvalidVersion) succeeded")
} }
if msg := c.readHex(); msg != "0101" { if msg := c.readHex(); msg != "0101" {
t.Error("authenticate(InvalidVersion) invalid response:", msg) t.Error("authenticate(InvalidVersion) invalid response:", msg)
@ -245,7 +245,7 @@ func TestRFC1929InvalidUlen(t *testing.T) {
// VER = 01, ULEN = 0, UNAME = "", PLEN = 5, PASSWD = "abcde" // VER = 01, ULEN = 0, UNAME = "", PLEN = 5, PASSWD = "abcde"
c.writeHex("0100056162636465") c.writeHex("0100056162636465")
if err := req.authenticate(authUsernamePassword); err == nil { if err := req.authenticate(authUsernamePassword); err == nil {
t.Error("authenticate(InvalidUlen) succeded") t.Error("authenticate(InvalidUlen) succeeded")
} }
if msg := c.readHex(); msg != "0101" { if msg := c.readHex(); msg != "0101" {
t.Error("authenticate(InvalidUlen) invalid response:", msg) t.Error("authenticate(InvalidUlen) invalid response:", msg)
@ -260,7 +260,7 @@ func TestRFC1929InvalidPlen(t *testing.T) {
// VER = 01, ULEN = 5, UNAME = "ABCDE", PLEN = 0, PASSWD = "" // VER = 01, ULEN = 5, UNAME = "ABCDE", PLEN = 0, PASSWD = ""
c.writeHex("0105414243444500") c.writeHex("0105414243444500")
if err := req.authenticate(authUsernamePassword); err == nil { if err := req.authenticate(authUsernamePassword); err == nil {
t.Error("authenticate(InvalidPlen) succeded") t.Error("authenticate(InvalidPlen) succeeded")
} }
if msg := c.readHex(); msg != "0101" { if msg := c.readHex(); msg != "0101" {
t.Error("authenticate(InvalidPlen) invalid response:", msg) t.Error("authenticate(InvalidPlen) invalid response:", msg)
@ -275,7 +275,7 @@ func TestRFC1929InvalidPTArgs(t *testing.T) {
// VER = 01, ULEN = 5, UNAME = "ABCDE", PLEN = 5, PASSWD = "abcde" // VER = 01, ULEN = 5, UNAME = "ABCDE", PLEN = 5, PASSWD = "abcde"
c.writeHex("01054142434445056162636465") c.writeHex("01054142434445056162636465")
if err := req.authenticate(authUsernamePassword); err == nil { if err := req.authenticate(authUsernamePassword); err == nil {
t.Error("authenticate(InvalidArgs) succeded") t.Error("authenticate(InvalidArgs) succeeded")
} }
if msg := c.readHex(); msg != "0101" { if msg := c.readHex(); msg != "0101" {
t.Error("authenticate(InvalidArgs) invalid response:", msg) t.Error("authenticate(InvalidArgs) invalid response:", msg)
@ -301,7 +301,7 @@ func TestRFC1929Success(t *testing.T) {
} }
} }
// TestRequestInvalidHdr tests SOCKS5 requests with invalid VER/CMD/RSV/ATYPE // TestRequestInvalidHdr tests SOCKS5 requests with invalid VER/CMD/RSV/ATYPE.
func TestRequestInvalidHdr(t *testing.T) { func TestRequestInvalidHdr(t *testing.T) {
c := new(testReadWriter) c := new(testReadWriter)
req := c.toRequest() req := c.toRequest()
@ -309,7 +309,7 @@ func TestRequestInvalidHdr(t *testing.T) {
// VER = 03, CMD = 01, RSV = 00, ATYPE = 01, DST.ADDR = 127.0.0.1, DST.PORT = 9050 // VER = 03, CMD = 01, RSV = 00, ATYPE = 01, DST.ADDR = 127.0.0.1, DST.PORT = 9050
c.writeHex("030100017f000001235a") c.writeHex("030100017f000001235a")
if err := req.readCommand(); err == nil { if err := req.readCommand(); err == nil {
t.Error("readCommand(InvalidVer) succeded") t.Error("readCommand(InvalidVer) succeeded")
} }
if msg := c.readHex(); msg != "05010001000000000000" { if msg := c.readHex(); msg != "05010001000000000000" {
t.Error("readCommand(InvalidVer) invalid response:", msg) t.Error("readCommand(InvalidVer) invalid response:", msg)
@ -319,7 +319,7 @@ func TestRequestInvalidHdr(t *testing.T) {
// VER = 05, CMD = 05, RSV = 00, ATYPE = 01, DST.ADDR = 127.0.0.1, DST.PORT = 9050 // VER = 05, CMD = 05, RSV = 00, ATYPE = 01, DST.ADDR = 127.0.0.1, DST.PORT = 9050
c.writeHex("050500017f000001235a") c.writeHex("050500017f000001235a")
if err := req.readCommand(); err == nil { if err := req.readCommand(); err == nil {
t.Error("readCommand(InvalidCmd) succeded") t.Error("readCommand(InvalidCmd) succeeded")
} }
if msg := c.readHex(); msg != "05070001000000000000" { if msg := c.readHex(); msg != "05070001000000000000" {
t.Error("readCommand(InvalidCmd) invalid response:", msg) t.Error("readCommand(InvalidCmd) invalid response:", msg)
@ -329,7 +329,7 @@ func TestRequestInvalidHdr(t *testing.T) {
// VER = 05, CMD = 01, RSV = 30, ATYPE = 01, DST.ADDR = 127.0.0.1, DST.PORT = 9050 // VER = 05, CMD = 01, RSV = 30, ATYPE = 01, DST.ADDR = 127.0.0.1, DST.PORT = 9050
c.writeHex("050130017f000001235a") c.writeHex("050130017f000001235a")
if err := req.readCommand(); err == nil { if err := req.readCommand(); err == nil {
t.Error("readCommand(InvalidRsv) succeded") t.Error("readCommand(InvalidRsv) succeeded")
} }
if msg := c.readHex(); msg != "05010001000000000000" { if msg := c.readHex(); msg != "05010001000000000000" {
t.Error("readCommand(InvalidRsv) invalid response:", msg) t.Error("readCommand(InvalidRsv) invalid response:", msg)
@ -339,7 +339,7 @@ func TestRequestInvalidHdr(t *testing.T) {
// VER = 05, CMD = 01, RSV = 01, ATYPE = 05, DST.ADDR = 127.0.0.1, DST.PORT = 9050 // VER = 05, CMD = 01, RSV = 01, ATYPE = 05, DST.ADDR = 127.0.0.1, DST.PORT = 9050
c.writeHex("050100057f000001235a") c.writeHex("050100057f000001235a")
if err := req.readCommand(); err == nil { if err := req.readCommand(); err == nil {
t.Error("readCommand(InvalidAtype) succeded") t.Error("readCommand(InvalidAtype) succeeded")
} }
if msg := c.readHex(); msg != "05080001000000000000" { if msg := c.readHex(); msg != "05080001000000000000" {
t.Error("readCommand(InvalidAtype) invalid response:", msg) t.Error("readCommand(InvalidAtype) invalid response:", msg)

@ -32,6 +32,7 @@
package uniformdh // import "gitlab.com/yawning/obfs4.git/common/uniformdh" package uniformdh // import "gitlab.com/yawning/obfs4.git/common/uniformdh"
import ( import (
"bytes"
"fmt" "fmt"
"io" "io"
"math/big" "math/big"
@ -54,8 +55,16 @@ const (
g = 2 g = 2
) )
var modpGroup *big.Int var (
var gen *big.Int modpGroup = func() *big.Int {
n, ok := new(big.Int).SetString(modpStr, 16)
if !ok {
panic("Failed to load the RFC3526 MODP Group")
}
return n
}()
gen = big.NewInt(g)
)
// A PrivateKey represents a UniformDH private key. // A PrivateKey represents a UniformDH private key.
type PrivateKey struct { type PrivateKey struct {
@ -70,14 +79,11 @@ type PublicKey struct {
} }
// Bytes returns the byte representation of a PublicKey. // Bytes returns the byte representation of a PublicKey.
func (pub *PublicKey) Bytes() (pubBytes []byte, err error) { func (pub *PublicKey) Bytes() ([]byte, error) {
if len(pub.bytes) != Size || pub.bytes == nil { if len(pub.bytes) != Size || pub.bytes == nil {
return nil, fmt.Errorf("public key is not initialized") return nil, fmt.Errorf("public key is not initialized")
} }
pubBytes = make([]byte, Size) return bytes.Clone(pub.bytes), nil
copy(pubBytes, pub.bytes)
return
} }
// SetBytes sets the PublicKey from a byte slice. // SetBytes sets the PublicKey from a byte slice.
@ -85,25 +91,22 @@ func (pub *PublicKey) SetBytes(pubBytes []byte) error {
if len(pubBytes) != Size { if len(pubBytes) != Size {
return fmt.Errorf("public key length %d is not %d", len(pubBytes), Size) return fmt.Errorf("public key length %d is not %d", len(pubBytes), Size)
} }
pub.bytes = make([]byte, Size) pub.bytes = bytes.Clone(pubBytes)
copy(pub.bytes, pubBytes)
pub.publicKey = new(big.Int).SetBytes(pub.bytes) pub.publicKey = new(big.Int).SetBytes(pub.bytes)
return nil return nil
} }
// GenerateKey generates a UniformDH keypair using the random source random. // GenerateKey generates a UniformDH keypair using the random source random.
func GenerateKey(random io.Reader) (priv *PrivateKey, err error) { func GenerateKey(random io.Reader) (*PrivateKey, error) {
privBytes := make([]byte, Size) var privBytes [Size]byte
if _, err = io.ReadFull(random, privBytes); err != nil { if _, err := io.ReadFull(random, privBytes[:]); err != nil {
return return nil, err
} }
priv, err = generateKey(privBytes) return generateKey(privBytes[:])
return
} }
func generateKey(privBytes []byte) (priv *PrivateKey, err error) { func generateKey(privBytes []byte) (*PrivateKey, error) {
// This function does all of the actual heavy lifting of creating a public // This function does all of the actual heavy lifting of creating a public
// key from a raw 192 byte private key. It is split so that the KAT tests // key from a raw 192 byte private key. It is split so that the KAT tests
// can be written easily, and not exposed since non-ephemeral keys are a // can be written easily, and not exposed since non-ephemeral keys are a
@ -132,52 +135,26 @@ func generateKey(privBytes []byte) (priv *PrivateKey, err error) {
// to the key so that it is always exactly Size bytes. // to the key so that it is always exactly Size bytes.
pubBytes := make([]byte, Size) pubBytes := make([]byte, Size)
if wasEven { if wasEven {
err = prependZeroBytes(pubBytes, pubBn.Bytes()) pubBn.FillBytes(pubBytes)
} else { } else {
err = prependZeroBytes(pubBytes, pubAlt.Bytes()) pubAlt.FillBytes(pubBytes)
}
if err != nil {
return
} }
priv = new(PrivateKey) priv := new(PrivateKey)
priv.PublicKey.bytes = pubBytes priv.PublicKey.bytes = pubBytes
priv.PublicKey.publicKey = pubBn priv.PublicKey.publicKey = pubBn
priv.privateKey = privBn priv.privateKey = privBn
return return priv, nil
} }
// Handshake generates a shared secret given a PrivateKey and PublicKey. // Handshake generates a shared secret given a PrivateKey and PublicKey.
func Handshake(privateKey *PrivateKey, publicKey *PublicKey) (sharedSecret []byte, err error) { func Handshake(privateKey *PrivateKey, publicKey *PublicKey) ([]byte, error) {
// When a party wants to calculate the shared secret, she raises the // When a party wants to calculate the shared secret, she raises the
// foreign public key to her private key. // foreign public key to her private key.
secretBn := new(big.Int).Exp(publicKey.publicKey, privateKey.privateKey, modpGroup) secretBn := new(big.Int).Exp(publicKey.publicKey, privateKey.privateKey, modpGroup)
sharedSecret = make([]byte, Size) sharedSecret := make([]byte, Size)
err = prependZeroBytes(sharedSecret, secretBn.Bytes()) secretBn.FillBytes(sharedSecret)
return
}
func prependZeroBytes(dst, src []byte) error { return sharedSecret, nil
zeros := len(dst) - len(src)
if zeros < 0 {
return fmt.Errorf("src length is greater than destination: %d", zeros)
}
for i := 0; i < zeros; i++ {
dst[i] = 0
}
copy(dst[zeros:], src)
return nil
}
func init() {
// Load the MODP group and the generator.
var ok bool
modpGroup, ok = new(big.Int).SetString(modpStr, 16)
if !ok {
panic("Failed to load the RFC3526 MODP Group")
}
gen = big.NewInt(g)
} }

@ -101,7 +101,14 @@ const (
"a81359543e77e4a4cfa7598a4152e4c0" "a81359543e77e4a4cfa7598a4152e4c0"
) )
var xPriv, xPub, yPriv, yPub, ss []byte var (
// Load the test vectors into byte slices.
xPriv, _ = hex.DecodeString(xPrivStr)
xPub, _ = hex.DecodeString(xPubStr)
yPriv, _ = hex.DecodeString(yPrivStr)
yPub, _ = hex.DecodeString(yPubStr)
ss, _ = hex.DecodeString(ssStr)
)
// TestGenerateKeyOdd tests creating a UniformDH keypair with a odd private // TestGenerateKeyOdd tests creating a UniformDH keypair with a odd private
// key. // key.
@ -137,7 +144,7 @@ func TestGenerateKeyEven(t *testing.T) {
} }
} }
// TestHandshake tests conductiong a UniformDH handshake with know values. // TestHandshake tests conducting a UniformDH handshake with know values.
func TestHandshake(t *testing.T) { func TestHandshake(t *testing.T) {
xX, err := generateKey(xPriv) xX, err := generateKey(xPriv)
if err != nil { if err != nil {
@ -193,28 +200,3 @@ func BenchmarkHandshake(b *testing.B) {
_ = yX _ = yX
} }
} }
func init() {
// Load the test vectors into byte slices.
var err error
xPriv, err = hex.DecodeString(xPrivStr)
if err != nil {
panic("hex.DecodeString(xPrivStr) failed")
}
xPub, err = hex.DecodeString(xPubStr)
if err != nil {
panic("hex.DecodeString(xPubStr) failed")
}
yPriv, err = hex.DecodeString(yPrivStr)
if err != nil {
panic("hex.DecodeString(yPrivStr) failed")
}
yPub, err = hex.DecodeString(yPubStr)
if err != nil {
panic("hex.DecodeString(yPubStr) failed")
}
ss, err = hex.DecodeString(ssStr)
if err != nil {
panic("hex.DecodeString(ssStr) failed")
}
}

@ -2,11 +2,13 @@ module gitlab.com/yawning/obfs4.git
require ( require (
filippo.io/edwards25519 v1.0.0 filippo.io/edwards25519 v1.0.0
git.torproject.org/pluggable-transports/goptlib.git v1.3.0
github.com/dchest/siphash v1.2.3 github.com/dchest/siphash v1.2.3
gitlab.com/yawning/edwards25519-extra.git v0.0.0-20211229043746-2f91fcc9fbdb gitlab.com/yawning/edwards25519-extra.git v0.0.0-20220726154925-def713fd18e4
golang.org/x/crypto v0.9.0 gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/goptlib v1.4.0
golang.org/x/net v0.10.0 golang.org/x/crypto v0.11.0
golang.org/x/net v0.12.0
) )
require golang.org/x/sys v0.10.0 // indirect
go 1.20 go 1.20

@ -1,48 +1,22 @@
filippo.io/edwards25519 v1.0.0-rc.1.0.20210721174708-390f27c3be20/go.mod h1:N1IkdkCkiLB6tki+MYJoSx2JTY9NUlxZE7eHn5EwJns= filippo.io/edwards25519 v1.0.0-rc.1.0.20210721174708-390f27c3be20/go.mod h1:N1IkdkCkiLB6tki+MYJoSx2JTY9NUlxZE7eHn5EwJns=
filippo.io/edwards25519 v1.0.0 h1:0wAIcmJUqRdI8IJ/3eGi5/HwXZWPujYXXlkrQogz0Ek= filippo.io/edwards25519 v1.0.0 h1:0wAIcmJUqRdI8IJ/3eGi5/HwXZWPujYXXlkrQogz0Ek=
filippo.io/edwards25519 v1.0.0/go.mod h1:N1IkdkCkiLB6tki+MYJoSx2JTY9NUlxZE7eHn5EwJns= filippo.io/edwards25519 v1.0.0/go.mod h1:N1IkdkCkiLB6tki+MYJoSx2JTY9NUlxZE7eHn5EwJns=
git.torproject.org/pluggable-transports/goptlib.git v1.3.0 h1:G+iuRUblCCC2xnO+0ag1/4+aaM98D5mjWP1M0v9s8a0=
git.torproject.org/pluggable-transports/goptlib.git v1.3.0/go.mod h1:4PBMl1dg7/3vMWSoWb46eGWlrxkUyn/CAJmxhDLAlDs=
github.com/dchest/siphash v1.2.3 h1:QXwFc8cFOR2dSa/gE6o/HokBMWtLUaNDVd+22aKHeEA= github.com/dchest/siphash v1.2.3 h1:QXwFc8cFOR2dSa/gE6o/HokBMWtLUaNDVd+22aKHeEA=
github.com/dchest/siphash v1.2.3/go.mod h1:0NvQU092bT0ipiFN++/rXm69QG9tVxLAlQHIXMPAkHc= github.com/dchest/siphash v1.2.3/go.mod h1:0NvQU092bT0ipiFN++/rXm69QG9tVxLAlQHIXMPAkHc=
github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= gitlab.com/yawning/edwards25519-extra.git v0.0.0-20220726154925-def713fd18e4 h1:LeXiZggivkDGgmkl7+r+m/2xj3rd+K/30/0obRKayAU=
gitlab.com/yawning/edwards25519-extra.git v0.0.0-20211229043746-2f91fcc9fbdb h1:qRSZHsODmAP5qDvb3YsO7Qnf3TRiVbGxNG/WYnlM4/o= gitlab.com/yawning/edwards25519-extra.git v0.0.0-20220726154925-def713fd18e4/go.mod h1:gvdJuZuO/tPZyhEV8K3Hmoxv/DWud5L4qEQxfYjEUTo=
gitlab.com/yawning/edwards25519-extra.git v0.0.0-20211229043746-2f91fcc9fbdb/go.mod h1:gvdJuZuO/tPZyhEV8K3Hmoxv/DWud5L4qEQxfYjEUTo= gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/goptlib v1.4.0 h1:Y7fHDMy11yyjM+YlHfcM3svaujdL+m5DqS444wbj8o4=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/goptlib v1.4.0/go.mod h1:70bhd4JKW/+1HLfm+TMrgHJsUHG4coelMWwiVEJ2gAg=
golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.11.0 h1:6Ewdq3tDic1mg5xRO4milcWCfMVQhI4NkqWWvqejpuA=
golang.org/x/crypto v0.9.0 h1:LF6fAI+IutBocDJ2OT0Q1g8plpYljMZ4+lty+dsqw3g= golang.org/x/crypto v0.11.0/go.mod h1:xgJhtzW8F9jGdVFWZESrid1U1bjeNy4zgy5cRr/CIio=
golang.org/x/crypto v0.9.0/go.mod h1:yrmDGqONDYtNj3tH8X9dzUun2m2lzPa9ngI6/RUPGR0=
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= golang.org/x/net v0.12.0 h1:cfawfvKITfUsFCeJIHJrbSxpeu/E81khclypR0GVT50=
golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= golang.org/x/net v0.12.0/go.mod h1:zEVYFnQC7m/vmpQFELhcD1EWkZlX69l4oqgmer6hfKA=
golang.org/x/net v0.10.0 h1:X2//UzNDwYmtCLn7To6G58Wr6f5ahEAQgKNzv9Y951M=
golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.10.0 h1:SqMFp9UcQJZa+pmYuAKjd9xq1f0j5rLcDIk0mj4qAsA=
golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.8.0 h1:EBmGv8NaZBZTWvrbjNoL6HVt+IVy3QDQpJs7VRIw3tU=
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k=
golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc=
golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=

@ -22,7 +22,6 @@ import (
"filippo.io/edwards25519" "filippo.io/edwards25519"
"filippo.io/edwards25519/field" "filippo.io/edwards25519/field"
"gitlab.com/yawning/edwards25519-extra.git/elligator2" "gitlab.com/yawning/edwards25519-extra.git/elligator2"
) )
@ -52,7 +51,7 @@ var (
0xbb, 0x4a, 0xde, 0x38, 0x32, 0x99, 0x33, 0xe9, 0x28, 0x4a, 0x39, 0x06, 0xa0, 0xb9, 0xd5, 0x1f, 0xbb, 0x4a, 0xde, 0x38, 0x32, 0x99, 0x33, 0xe9, 0x28, 0x4a, 0x39, 0x06, 0xa0, 0xb9, 0xd5, 0x1f,
}) })
// Low order point Edwards y-coordinate `-lop_x * sqrtm1` // Low order point Edwards y-coordinate `-lop_x * sqrtm1`.
feLopY = mustFeFromBytes([]byte{ feLopY = mustFeFromBytes([]byte{
0x26, 0xe8, 0x95, 0x8f, 0xc2, 0xb2, 0x27, 0xb0, 0x45, 0xc3, 0xf4, 0x89, 0xf2, 0xef, 0x98, 0xf0, 0x26, 0xe8, 0x95, 0x8f, 0xc2, 0xb2, 0x27, 0xb0, 0x45, 0xc3, 0xf4, 0x89, 0xf2, 0xef, 0x98, 0xf0,
0xd5, 0xdf, 0xac, 0x05, 0xd3, 0xc6, 0x33, 0x39, 0xb1, 0x38, 0x02, 0x88, 0x6d, 0x53, 0xfc, 0x05, 0xd5, 0xdf, 0xac, 0x05, 0xd3, 0xc6, 0x33, 0x39, 0xb1, 0x38, 0x02, 0x88, 0x6d, 0x53, 0xfc, 0x05,

@ -41,12 +41,13 @@ import (
"sync" "sync"
"syscall" "syscall"
"git.torproject.org/pluggable-transports/goptlib.git" "gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/goptlib"
"golang.org/x/net/proxy"
"gitlab.com/yawning/obfs4.git/common/log" "gitlab.com/yawning/obfs4.git/common/log"
"gitlab.com/yawning/obfs4.git/common/socks5" "gitlab.com/yawning/obfs4.git/common/socks5"
"gitlab.com/yawning/obfs4.git/transports" "gitlab.com/yawning/obfs4.git/transports"
"gitlab.com/yawning/obfs4.git/transports/base" "gitlab.com/yawning/obfs4.git/transports/base"
"golang.org/x/net/proxy"
) )
const ( const (
@ -55,23 +56,27 @@ const (
socksAddr = "127.0.0.1:0" socksAddr = "127.0.0.1:0"
) )
var stateDir string var (
var termMon *termMonitor stateDir string
termMon *termMonitor
)
func clientSetup() (launched bool, listeners []net.Listener) { func clientSetup() (bool, []net.Listener) {
ptClientInfo, err := pt.ClientSetup(transports.Transports()) ptClientInfo, err := pt.ClientSetup(transports.Transports())
if err != nil { if err != nil {
golog.Fatal(err) golog.Fatal(err)
} }
ptClientProxy, err := ptGetProxy() ptClientProxy, err := ptGetProxy(&ptClientInfo)
if err != nil { if err != nil {
golog.Fatal(err) golog.Fatal(err)
} else if ptClientProxy != nil { } else if ptClientProxy != nil {
ptProxyDone() pt.ProxyDone()
} }
// Launch each of the client listeners. // Launch each of the client listeners.
var launched bool
listeners := make([]net.Listener, 0, len(ptClientInfo.MethodNames))
for _, name := range ptClientInfo.MethodNames { for _, name := range ptClientInfo.MethodNames {
t := transports.Get(name) t := transports.Get(name)
if t == nil { if t == nil {
@ -103,7 +108,7 @@ func clientSetup() (launched bool, listeners []net.Listener) {
} }
pt.CmethodsDone() pt.CmethodsDone()
return return launched, listeners
} }
func clientAcceptLoop(f base.ClientFactory, ln net.Listener, proxyURI *url.URL) error { func clientAcceptLoop(f base.ClientFactory, ln net.Listener, proxyURI *url.URL) error {
@ -111,10 +116,7 @@ func clientAcceptLoop(f base.ClientFactory, ln net.Listener, proxyURI *url.URL)
for { for {
conn, err := ln.Accept() conn, err := ln.Accept()
if err != nil { if err != nil {
if e, ok := err.(net.Error); ok && !e.Temporary() { return err
return err
}
continue
} }
go clientHandler(f, conn, proxyURI) go clientHandler(f, conn, proxyURI)
} }
@ -176,12 +178,14 @@ func clientHandler(f base.ClientFactory, conn net.Conn, proxyURI *url.URL) {
} }
} }
func serverSetup() (launched bool, listeners []net.Listener) { func serverSetup() (bool, []net.Listener) {
ptServerInfo, err := pt.ServerSetup(transports.Transports()) ptServerInfo, err := pt.ServerSetup(transports.Transports())
if err != nil { if err != nil {
golog.Fatal(err) golog.Fatal(err)
} }
var launched bool
listeners := make([]net.Listener, 0, len(ptServerInfo.Bindaddrs))
for _, bindaddr := range ptServerInfo.Bindaddrs { for _, bindaddr := range ptServerInfo.Bindaddrs {
name := bindaddr.MethodName name := bindaddr.MethodName
t := transports.Get(name) t := transports.Get(name)
@ -218,7 +222,7 @@ func serverSetup() (launched bool, listeners []net.Listener) {
} }
pt.SmethodsDone() pt.SmethodsDone()
return return launched, listeners
} }
func serverAcceptLoop(f base.ServerFactory, ln net.Listener, info *pt.ServerInfo) error { func serverAcceptLoop(f base.ServerFactory, ln net.Listener, info *pt.ServerInfo) error {
@ -226,10 +230,7 @@ func serverAcceptLoop(f base.ServerFactory, ln net.Listener, info *pt.ServerInfo
for { for {
conn, err := ln.Accept() conn, err := ln.Accept()
if err != nil { if err != nil {
if e, ok := err.(net.Error); ok && !e.Temporary() { return err
return err
}
continue
} }
go serverHandler(f, conn, info) go serverHandler(f, conn, info)
} }
@ -317,7 +318,7 @@ func main() {
flag.Parse() flag.Parse()
if *showVer { if *showVer {
fmt.Printf("%s\n", getVersion()) fmt.Printf("%s\n", getVersion()) //nolint:forbidigo
os.Exit(0) os.Exit(0)
} }
if err := log.SetLogLevel(*logLevelStr); err != nil { if err := log.SetLogLevel(*logLevelStr); err != nil {

@ -30,6 +30,7 @@ package main
import ( import (
"bufio" "bufio"
"encoding/base64" "encoding/base64"
"errors"
"fmt" "fmt"
"net" "net"
"net/http" "net/http"
@ -69,14 +70,14 @@ func (s *httpProxy) Dial(network, addr string) (net.Conn, error) {
return nil, err return nil, err
} }
conn := new(httpConn) conn := new(httpConn)
conn.httpConn = httputil.NewClientConn(c, nil) // nolint: staticcheck conn.httpConn = httputil.NewClientConn(c, nil) //nolint:staticcheck
conn.remoteAddr, err = net.ResolveTCPAddr(network, addr) conn.remoteAddr, err = net.ResolveTCPAddr(network, addr)
if err != nil { if err != nil {
conn.httpConn.Close() conn.httpConn.Close()
return nil, err return nil, err
} }
// HACK HACK HACK HACK. http.ReadRequest also does this. // HACK: http.ReadRequest also does this.
reqURL, err := url.Parse("http://" + addr) reqURL, err := url.Parse("http://" + addr)
if err != nil { if err != nil {
conn.httpConn.Close() conn.httpConn.Close()
@ -84,7 +85,7 @@ func (s *httpProxy) Dial(network, addr string) (net.Conn, error) {
} }
reqURL.Scheme = "" reqURL.Scheme = ""
req, err := http.NewRequest("CONNECT", reqURL.String(), nil) req, err := http.NewRequest(http.MethodConnect, reqURL.String(), nil)
if err != nil { if err != nil {
conn.httpConn.Close() conn.httpConn.Close()
return nil, err return nil, err
@ -93,16 +94,16 @@ func (s *httpProxy) Dial(network, addr string) (net.Conn, error) {
if s.haveAuth { if s.haveAuth {
// SetBasicAuth doesn't quite do what is appropriate, because // SetBasicAuth doesn't quite do what is appropriate, because
// the correct header is `Proxy-Authorization`. // the correct header is `Proxy-Authorization`.
req.Header.Set("Proxy-Authorization", "Basic " + base64.StdEncoding.EncodeToString([]byte(s.username+":"+s.password))) req.Header.Set("Proxy-Authorization", "Basic "+base64.StdEncoding.EncodeToString([]byte(s.username+":"+s.password)))
} }
req.Header.Set("User-Agent", "") req.Header.Set("User-Agent", "")
resp, err := conn.httpConn.Do(req) resp, err := conn.httpConn.Do(req)
if err != nil && err != httputil.ErrPersistEOF { // nolint: staticcheck if err != nil && !errors.Is(err, httputil.ErrPersistEOF) { //nolint:staticcheck
conn.httpConn.Close() conn.httpConn.Close()
return nil, err return nil, err
} }
if resp.StatusCode != 200 { if resp.StatusCode != http.StatusOK {
conn.httpConn.Close() conn.httpConn.Close()
return nil, fmt.Errorf("proxy error: %s", resp.Status) return nil, fmt.Errorf("proxy error: %s", resp.Status)
} }
@ -113,7 +114,7 @@ func (s *httpProxy) Dial(network, addr string) (net.Conn, error) {
type httpConn struct { type httpConn struct {
remoteAddr *net.TCPAddr remoteAddr *net.TCPAddr
httpConn *httputil.ClientConn // nolint: staticcheck httpConn *httputil.ClientConn //nolint:staticcheck
hijackedConn net.Conn hijackedConn net.Conn
staleReader *bufio.Reader staleReader *bufio.Reader
} }
@ -156,6 +157,6 @@ func (c *httpConn) SetWriteDeadline(t time.Time) error {
return c.hijackedConn.SetWriteDeadline(t) return c.hijackedConn.SetWriteDeadline(t)
} }
func init() { func init() { //nolint:gochecknoinits
proxy.RegisterDialerType("http", newHTTP) proxy.RegisterDialerType("http", newHTTP)
} }

@ -150,7 +150,7 @@ func socks4ErrorToString(code byte) string {
case socks4Rejected: case socks4Rejected:
return "request rejected or failed" return "request rejected or failed"
case socks4RejectedIdentdFailed: case socks4RejectedIdentdFailed:
return "request rejected becasue SOCKS server cannot connect to identd on the client" return "request rejected because SOCKS server cannot connect to identd on the client"
case socks4RejectedIdentdMismatch: case socks4RejectedIdentdMismatch:
return "request rejected because the client program and identd report different user-ids" return "request rejected because the client program and identd report different user-ids"
default: default:
@ -158,7 +158,7 @@ func socks4ErrorToString(code byte) string {
} }
} }
func init() { func init() { //nolint:gochecknoinits
// Despite the scheme name, this really is SOCKS4. // Despite the scheme name, this really is SOCKS4.
proxy.RegisterDialerType("socks4a", newSOCKS4) proxy.RegisterDialerType("socks4a", newSOCKS4)
} }

@ -35,11 +35,11 @@ import (
"os" "os"
"strconv" "strconv"
"git.torproject.org/pluggable-transports/goptlib.git" "gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/goptlib"
) )
// This file contains things that probably should be in goptlib but are not // This file contains things that probably should be in goptlib but are not
// yet or are not finalized. // yet or not exposed.
func ptEnvError(msg string) error { func ptEnvError(msg string) error {
line := []byte(fmt.Sprintf("ENV-ERROR %s\n", msg)) line := []byte(fmt.Sprintf("ENV-ERROR %s\n", msg))
@ -47,89 +47,61 @@ func ptEnvError(msg string) error {
return errors.New(msg) return errors.New(msg)
} }
func ptProxyError(msg string) error {
line := []byte(fmt.Sprintf("PROXY-ERROR %s\n", msg))
_, _ = pt.Stdout.Write(line)
return errors.New(msg)
}
func ptProxyDone() {
line := []byte("PROXY DONE\n")
_, _ = pt.Stdout.Write(line)
}
func ptIsClient() (bool, error) { func ptIsClient() (bool, error) {
clientEnv := os.Getenv("TOR_PT_CLIENT_TRANSPORTS") clientEnv := os.Getenv("TOR_PT_CLIENT_TRANSPORTS")
serverEnv := os.Getenv("TOR_PT_SERVER_TRANSPORTS") serverEnv := os.Getenv("TOR_PT_SERVER_TRANSPORTS")
if clientEnv != "" && serverEnv != "" { switch {
case clientEnv != "" && serverEnv != "":
return false, ptEnvError("TOR_PT_[CLIENT,SERVER]_TRANSPORTS both set") return false, ptEnvError("TOR_PT_[CLIENT,SERVER]_TRANSPORTS both set")
} else if clientEnv != "" { case clientEnv != "":
return true, nil return true, nil
} else if serverEnv != "" { case serverEnv != "":
return false, nil return false, nil
} }
return false, errors.New("not launched as a managed transport") return false, errors.New("not launched as a managed transport")
} }
func ptGetProxy() (*url.URL, error) { func ptGetProxy(info *pt.ClientInfo) (*url.URL, error) {
specString := os.Getenv("TOR_PT_PROXY") proxyURL := info.ProxyURL
if specString == "" { if proxyURL == nil {
return nil, nil return nil, nil //nolint:nilnil
}
spec, err := url.Parse(specString)
if err != nil {
return nil, ptProxyError(fmt.Sprintf("failed to parse proxy config: %s", err))
} }
// Validate the TOR_PT_PROXY uri. // Validate the arguments.
if !spec.IsAbs() { switch proxyURL.Scheme {
return nil, ptProxyError("proxy URI is relative, must be absolute")
}
if spec.Path != "" {
return nil, ptProxyError("proxy URI has a path defined")
}
if spec.RawQuery != "" {
return nil, ptProxyError("proxy URI has a query defined")
}
if spec.Fragment != "" {
return nil, ptProxyError("proxy URI has a fragment defined")
}
switch spec.Scheme {
case "http": case "http":
// The most forgiving of proxies. // The most forgiving of proxies.
case "socks4a": case "socks4a":
if spec.User != nil { if proxyURL.User != nil {
_, isSet := spec.User.Password() _, isSet := proxyURL.User.Password()
if isSet { if isSet {
return nil, ptProxyError("proxy URI specified SOCKS4a and a password") return nil, pt.ProxyError("proxy URI proxyURLified SOCKS4a and a password")
} }
} }
case "socks5": case "socks5":
if spec.User != nil { if proxyURL.User != nil {
// UNAME/PASSWD both must be between 1 and 255 bytes long. (RFC1929) // UNAME/PASSWD both must be between 1 and 255 bytes long. (RFC1929)
user := spec.User.Username() user := proxyURL.User.Username()
passwd, isSet := spec.User.Password() passwd, isSet := proxyURL.User.Password()
if len(user) < 1 || len(user) > 255 { if len(user) < 1 || len(user) > 255 {
return nil, ptProxyError("proxy URI specified a invalid SOCKS5 username") return nil, pt.ProxyError("proxy URI proxyURLified a invalid SOCKS5 username")
} }
if !isSet || len(passwd) < 1 || len(passwd) > 255 { if !isSet || len(passwd) < 1 || len(passwd) > 255 {
return nil, ptProxyError("proxy URI specified a invalid SOCKS5 password") return nil, pt.ProxyError("proxy URI proxyURLified a invalid SOCKS5 password")
} }
} }
default: default:
return nil, ptProxyError(fmt.Sprintf("proxy URI has invalid scheme: %s", spec.Scheme)) return nil, pt.ProxyError(fmt.Sprintf("proxy URI has invalid scheme: %s", proxyURL.Scheme))
} }
_, err = resolveAddrStr(spec.Host) if _, err := resolveAddrStr(proxyURL.Host); err != nil {
if err != nil { return nil, pt.ProxyError(fmt.Sprintf("proxy URI has invalid host: %s", err))
return nil, ptProxyError(fmt.Sprintf("proxy URI has invalid host: %s", err))
} }
return spec, nil return proxyURL, nil
} }
// Sigh, pt.resolveAddr() isn't exported. Include our own getto version that // Sigh, pt.resolveAddr() isn't exported. Include our own getto version that

@ -29,7 +29,6 @@ package main
import ( import (
"io" "io"
"io/ioutil"
"os" "os"
"os/signal" "os/signal"
"runtime" "runtime"
@ -73,7 +72,7 @@ func (m *termMonitor) wait(termOnNoHandlers bool) os.Signal {
} }
func (m *termMonitor) termOnStdinClose() { func (m *termMonitor) termOnStdinClose() {
_, err := io.Copy(ioutil.Discard, os.Stdin) _, err := io.Copy(io.Discard, os.Stdin)
// io.Copy() will return a nil on EOF, since reaching EOF is // io.Copy() will return a nil on EOF, since reaching EOF is
// expected behavior. No matter what, if this unblocks, assume // expected behavior. No matter what, if this unblocks, assume
@ -103,9 +102,9 @@ func (m *termMonitor) termOnPPIDChange(ppid int) {
m.sigChan <- syscall.SIGTERM m.sigChan <- syscall.SIGTERM
} }
func newTermMonitor() (m *termMonitor) { func newTermMonitor() *termMonitor {
ppid := os.Getppid() ppid := os.Getppid()
m = new(termMonitor) m := new(termMonitor)
m.sigChan = make(chan os.Signal) m.sigChan = make(chan os.Signal)
m.handlerChan = make(chan int) m.handlerChan = make(chan int)
signal.Notify(m.sigChan, syscall.SIGINT, syscall.SIGTERM) signal.Notify(m.sigChan, syscall.SIGINT, syscall.SIGTERM)
@ -113,7 +112,7 @@ func newTermMonitor() (m *termMonitor) {
// If tor supports feature #15435, we can use Stdin being closed as an // If tor supports feature #15435, we can use Stdin being closed as an
// indication that tor has died, or wants the PT to shutdown for any // indication that tor has died, or wants the PT to shutdown for any
// reason. // reason.
if ptShouldExitOnStdinClose() { if ptShouldExitOnStdinClose() { //nolint:nestif
go m.termOnStdinClose() go m.termOnStdinClose()
} else { } else {
// Instead of feature #15435, use various kludges and hacks: // Instead of feature #15435, use various kludges and hacks:
@ -124,12 +123,12 @@ func newTermMonitor() (m *termMonitor) {
// Errors here are non-fatal, since it might still be // Errors here are non-fatal, since it might still be
// possible to fall back to a generic implementation. // possible to fall back to a generic implementation.
if err := termMonitorOSInit(m); err == nil { if err := termMonitorOSInit(m); err == nil {
return return m
} }
} }
if runtime.GOOS != "windows" { if runtime.GOOS != "windows" {
go m.termOnPPIDChange(ppid) go m.termOnPPIDChange(ppid)
} }
} }
return return m
} }

@ -32,18 +32,18 @@ import (
"syscall" "syscall"
) )
func termMonitorInitLinux(m *termMonitor) error { func termMonitorInitLinux(_ *termMonitor) error {
// Use prctl() to have the kernel deliver a SIGTERM if the parent // Use prctl() to have the kernel deliver a SIGTERM if the parent
// process dies. This beats anything else that can be done before // process dies. This beats anything else that can be done before
// #15435 is implemented. // #15435 is implemented.
_, _, errno := syscall.Syscall(syscall.SYS_PRCTL, syscall.PR_SET_PDEATHSIG, uintptr(syscall.SIGTERM), 0) _, _, errno := syscall.Syscall(syscall.SYS_PRCTL, syscall.PR_SET_PDEATHSIG, uintptr(syscall.SIGTERM), 0)
if errno != 0 { if errno != 0 {
var err error = errno var err error = errno
return fmt.Errorf("prctl(PR_SET_PDEATHSIG, SIGTERM) returned: %s", err) return fmt.Errorf("prctl(PR_SET_PDEATHSIG, SIGTERM) returned: %w", err)
} }
return nil return nil
} }
func init() { func init() { //nolint:gochecknoinits
termMonitorOSInit = termMonitorInitLinux termMonitorOSInit = termMonitorInitLinux
} }

@ -32,7 +32,7 @@ package base // import "gitlab.com/yawning/obfs4.git/transports/base"
import ( import (
"net" "net"
"git.torproject.org/pluggable-transports/goptlib.git" "gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/goptlib"
) )
type DialFunc func(string, string) (net.Conn, error) type DialFunc func(string, string) (net.Conn, error)
@ -48,12 +48,12 @@ type ClientFactory interface {
// for use with WrapConn. This routine is called before the outgoing // for use with WrapConn. This routine is called before the outgoing
// TCP/IP connection is created to allow doing things (like keypair // TCP/IP connection is created to allow doing things (like keypair
// generation) to be hidden from third parties. // generation) to be hidden from third parties.
ParseArgs(args *pt.Args) (interface{}, error) ParseArgs(args *pt.Args) (any, error)
// Dial creates an outbound net.Conn, and does whatever is required // Dial creates an outbound net.Conn, and does whatever is required
// (eg: handshaking) to get the connection to the point where it is // (eg: handshaking) to get the connection to the point where it is
// ready to relay data. // ready to relay data.
Dial(network, address string, dialFn DialFunc, args interface{}) (net.Conn, error) Dial(network, address string, dialFn DialFunc, args any) (net.Conn, error)
} }
// ServerFactory is the interface that defines the factory for creating // ServerFactory is the interface that defines the factory for creating

@ -36,7 +36,8 @@ import (
"fmt" "fmt"
"net" "net"
"git.torproject.org/pluggable-transports/goptlib.git" "gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/goptlib"
"gitlab.com/yawning/obfs4.git/transports/base" "gitlab.com/yawning/obfs4.git/transports/base"
) )
@ -51,15 +52,13 @@ func (t *Transport) Name() string {
} }
// ClientFactory returns a new meekClientFactory instance. // ClientFactory returns a new meekClientFactory instance.
func (t *Transport) ClientFactory(stateDir string) (base.ClientFactory, error) { func (t *Transport) ClientFactory(_ string) (base.ClientFactory, error) {
cf := &meekClientFactory{transport: t} cf := &meekClientFactory{transport: t}
return cf, nil return cf, nil
} }
// ServerFactory will one day return a new meekServerFactory instance. // ServerFactory will one day return a new meekServerFactory instance.
func (t *Transport) ServerFactory(stateDir string, args *pt.Args) (base.ServerFactory, error) { func (t *Transport) ServerFactory(_ string, _ *pt.Args) (base.ServerFactory, error) {
// TODO: Fill this in eventually, though for servers people should
// just use the real thing.
return nil, fmt.Errorf("server not supported") return nil, fmt.Errorf("server not supported")
} }
@ -71,18 +70,18 @@ func (cf *meekClientFactory) Transport() base.Transport {
return cf.transport return cf.transport
} }
func (cf *meekClientFactory) ParseArgs(args *pt.Args) (interface{}, error) { func (cf *meekClientFactory) ParseArgs(args *pt.Args) (any, error) {
return newClientArgs(args) return newClientArgs(args)
} }
func (cf *meekClientFactory) Dial(network, addr string, dialFn base.DialFunc, args interface{}) (net.Conn, error) { func (cf *meekClientFactory) Dial(_, _ string, dialFn base.DialFunc, args any) (net.Conn, error) {
// Validate args before opening outgoing connection. // Validate args before opening outgoing connection.
ca, ok := args.(*meekClientArgs) ca, ok := args.(*meekClientArgs)
if !ok { if !ok {
return nil, fmt.Errorf("invalid argument type for args") return nil, fmt.Errorf("invalid argument type for args")
} }
return newMeekConn(network, addr, dialFn, ca) return newMeekConn(dialFn, ca)
} }
var ( var (

@ -35,7 +35,6 @@ import (
"errors" "errors"
"fmt" "fmt"
"io" "io"
"io/ioutil"
"net" "net"
"net/http" "net/http"
gourl "net/url" gourl "net/url"
@ -44,7 +43,8 @@ import (
"sync" "sync"
"time" "time"
"git.torproject.org/pluggable-transports/goptlib.git" "gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/goptlib"
"gitlab.com/yawning/obfs4.git/transports/base" "gitlab.com/yawning/obfs4.git/transports/base"
) )
@ -83,8 +83,11 @@ func (ca *meekClientArgs) String() string {
return transportName + ":" + ca.front + ":" + ca.url.String() return transportName + ":" + ca.front + ":" + ca.url.String()
} }
func newClientArgs(args *pt.Args) (ca *meekClientArgs, err error) { func newClientArgs(args *pt.Args) (*meekClientArgs, error) {
ca = &meekClientArgs{} var (
ca meekClientArgs
err error
)
// Parse the URL argument. // Parse the URL argument.
str, ok := args.Get(urlArg) str, ok := args.Get(urlArg)
@ -104,7 +107,7 @@ func newClientArgs(args *pt.Args) (ca *meekClientArgs, err error) {
// Parse the (optional) front argument. // Parse the (optional) front argument.
ca.front, _ = args.Get(frontArg) ca.front, _ = args.Get(frontArg)
return ca, nil return &ca, nil
} }
type meekConn struct { type meekConn struct {
@ -119,18 +122,18 @@ type meekConn struct {
rdBuf *bytes.Buffer rdBuf *bytes.Buffer
} }
func (c *meekConn) Read(p []byte) (n int, err error) { func (c *meekConn) Read(p []byte) (int, error) {
// If there is data left over from the previous read, // If there is data left over from the previous read,
// service the request using the buffered data. // service the request using the buffered data.
if c.rdBuf != nil { if c.rdBuf != nil {
if c.rdBuf.Len() == 0 { if c.rdBuf.Len() == 0 {
panic("empty read buffer") panic("empty read buffer")
} }
n, err = c.rdBuf.Read(p) n, err := c.rdBuf.Read(p)
if c.rdBuf.Len() == 0 { if c.rdBuf.Len() == 0 {
c.rdBuf = nil c.rdBuf = nil
} }
return return n, err
} }
// Wait for the worker to enqueue more incoming data. // Wait for the worker to enqueue more incoming data.
@ -142,16 +145,16 @@ func (c *meekConn) Read(p []byte) (n int, err error) {
// Ew, an extra copy, but who am I kidding, it's meek. // Ew, an extra copy, but who am I kidding, it's meek.
buf := bytes.NewBuffer(b) buf := bytes.NewBuffer(b)
n, err = buf.Read(p) n, err := buf.Read(p)
if buf.Len() > 0 { if buf.Len() > 0 {
// If there's data pending, stash the buffer so the next // If there's data pending, stash the buffer so the next
// Read() call will use it to fulfuill the Read(). // Read() call will use it to fulfuill the Read().
c.rdBuf = buf c.rdBuf = buf
} }
return return n, err
} }
func (c *meekConn) Write(b []byte) (n int, err error) { func (c *meekConn) Write(b []byte) (int, error) {
// Check to see if the connection is actually open. // Check to see if the connection is actually open.
select { select {
case <-c.workerCloseChan: case <-c.workerCloseChan:
@ -196,19 +199,19 @@ func (c *meekConn) RemoteAddr() net.Addr {
return c.args return c.args
} }
func (c *meekConn) SetDeadline(t time.Time) error { func (c *meekConn) SetDeadline(_ time.Time) error {
return ErrNotSupported return ErrNotSupported
} }
func (c *meekConn) SetReadDeadline(t time.Time) error { func (c *meekConn) SetReadDeadline(_ time.Time) error {
return ErrNotSupported return ErrNotSupported
} }
func (c *meekConn) SetWriteDeadline(t time.Time) error { func (c *meekConn) SetWriteDeadline(_ time.Time) error {
return ErrNotSupported return ErrNotSupported
} }
func (c *meekConn) enqueueWrite(b []byte) (ok bool) { func (c *meekConn) enqueueWrite(b []byte) (ok bool) { //nolint:nonamedreturns
defer func() { defer func() {
if err := recover(); err != nil { if err := recover(); err != nil {
ok = false ok = false
@ -218,21 +221,26 @@ func (c *meekConn) enqueueWrite(b []byte) (ok bool) {
return true return true
} }
func (c *meekConn) roundTrip(sndBuf []byte) (recvBuf []byte, err error) { func (c *meekConn) roundTrip(sndBuf []byte) ([]byte, error) {
var req *http.Request var (
var resp *http.Response req *http.Request
resp *http.Response
err error
)
url := *c.args.url
host := url.Host
if c.args.front != "" {
url.Host = c.args.front
}
urlStr := url.String()
for retries := 0; retries < maxRetries; retries++ { for retries := 0; retries < maxRetries; retries++ {
url := *c.args.url
host := url.Host
if c.args.front != "" {
url.Host = c.args.front
}
var body io.Reader var body io.Reader
if len(sndBuf) > 0 { if len(sndBuf) > 0 {
body = bytes.NewReader(sndBuf) body = bytes.NewReader(sndBuf)
} }
req, err = http.NewRequest("POST", url.String(), body) req, err = http.NewRequest(http.MethodPost, urlStr, body)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -248,16 +256,17 @@ func (c *meekConn) roundTrip(sndBuf []byte) (recvBuf []byte, err error) {
} }
if resp.StatusCode == http.StatusOK { if resp.StatusCode == http.StatusOK {
recvBuf, err = ioutil.ReadAll(io.LimitReader(resp.Body, maxPayloadLength)) var recvBuf []byte
recvBuf, err = io.ReadAll(io.LimitReader(resp.Body, maxPayloadLength))
resp.Body.Close() resp.Body.Close()
return return recvBuf, err
} }
resp.Body.Close() resp.Body.Close()
err = fmt.Errorf("status code was %d, not %d", resp.StatusCode, http.StatusOK) err = fmt.Errorf("status code was %d, not %d", resp.StatusCode, http.StatusOK)
time.Sleep(retryDelay) time.Sleep(retryDelay)
} }
return return nil, err
} }
func (c *meekConn) ioWorker() { func (c *meekConn) ioWorker() {
@ -305,19 +314,20 @@ loop:
} }
// Determine the next poll interval. // Determine the next poll interval.
if len(rdBuf) > 0 { switch {
case len(rdBuf) > 0:
// Received data, enqueue the read. // Received data, enqueue the read.
c.workerRdChan <- rdBuf c.workerRdChan <- rdBuf
// And poll immediately. // And poll immediately.
interval = 0 interval = 0
} else if wrSz > 0 { case wrSz > 0:
// Sent data, poll immediately. // Sent data, poll immediately.
interval = 0 interval = 0
} else if interval == 0 { case interval == 0:
// Neither sent nor received data after a poll, re-initialize the delay. // Neither sent nor received data after a poll, re-initialize the delay.
interval = initPollInterval interval = initPollInterval
} else { default:
// Apply a multiplicative backoff. // Apply a multiplicative backoff.
interval = time.Duration(float64(interval) * pollIntervalMultiplier) interval = time.Duration(float64(interval) * pollIntervalMultiplier)
if interval > maxPollInterval { if interval > maxPollInterval {
@ -337,7 +347,7 @@ loop:
_ = c.Close() _ = c.Close()
} }
func newMeekConn(network, addr string, dialFn base.DialFunc, ca *meekClientArgs) (net.Conn, error) { func newMeekConn(dialFn base.DialFunc, ca *meekClientArgs) (net.Conn, error) {
id, err := newSessionID() id, err := newSessionID()
if err != nil { if err != nil {
return nil, err return nil, err

@ -40,7 +40,8 @@ import (
"net" "net"
"time" "time"
"git.torproject.org/pluggable-transports/goptlib.git" "gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/goptlib"
"gitlab.com/yawning/obfs4.git/common/csrand" "gitlab.com/yawning/obfs4.git/common/csrand"
"gitlab.com/yawning/obfs4.git/transports/base" "gitlab.com/yawning/obfs4.git/transports/base"
) )
@ -81,13 +82,13 @@ func (t *Transport) Name() string {
} }
// ClientFactory returns a new obfs2ClientFactory instance. // ClientFactory returns a new obfs2ClientFactory instance.
func (t *Transport) ClientFactory(stateDir string) (base.ClientFactory, error) { func (t *Transport) ClientFactory(_ string) (base.ClientFactory, error) {
cf := &obfs2ClientFactory{transport: t} cf := &obfs2ClientFactory{transport: t}
return cf, nil return cf, nil
} }
// ServerFactory returns a new obfs2ServerFactory instance. // ServerFactory returns a new obfs2ServerFactory instance.
func (t *Transport) ServerFactory(stateDir string, args *pt.Args) (base.ServerFactory, error) { func (t *Transport) ServerFactory(_ string, args *pt.Args) (base.ServerFactory, error) {
if err := validateArgs(args); err != nil { if err := validateArgs(args); err != nil {
return nil, err return nil, err
} }
@ -104,11 +105,11 @@ func (cf *obfs2ClientFactory) Transport() base.Transport {
return cf.transport return cf.transport
} }
func (cf *obfs2ClientFactory) ParseArgs(args *pt.Args) (interface{}, error) { func (cf *obfs2ClientFactory) ParseArgs(args *pt.Args) (any, error) {
return nil, validateArgs(args) return nil, validateArgs(args)
} }
func (cf *obfs2ClientFactory) Dial(network, addr string, dialFn base.DialFunc, args interface{}) (net.Conn, error) { func (cf *obfs2ClientFactory) Dial(network, addr string, dialFn base.DialFunc, _ any) (net.Conn, error) {
conn, err := dialFn(network, addr) conn, err := dialFn(network, addr)
if err != nil { if err != nil {
return nil, err return nil, err
@ -154,46 +155,46 @@ func (conn *obfs2Conn) Write(b []byte) (int, error) {
return conn.tx.Write(b) return conn.tx.Write(b)
} }
func newObfs2ClientConn(conn net.Conn) (c *obfs2Conn, err error) { func newObfs2ClientConn(conn net.Conn) (*obfs2Conn, error) {
// Initialize a client connection, and start the handshake timeout. // Initialize a client connection, and start the handshake timeout.
c = &obfs2Conn{conn, true, nil, nil} c := &obfs2Conn{conn, true, nil, nil}
deadline := time.Now().Add(clientHandshakeTimeout) deadline := time.Now().Add(clientHandshakeTimeout)
if err = c.SetDeadline(deadline); err != nil { if err := c.SetDeadline(deadline); err != nil {
return nil, err return nil, err
} }
// Handshake. // Handshake.
if err = c.handshake(); err != nil { if err := c.handshake(); err != nil {
return nil, err return nil, err
} }
// Disarm the handshake timer. // Disarm the handshake timer.
if err = c.SetDeadline(time.Time{}); err != nil { if err := c.SetDeadline(time.Time{}); err != nil {
return nil, err return nil, err
} }
return return c, nil
} }
func newObfs2ServerConn(conn net.Conn) (c *obfs2Conn, err error) { func newObfs2ServerConn(conn net.Conn) (*obfs2Conn, error) {
// Initialize a server connection, and start the handshake timeout. // Initialize a server connection, and start the handshake timeout.
c = &obfs2Conn{conn, false, nil, nil} c := &obfs2Conn{conn, false, nil, nil}
deadline := time.Now().Add(serverHandshakeTimeout) deadline := time.Now().Add(serverHandshakeTimeout)
if err = c.SetDeadline(deadline); err != nil { if err := c.SetDeadline(deadline); err != nil {
return nil, err return nil, err
} }
// Handshake. // Handshake.
if err = c.handshake(); err != nil { if err := c.handshake(); err != nil {
return nil, err return nil, err
} }
// Disarm the handshake timer. // Disarm the handshake timer.
if err = c.SetDeadline(time.Time{}); err != nil { if err := c.SetDeadline(time.Time{}); err != nil {
return nil, err return nil, err
} }
return return c, nil
} }
func (conn *obfs2Conn) handshake() error { func (conn *obfs2Conn) handshake() error {
@ -220,7 +221,7 @@ func (conn *obfs2Conn) handshake() error {
} else { } else {
padMagic = []byte(responderPadString) padMagic = []byte(responderPadString)
} }
padKey, padIV := hsKdf(padMagic, seed[:], conn.isInitiator) padKey, padIV := hsKdf(padMagic, seed[:])
padLen := uint32(csrand.IntRange(0, maxPadding)) padLen := uint32(csrand.IntRange(0, maxPadding))
hsBlob := make([]byte, hsLen+padLen) hsBlob := make([]byte, hsLen+padLen)
@ -265,7 +266,7 @@ func (conn *obfs2Conn) handshake() error {
} else { } else {
peerPadMagic = []byte(initiatorPadString) peerPadMagic = []byte(initiatorPadString)
} }
peerKey, peerIV := hsKdf(peerPadMagic, peerSeed[:], !conn.isInitiator) peerKey, peerIV := hsKdf(peerPadMagic, peerSeed[:])
rxBlock, err := aes.NewCipher(peerKey) rxBlock, err := aes.NewCipher(peerKey)
if err != nil { if err != nil {
return err return err
@ -273,7 +274,7 @@ func (conn *obfs2Conn) handshake() error {
rxStream := cipher.NewCTR(rxBlock, peerIV) rxStream := cipher.NewCTR(rxBlock, peerIV)
conn.rx = &cipher.StreamReader{S: rxStream, R: conn.Conn} conn.rx = &cipher.StreamReader{S: rxStream, R: conn.Conn}
hsHdr := make([]byte, hsLen) hsHdr := make([]byte, hsLen)
if _, err := io.ReadFull(conn, hsHdr[:]); err != nil { if _, err := io.ReadFull(conn, hsHdr); err != nil {
return err return err
} }
@ -296,11 +297,7 @@ func (conn *obfs2Conn) handshake() error {
} }
// Derive the actual keys. // Derive the actual keys.
if err := conn.kdf(seed[:], peerSeed[:]); err != nil { return conn.kdf(seed[:], peerSeed[:])
return err
}
return nil
} }
func (conn *obfs2Conn) kdf(seed, peerSeed []byte) error { func (conn *obfs2Conn) kdf(seed, peerSeed []byte) error {
@ -321,14 +318,14 @@ func (conn *obfs2Conn) kdf(seed, peerSeed []byte) error {
combSeed = append(combSeed, seed...) combSeed = append(combSeed, seed...)
} }
initKey, initIV := hsKdf([]byte(initiatorKdfString), combSeed, true) initKey, initIV := hsKdf([]byte(initiatorKdfString), combSeed)
initBlock, err := aes.NewCipher(initKey) initBlock, err := aes.NewCipher(initKey)
if err != nil { if err != nil {
return err return err
} }
initStream := cipher.NewCTR(initBlock, initIV) initStream := cipher.NewCTR(initBlock, initIV)
respKey, respIV := hsKdf([]byte(responderKdfString), combSeed, false) respKey, respIV := hsKdf([]byte(responderKdfString), combSeed)
respBlock, err := aes.NewCipher(respKey) respBlock, err := aes.NewCipher(respKey)
if err != nil { if err != nil {
return err return err
@ -346,16 +343,16 @@ func (conn *obfs2Conn) kdf(seed, peerSeed []byte) error {
return nil return nil
} }
func hsKdf(magic, seed []byte, isInitiator bool) (padKey, padIV []byte) { func hsKdf(magic, seed []byte) ([]byte, []byte) {
// The actual key/IV is derived in the form of: // The actual key/IV is derived in the form of:
// m = MAC(magic, seed) // m = MAC(magic, seed)
// KEY = m[:KEYLEN] // KEY = m[:KEYLEN]
// IV = m[KEYLEN:] // IV = m[KEYLEN:]
m := mac(magic, seed) m := mac(magic, seed)
padKey = m[:keyLen] padKey := m[:keyLen]
padIV = m[keyLen:] padIV := m[keyLen:]
return return padKey, padIV
} }
func mac(s, x []byte) []byte { func mac(s, x []byte) []byte {
@ -368,7 +365,9 @@ func mac(s, x []byte) []byte {
return h.Sum(nil) return h.Sum(nil)
} }
var _ base.ClientFactory = (*obfs2ClientFactory)(nil) var (
var _ base.ServerFactory = (*obfs2ServerFactory)(nil) _ base.ClientFactory = (*obfs2ClientFactory)(nil)
var _ base.Transport = (*Transport)(nil) _ base.ServerFactory = (*obfs2ServerFactory)(nil)
var _ net.Conn = (*obfs2Conn)(nil) _ base.Transport = (*Transport)(nil)
_ net.Conn = (*obfs2Conn)(nil)
)

@ -40,7 +40,8 @@ import (
"net" "net"
"time" "time"
"git.torproject.org/pluggable-transports/goptlib.git" "gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/goptlib"
"gitlab.com/yawning/obfs4.git/common/csrand" "gitlab.com/yawning/obfs4.git/common/csrand"
"gitlab.com/yawning/obfs4.git/common/uniformdh" "gitlab.com/yawning/obfs4.git/common/uniformdh"
"gitlab.com/yawning/obfs4.git/transports/base" "gitlab.com/yawning/obfs4.git/transports/base"
@ -69,13 +70,13 @@ func (t *Transport) Name() string {
} }
// ClientFactory returns a new obfs3ClientFactory instance. // ClientFactory returns a new obfs3ClientFactory instance.
func (t *Transport) ClientFactory(stateDir string) (base.ClientFactory, error) { func (t *Transport) ClientFactory(_ string) (base.ClientFactory, error) {
cf := &obfs3ClientFactory{transport: t} cf := &obfs3ClientFactory{transport: t}
return cf, nil return cf, nil
} }
// ServerFactory returns a new obfs3ServerFactory instance. // ServerFactory returns a new obfs3ServerFactory instance.
func (t *Transport) ServerFactory(stateDir string, args *pt.Args) (base.ServerFactory, error) { func (t *Transport) ServerFactory(_ string, _ *pt.Args) (base.ServerFactory, error) {
sf := &obfs3ServerFactory{transport: t} sf := &obfs3ServerFactory{transport: t}
return sf, nil return sf, nil
} }
@ -88,11 +89,11 @@ func (cf *obfs3ClientFactory) Transport() base.Transport {
return cf.transport return cf.transport
} }
func (cf *obfs3ClientFactory) ParseArgs(args *pt.Args) (interface{}, error) { func (cf *obfs3ClientFactory) ParseArgs(_ *pt.Args) (any, error) {
return nil, nil return nil, nil //nolint:nilnil
} }
func (cf *obfs3ClientFactory) Dial(network, addr string, dialFn base.DialFunc, args interface{}) (net.Conn, error) { func (cf *obfs3ClientFactory) Dial(network, addr string, dialFn base.DialFunc, _ any) (net.Conn, error) {
conn, err := dialFn(network, addr) conn, err := dialFn(network, addr)
if err != nil { if err != nil {
return nil, err return nil, err
@ -133,46 +134,46 @@ type obfs3Conn struct {
tx *cipher.StreamWriter tx *cipher.StreamWriter
} }
func newObfs3ClientConn(conn net.Conn) (c *obfs3Conn, err error) { func newObfs3ClientConn(conn net.Conn) (*obfs3Conn, error) {
// Initialize a client connection, and start the handshake timeout. // Initialize a client connection, and start the handshake timeout.
c = &obfs3Conn{conn, true, nil, nil, new(bytes.Buffer), nil, nil} c := &obfs3Conn{conn, true, nil, nil, new(bytes.Buffer), nil, nil}
deadline := time.Now().Add(clientHandshakeTimeout) deadline := time.Now().Add(clientHandshakeTimeout)
if err = c.SetDeadline(deadline); err != nil { if err := c.SetDeadline(deadline); err != nil {
return nil, err return nil, err
} }
// Handshake. // Handshake.
if err = c.handshake(); err != nil { if err := c.handshake(); err != nil {
return nil, err return nil, err
} }
// Disarm the handshake timer. // Disarm the handshake timer.
if err = c.SetDeadline(time.Time{}); err != nil { if err := c.SetDeadline(time.Time{}); err != nil {
return nil, err return nil, err
} }
return return c, nil
} }
func newObfs3ServerConn(conn net.Conn) (c *obfs3Conn, err error) { func newObfs3ServerConn(conn net.Conn) (*obfs3Conn, error) {
// Initialize a server connection, and start the handshake timeout. // Initialize a server connection, and start the handshake timeout.
c = &obfs3Conn{conn, false, nil, nil, new(bytes.Buffer), nil, nil} c := &obfs3Conn{conn, false, nil, nil, new(bytes.Buffer), nil, nil}
deadline := time.Now().Add(serverHandshakeTimeout) deadline := time.Now().Add(serverHandshakeTimeout)
if err = c.SetDeadline(deadline); err != nil { if err := c.SetDeadline(deadline); err != nil {
return nil, err return nil, err
} }
// Handshake. // Handshake.
if err = c.handshake(); err != nil { if err := c.handshake(); err != nil {
return nil, err return nil, err
} }
// Disarm the handshake timer. // Disarm the handshake timer.
if err = c.SetDeadline(time.Time{}); err != nil { if err := c.SetDeadline(time.Time{}); err != nil {
return nil, err return nil, err
} }
return return c, nil
} }
func (conn *obfs3Conn) handshake() error { func (conn *obfs3Conn) handshake() error {
@ -217,11 +218,7 @@ func (conn *obfs3Conn) handshake() error {
if err != nil { if err != nil {
return err return err
} }
if err := conn.kdf(sharedSecret); err != nil { return conn.kdf(sharedSecret)
return err
}
return nil
} }
func (conn *obfs3Conn) kdf(sharedSecret []byte) error { func (conn *obfs3Conn) kdf(sharedSecret []byte) error {
@ -313,13 +310,13 @@ func (conn *obfs3Conn) findPeerMagic() error {
} }
} }
func (conn *obfs3Conn) Read(b []byte) (n int, err error) { func (conn *obfs3Conn) Read(b []byte) (int, error) {
// If this is the first time we read data post handshake, scan for the // If this is the first time we read data post handshake, scan for the
// magic value. // magic value.
if conn.rxMagic != nil { if conn.rxMagic != nil {
if err = conn.findPeerMagic(); err != nil { if err := conn.findPeerMagic(); err != nil {
conn.Close() conn.Close()
return return 0, err
} }
conn.rxMagic = nil conn.rxMagic = nil
} }
@ -339,20 +336,20 @@ func (conn *obfs3Conn) Read(b []byte) (n int, err error) {
return conn.rx.Read(b) return conn.rx.Read(b)
} }
func (conn *obfs3Conn) Write(b []byte) (n int, err error) { func (conn *obfs3Conn) Write(b []byte) (int, error) {
// If this is the first time we write data post handshake, send the // If this is the first time we write data post handshake, send the
// padding/magic value. // padding/magic value.
if conn.txMagic != nil { if conn.txMagic != nil {
padLen := csrand.IntRange(0, maxPadding/2) padLen := csrand.IntRange(0, maxPadding/2)
blob := make([]byte, padLen+len(conn.txMagic)) blob := make([]byte, padLen+len(conn.txMagic))
if err = csrand.Bytes(blob[:padLen]); err != nil { if err := csrand.Bytes(blob[:padLen]); err != nil {
conn.Close() conn.Close()
return return 0, err
} }
copy(blob[padLen:], conn.txMagic) copy(blob[padLen:], conn.txMagic)
if _, err = conn.Conn.Write(blob); err != nil { if _, err := conn.Conn.Write(blob); err != nil {
conn.Close() conn.Close()
return return 0, err
} }
conn.txMagic = nil conn.txMagic = nil
} }
@ -360,7 +357,9 @@ func (conn *obfs3Conn) Write(b []byte) (n int, err error) {
return conn.tx.Write(b) return conn.tx.Write(b)
} }
var _ base.ClientFactory = (*obfs3ClientFactory)(nil) var (
var _ base.ServerFactory = (*obfs3ServerFactory)(nil) _ base.ClientFactory = (*obfs3ClientFactory)(nil)
var _ base.Transport = (*Transport)(nil) _ base.ServerFactory = (*obfs3ServerFactory)(nil)
var _ net.Conn = (*obfs3Conn)(nil) _ base.Transport = (*Transport)(nil)
_ net.Conn = (*obfs3Conn)(nil)
)

@ -25,39 +25,40 @@
* POSSIBILITY OF SUCH DAMAGE. * POSSIBILITY OF SUCH DAMAGE.
*/ */
//
// Package framing implements the obfs4 link framing and cryptography. // Package framing implements the obfs4 link framing and cryptography.
// //
// The Encoder/Decoder shared secret format is: // The Encoder/Decoder shared secret format is:
// uint8_t[32] NaCl secretbox key //
// uint8_t[16] NaCl Nonce prefix // uint8_t[32] NaCl secretbox key
// uint8_t[16] SipHash-2-4 key (used to obfsucate length) // uint8_t[16] NaCl Nonce prefix
// uint8_t[8] SipHash-2-4 IV // uint8_t[16] SipHash-2-4 key (used to obfsucate length)
// uint8_t[8] SipHash-2-4 IV
// //
// The frame format is: // The frame format is:
// uint16_t length (obfsucated, big endian) //
// NaCl secretbox (Poly1305/XSalsa20) containing: // uint16_t length (obfsucated, big endian)
// uint8_t[16] tag (Part of the secretbox construct) // NaCl secretbox (Poly1305/XSalsa20) containing:
// uint8_t[] payload // uint8_t[16] tag (Part of the secretbox construct)
// uint8_t[] payload
// //
// The length field is length of the NaCl secretbox XORed with the truncated // The length field is length of the NaCl secretbox XORed with the truncated
// SipHash-2-4 digest ran in OFB mode. // SipHash-2-4 digest ran in OFB mode.
// //
// Initialize K, IV[0] with values from the shared secret. // Initialize K, IV[0] with values from the shared secret.
// On each packet, IV[n] = H(K, IV[n - 1]) // On each packet, IV[n] = H(K, IV[n - 1])
// mask[n] = IV[n][0:2] // mask[n] = IV[n][0:2]
// obfsLen = length ^ mask[n] // obfsLen = length ^ mask[n]
// //
// The NaCl secretbox (Poly1305/XSalsa20) nonce format is: // The NaCl secretbox (Poly1305/XSalsa20) nonce format is:
// uint8_t[24] prefix (Fixed) //
// uint64_t counter (Big endian) // uint8_t[24] prefix (Fixed)
// uint64_t counter (Big endian)
// //
// The counter is initialized to 1, and is incremented on each frame. Since // The counter is initialized to 1, and is incremented on each frame. Since
// the protocol is designed to be used over a reliable medium, the nonce is not // the protocol is designed to be used over a reliable medium, the nonce is not
// transmitted over the wire as both sides of the conversation know the prefix // transmitted over the wire as both sides of the conversation know the prefix
// and the initial counter value. It is imperative that the counter does not // and the initial counter value. It is imperative that the counter does not
// wrap, and sessions MUST terminate before 2^64 frames are sent. // wrap, and sessions MUST terminate before 2^64 frames are sent.
//
package framing // import "gitlab.com/yawning/obfs4.git/transports/obfs4/framing" package framing // import "gitlab.com/yawning/obfs4.git/transports/obfs4/framing"
import ( import (
@ -67,9 +68,10 @@ import (
"fmt" "fmt"
"io" "io"
"golang.org/x/crypto/nacl/secretbox"
"gitlab.com/yawning/obfs4.git/common/csrand" "gitlab.com/yawning/obfs4.git/common/csrand"
"gitlab.com/yawning/obfs4.git/common/drbg" "gitlab.com/yawning/obfs4.git/common/drbg"
"golang.org/x/crypto/nacl/secretbox"
) )
const ( const (
@ -175,7 +177,7 @@ func NewEncoder(key []byte) *Encoder {
// Encode encodes a single frame worth of payload and returns the encoded // Encode encodes a single frame worth of payload and returns the encoded
// length. InvalidPayloadLengthError is recoverable, all other errors MUST be // length. InvalidPayloadLengthError is recoverable, all other errors MUST be
// treated as fatal and the session aborted. // treated as fatal and the session aborted.
func (encoder *Encoder) Encode(frame, payload []byte) (n int, err error) { func (encoder *Encoder) Encode(frame, payload []byte) (int, error) {
payloadLen := len(payload) payloadLen := len(payload)
if MaximumFramePayloadLength < payloadLen { if MaximumFramePayloadLength < payloadLen {
return 0, InvalidPayloadLengthError(payloadLen) return 0, InvalidPayloadLengthError(payloadLen)
@ -186,7 +188,7 @@ func (encoder *Encoder) Encode(frame, payload []byte) (n int, err error) {
// Generate a new nonce. // Generate a new nonce.
var nonce [nonceLength]byte var nonce [nonceLength]byte
if err = encoder.nonce.bytes(&nonce); err != nil { if err := encoder.nonce.bytes(&nonce); err != nil {
return 0, err return 0, err
} }
encoder.nonce.counter++ encoder.nonce.counter++

@ -30,6 +30,7 @@ package framing
import ( import (
"bytes" "bytes"
"crypto/rand" "crypto/rand"
"errors"
"testing" "testing"
) )
@ -89,7 +90,9 @@ func TestEncoder_Encode_Oversize(t *testing.T) {
var buf [MaximumFramePayloadLength + 1]byte var buf [MaximumFramePayloadLength + 1]byte
_, _ = rand.Read(buf[:]) // YOLO _, _ = rand.Read(buf[:]) // YOLO
_, err := encoder.Encode(frame[:], buf[:]) _, err := encoder.Encode(frame[:], buf[:])
if _, ok := err.(InvalidPayloadLengthError); !ok {
var payloadErr InvalidPayloadLengthError
if !errors.As(err, &payloadErr) {
t.Error("Encoder.encode() returned unexpected error:", err) t.Error("Encoder.encode() returned unexpected error:", err)
} }
} }
@ -150,7 +153,7 @@ func BenchmarkEncoder_Encode(b *testing.B) {
b.ResetTimer() b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
transfered := 0 var xfered int
buffer := bytes.NewBuffer(payload) buffer := bytes.NewBuffer(payload)
for 0 < buffer.Len() { for 0 < buffer.Len() {
n, err := buffer.Read(chopBuf[:]) n, err := buffer.Read(chopBuf[:])
@ -159,11 +162,10 @@ func BenchmarkEncoder_Encode(b *testing.B) {
} }
n, _ = encoder.Encode(frame[:], chopBuf[:n]) n, _ = encoder.Encode(frame[:], chopBuf[:n])
transfered += n - FrameOverhead xfered += n - FrameOverhead
} }
if transfered != len(payload) { if xfered != len(payload) {
b.Fatalf("Transfered length mismatch: %d != %d", transfered, b.Fatalf("Xfered length mismatch: %d != %d", xfered, len(payload))
len(payload))
} }
} }
} }

@ -280,7 +280,7 @@ func (hs *serverHandshake) parseClientHandshake(filter *replayfilter.ReplayFilte
macFound := false macFound := false
for _, off := range []int64{0, -1, 1} { for _, off := range []int64{0, -1, 1} {
// Allow epoch to be off by up to a hour in either direction. // Allow epoch to be off by up to a hour in either direction.
epochHour := []byte(strconv.FormatInt(getEpochHour()+int64(off), 10)) epochHour := []byte(strconv.FormatInt(getEpochHour()+off, 10))
hs.mac.Reset() hs.mac.Reset()
_, _ = hs.mac.Write(resp[:pos+markLength]) _, _ = hs.mac.Write(resp[:pos+markLength])
_, _ = hs.mac.Write(epochHour) _, _ = hs.mac.Write(epochHour)
@ -367,7 +367,7 @@ func getEpochHour() int64 {
return time.Now().Unix() / 3600 return time.Now().Unix() / 3600
} }
func findMarkMac(mark, buf []byte, startPos, maxPos int, fromTail bool) (pos int) { func findMarkMac(mark, buf []byte, startPos, maxPos int, fromTail bool) int {
if len(mark) != markLength { if len(mark) != markLength {
panic(fmt.Sprintf("BUG: Invalid mark length: %d", len(mark))) panic(fmt.Sprintf("BUG: Invalid mark length: %d", len(mark)))
} }
@ -387,19 +387,19 @@ func findMarkMac(mark, buf []byte, startPos, maxPos int, fromTail bool) (pos int
// The server can optimize the search process by only examining the // The server can optimize the search process by only examining the
// tail of the buffer. The client can't send valid data past M_C | // tail of the buffer. The client can't send valid data past M_C |
// MAC_C as it does not have the server's public key yet. // MAC_C as it does not have the server's public key yet.
pos = endPos - (markLength + macLength) pos := endPos - (markLength + macLength)
if !hmac.Equal(buf[pos:pos+markLength], mark) { if !hmac.Equal(buf[pos:pos+markLength], mark) {
return -1 return -1
} }
return return pos
} }
// The client has to actually do a substring search since the server can // The client has to actually do a substring search since the server can
// and will send payload trailing the response. // and will send payload trailing the response.
// //
// XXX: bytes.Index() uses a naive search, which kind of sucks. // XXX: bytes.Index() uses a naive search, which kind of sucks.
pos = bytes.Index(buf[startPos:endPos], mark) pos := bytes.Index(buf[startPos:endPos], mark)
if pos == -1 { if pos == -1 {
return -1 return -1
} }
@ -411,7 +411,7 @@ func findMarkMac(mark, buf []byte, startPos, maxPos int, fromTail bool) (pos int
// Return the index relative to the start of the slice. // Return the index relative to the start of the slice.
pos += startPos pos += startPos
return return pos
} }
func makePad(padLen int) ([]byte, error) { func makePad(padLen int) ([]byte, error) {

@ -115,7 +115,7 @@ func TestHandshakeNtorClient(t *testing.T) {
serverHs := newServerHandshake(nodeID, idKeypair, serverKeypair) serverHs := newServerHandshake(nodeID, idKeypair, serverKeypair)
_, err = serverHs.parseClientHandshake(serverFilter, clientBlob) _, err = serverHs.parseClientHandshake(serverFilter, clientBlob)
if err == nil { if err == nil {
t.Fatalf("serverHandshake.parseClientHandshake() succeded (oversized)") t.Fatalf("serverHandshake.parseClientHandshake() succeeded (oversized)")
} }
// Test undersized client padding. // Test undersized client padding.
@ -127,7 +127,7 @@ func TestHandshakeNtorClient(t *testing.T) {
serverHs = newServerHandshake(nodeID, idKeypair, serverKeypair) serverHs = newServerHandshake(nodeID, idKeypair, serverKeypair)
_, err = serverHs.parseClientHandshake(serverFilter, clientBlob) _, err = serverHs.parseClientHandshake(serverFilter, clientBlob)
if err == nil { if err == nil {
t.Fatalf("serverHandshake.parseClientHandshake() succeded (undersized)") t.Fatalf("serverHandshake.parseClientHandshake() succeeded (undersized)")
} }
} }
@ -204,7 +204,7 @@ func TestHandshakeNtorServer(t *testing.T) {
serverHs := newServerHandshake(nodeID, idKeypair, serverKeypair) serverHs := newServerHandshake(nodeID, idKeypair, serverKeypair)
_, err = serverHs.parseClientHandshake(serverFilter, clientBlob) _, err = serverHs.parseClientHandshake(serverFilter, clientBlob)
if err == nil { if err == nil {
t.Fatalf("serverHandshake.parseClientHandshake() succeded (oversized)") t.Fatalf("serverHandshake.parseClientHandshake() succeeded (oversized)")
} }
// Test undersized client padding. // Test undersized client padding.
@ -216,7 +216,7 @@ func TestHandshakeNtorServer(t *testing.T) {
serverHs = newServerHandshake(nodeID, idKeypair, serverKeypair) serverHs = newServerHandshake(nodeID, idKeypair, serverKeypair)
_, err = serverHs.parseClientHandshake(serverFilter, clientBlob) _, err = serverHs.parseClientHandshake(serverFilter, clientBlob)
if err == nil { if err == nil {
t.Fatalf("serverHandshake.parseClientHandshake() succeded (undersized)") t.Fatalf("serverHandshake.parseClientHandshake() succeeded (undersized)")
} }
// Test oversized server padding. // Test oversized server padding.
@ -243,6 +243,6 @@ func TestHandshakeNtorServer(t *testing.T) {
} }
_, _, err = clientHs.parseServerHandshake(serverBlob) _, _, err = clientHs.parseServerHandshake(serverBlob)
if err == nil { if err == nil {
t.Fatalf("clientHandshake.parseServerHandshake() succeded (oversized)") t.Fatalf("clientHandshake.parseServerHandshake() succeeded (oversized)")
} }
} }

@ -32,17 +32,18 @@ package obfs4 // import "gitlab.com/yawning/obfs4.git/transports/obfs4"
import ( import (
"bytes" "bytes"
"crypto/sha256" "crypto/sha256"
"errors"
"flag" "flag"
"fmt" "fmt"
"io" "io"
"io/ioutil"
"math/rand" "math/rand"
"net" "net"
"strconv" "strconv"
"syscall" "syscall"
"time" "time"
"git.torproject.org/pluggable-transports/goptlib.git" "gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/goptlib"
"gitlab.com/yawning/obfs4.git/common/drbg" "gitlab.com/yawning/obfs4.git/common/drbg"
"gitlab.com/yawning/obfs4.git/common/ntor" "gitlab.com/yawning/obfs4.git/common/ntor"
"gitlab.com/yawning/obfs4.git/common/probdist" "gitlab.com/yawning/obfs4.git/common/probdist"
@ -81,7 +82,7 @@ const (
// biasedDist controls if the probability table will be ScrambleSuit style or // biasedDist controls if the probability table will be ScrambleSuit style or
// uniformly distributed. // uniformly distributed.
var biasedDist bool var biasedDist = flag.Bool(biasCmdArg, false, "Enable obfs4 using ScrambleSuit style table generation")
type obfs4ClientArgs struct { type obfs4ClientArgs struct {
nodeID *ntor.NodeID nodeID *ntor.NodeID
@ -99,7 +100,7 @@ func (t *Transport) Name() string {
} }
// ClientFactory returns a new obfs4ClientFactory instance. // ClientFactory returns a new obfs4ClientFactory instance.
func (t *Transport) ClientFactory(stateDir string) (base.ClientFactory, error) { func (t *Transport) ClientFactory(_ string) (base.ClientFactory, error) {
cf := &obfs4ClientFactory{transport: t} cf := &obfs4ClientFactory{transport: t}
return cf, nil return cf, nil
} }
@ -137,7 +138,7 @@ func (t *Transport) ServerFactory(stateDir string, args *pt.Args) (base.ServerFa
if err != nil { if err != nil {
return nil, err return nil, err
} }
rng := rand.New(drbg) rng := rand.New(drbg) //nolint:gosec
sf := &obfs4ServerFactory{t, &ptArgs, st.nodeID, st.identityKey, st.drbgSeed, iatSeed, st.iatMode, filter, rng.Intn(maxCloseDelay)} sf := &obfs4ServerFactory{t, &ptArgs, st.nodeID, st.identityKey, st.drbgSeed, iatSeed, st.iatMode, filter, rng.Intn(maxCloseDelay)}
return sf, nil return sf, nil
@ -151,14 +152,14 @@ func (cf *obfs4ClientFactory) Transport() base.Transport {
return cf.transport return cf.transport
} }
func (cf *obfs4ClientFactory) ParseArgs(args *pt.Args) (interface{}, error) { func (cf *obfs4ClientFactory) ParseArgs(args *pt.Args) (any, error) {
var nodeID *ntor.NodeID var nodeID *ntor.NodeID
var publicKey *ntor.PublicKey var publicKey *ntor.PublicKey
// The "new" (version >= 0.0.3) bridge lines use a unified "cert" argument // The "new" (version >= 0.0.3) bridge lines use a unified "cert" argument
// for the Node ID and Public Key. // for the Node ID and Public Key.
certStr, ok := args.Get(certArg) certStr, ok := args.Get(certArg)
if ok { if ok { //nolint:nestif
cert, err := serverCertFromString(certStr) cert, err := serverCertFromString(certStr)
if err != nil { if err != nil {
return nil, err return nil, err
@ -195,7 +196,7 @@ func (cf *obfs4ClientFactory) ParseArgs(args *pt.Args) (interface{}, error) {
return nil, fmt.Errorf("invalid iat-mode '%d'", iatMode) return nil, fmt.Errorf("invalid iat-mode '%d'", iatMode)
} }
// Generate the session key pair before connectiong to hide the Elligator2 // Generate the session key pair before connecting to hide the Elligator2
// rejection sampling from network observers. // rejection sampling from network observers.
sessionKey, err := ntor.NewKeypair(true) sessionKey, err := ntor.NewKeypair(true)
if err != nil { if err != nil {
@ -205,7 +206,7 @@ func (cf *obfs4ClientFactory) ParseArgs(args *pt.Args) (interface{}, error) {
return &obfs4ClientArgs{nodeID, publicKey, sessionKey, iatMode}, nil return &obfs4ClientArgs{nodeID, publicKey, sessionKey, iatMode}, nil
} }
func (cf *obfs4ClientFactory) Dial(network, addr string, dialFn base.DialFunc, args interface{}) (net.Conn, error) { func (cf *obfs4ClientFactory) Dial(network, addr string, dialFn base.DialFunc, args any) (net.Conn, error) {
// Validate args before bothering to open connection. // Validate args before bothering to open connection.
ca, ok := args.(*obfs4ClientArgs) ca, ok := args.(*obfs4ClientArgs)
if !ok { if !ok {
@ -259,10 +260,10 @@ func (sf *obfs4ServerFactory) WrapConn(conn net.Conn) (net.Conn, error) {
return nil, err return nil, err
} }
lenDist := probdist.New(sf.lenSeed, 0, framing.MaximumSegmentLength, biasedDist) lenDist := probdist.New(sf.lenSeed, 0, framing.MaximumSegmentLength, *biasedDist)
var iatDist *probdist.WeightedDist var iatDist *probdist.WeightedDist
if sf.iatSeed != nil { if sf.iatSeed != nil {
iatDist = probdist.New(sf.iatSeed, 0, maxIATDelay, biasedDist) iatDist = probdist.New(sf.iatSeed, 0, maxIATDelay, *biasedDist)
} }
c := &obfs4Conn{conn, true, lenDist, iatDist, sf.iatMode, bytes.NewBuffer(nil), bytes.NewBuffer(nil), make([]byte, consumeReadSize), nil, nil} c := &obfs4Conn{conn, true, lenDist, iatDist, sf.iatMode, bytes.NewBuffer(nil), bytes.NewBuffer(nil), make([]byte, consumeReadSize), nil, nil}
@ -294,25 +295,28 @@ type obfs4Conn struct {
decoder *framing.Decoder decoder *framing.Decoder
} }
func newObfs4ClientConn(conn net.Conn, args *obfs4ClientArgs) (c *obfs4Conn, err error) { func newObfs4ClientConn(conn net.Conn, args *obfs4ClientArgs) (*obfs4Conn, error) {
// Generate the initial protocol polymorphism distribution(s). // Generate the initial protocol polymorphism distribution(s).
var seed *drbg.Seed var (
seed *drbg.Seed
err error
)
if seed, err = drbg.NewSeed(); err != nil { if seed, err = drbg.NewSeed(); err != nil {
return return nil, err
} }
lenDist := probdist.New(seed, 0, framing.MaximumSegmentLength, biasedDist) lenDist := probdist.New(seed, 0, framing.MaximumSegmentLength, *biasedDist)
var iatDist *probdist.WeightedDist var iatDist *probdist.WeightedDist
if args.iatMode != iatNone { if args.iatMode != iatNone {
var iatSeed *drbg.Seed var iatSeed *drbg.Seed
iatSeedSrc := sha256.Sum256(seed.Bytes()[:]) iatSeedSrc := sha256.Sum256(seed.Bytes()[:])
if iatSeed, err = drbg.SeedFromBytes(iatSeedSrc[:]); err != nil { if iatSeed, err = drbg.SeedFromBytes(iatSeedSrc[:]); err != nil {
return return nil, err
} }
iatDist = probdist.New(iatSeed, 0, maxIATDelay, biasedDist) iatDist = probdist.New(iatSeed, 0, maxIATDelay, *biasedDist)
} }
// Allocate the client structure. // Allocate the client structure.
c = &obfs4Conn{conn, false, lenDist, iatDist, args.iatMode, bytes.NewBuffer(nil), bytes.NewBuffer(nil), make([]byte, consumeReadSize), nil, nil} c := &obfs4Conn{conn, false, lenDist, iatDist, args.iatMode, bytes.NewBuffer(nil), bytes.NewBuffer(nil), make([]byte, consumeReadSize), nil, nil}
// Start the handshake timeout. // Start the handshake timeout.
deadline := time.Now().Add(clientHandshakeTimeout) deadline := time.Now().Add(clientHandshakeTimeout)
@ -329,7 +333,7 @@ func newObfs4ClientConn(conn net.Conn, args *obfs4ClientArgs) (c *obfs4Conn, err
return nil, err return nil, err
} }
return return c, nil
} }
func (conn *obfs4Conn) clientHandshake(nodeID *ntor.NodeID, peerIdentityKey *ntor.PublicKey, sessionKey *ntor.Keypair) error { func (conn *obfs4Conn) clientHandshake(nodeID *ntor.NodeID, peerIdentityKey *ntor.PublicKey, sessionKey *ntor.Keypair) error {
@ -359,14 +363,14 @@ func (conn *obfs4Conn) clientHandshake(nodeID *ntor.NodeID, peerIdentityKey *nto
conn.receiveBuffer.Write(hsBuf[:n]) conn.receiveBuffer.Write(hsBuf[:n])
n, seed, err := hs.parseServerHandshake(conn.receiveBuffer.Bytes()) n, seed, err := hs.parseServerHandshake(conn.receiveBuffer.Bytes())
if err == ErrMarkNotFoundYet { if errors.Is(err, ErrMarkNotFoundYet) {
continue continue
} else if err != nil { } else if err != nil {
return err return err
} }
_ = conn.receiveBuffer.Next(n) _ = conn.receiveBuffer.Next(n)
// Use the derived key material to intialize the link crypto. // Use the derived key material to initialize the link crypto.
okm := ntor.Kdf(seed, framing.KeyLength*2) okm := ntor.Kdf(seed, framing.KeyLength*2)
conn.encoder = framing.NewEncoder(okm[:framing.KeyLength]) conn.encoder = framing.NewEncoder(okm[:framing.KeyLength])
conn.decoder = framing.NewDecoder(okm[framing.KeyLength:]) conn.decoder = framing.NewDecoder(okm[framing.KeyLength:])
@ -398,7 +402,7 @@ func (conn *obfs4Conn) serverHandshake(sf *obfs4ServerFactory, sessionKey *ntor.
conn.receiveBuffer.Write(hsBuf[:n]) conn.receiveBuffer.Write(hsBuf[:n])
seed, err := hs.parseClientHandshake(sf.replayFilter, conn.receiveBuffer.Bytes()) seed, err := hs.parseClientHandshake(sf.replayFilter, conn.receiveBuffer.Bytes())
if err == ErrMarkNotFoundYet { if errors.Is(err, ErrMarkNotFoundYet) {
continue continue
} else if err != nil { } else if err != nil {
return err return err
@ -406,10 +410,10 @@ func (conn *obfs4Conn) serverHandshake(sf *obfs4ServerFactory, sessionKey *ntor.
conn.receiveBuffer.Reset() conn.receiveBuffer.Reset()
if err := conn.Conn.SetDeadline(time.Time{}); err != nil { if err := conn.Conn.SetDeadline(time.Time{}); err != nil {
return nil return err
} }
// Use the derived key material to intialize the link crypto. // Use the derived key material to initialize the link crypto.
okm := ntor.Kdf(seed, framing.KeyLength*2) okm := ntor.Kdf(seed, framing.KeyLength*2)
conn.encoder = framing.NewEncoder(okm[framing.KeyLength:]) conn.encoder = framing.NewEncoder(okm[framing.KeyLength:])
conn.decoder = framing.NewDecoder(okm[:framing.KeyLength]) conn.decoder = framing.NewDecoder(okm[:framing.KeyLength])
@ -421,7 +425,7 @@ func (conn *obfs4Conn) serverHandshake(sf *obfs4ServerFactory, sessionKey *ntor.
// the length obfuscation, this makes the amount of data received from the // the length obfuscation, this makes the amount of data received from the
// server inconsistent with the length sent from the client. // server inconsistent with the length sent from the client.
// //
// Rebalance this by tweaking the client mimimum padding/server maximum // Rebalance this by tweaking the client minimum padding/server maximum
// padding, and sending the PRNG seed unpadded (As in, treat the PRNG seed // padding, and sending the PRNG seed unpadded (As in, treat the PRNG seed
// as part of the server response). See inlineSeedFrameLength in // as part of the server response). See inlineSeedFrameLength in
// handshake_ntor.go. // handshake_ntor.go.
@ -447,13 +451,14 @@ func (conn *obfs4Conn) serverHandshake(sf *obfs4ServerFactory, sessionKey *ntor.
return nil return nil
} }
func (conn *obfs4Conn) Read(b []byte) (n int, err error) { func (conn *obfs4Conn) Read(b []byte) (int, error) {
// If there is no payload from the previous Read() calls, consume data off // If there is no payload from the previous Read() calls, consume data off
// the network. Not all data received is guaranteed to be usable payload, // the network. Not all data received is guaranteed to be usable payload,
// so do this in a loop till data is present or an error occurs. // so do this in a loop till data is present or an error occurs.
var err error
for conn.receiveDecodedBuffer.Len() == 0 { for conn.receiveDecodedBuffer.Len() == 0 {
err = conn.readPackets() err = conn.readPackets()
if err == framing.ErrAgain { if errors.Is(err, framing.ErrAgain) {
// Don't proagate this back up the call stack if we happen to break // Don't proagate this back up the call stack if we happen to break
// out of the loop. // out of the loop.
err = nil err = nil
@ -465,6 +470,7 @@ func (conn *obfs4Conn) Read(b []byte) (n int, err error) {
// Even if err is set, attempt to do the read anyway so that all decoded // Even if err is set, attempt to do the read anyway so that all decoded
// data gets relayed before the connection is torn down. // data gets relayed before the connection is torn down.
var n int
if conn.receiveDecodedBuffer.Len() > 0 { if conn.receiveDecodedBuffer.Len() > 0 {
var berr error var berr error
n, berr = conn.receiveDecodedBuffer.Read(b) n, berr = conn.receiveDecodedBuffer.Read(b)
@ -475,28 +481,29 @@ func (conn *obfs4Conn) Read(b []byte) (n int, err error) {
} }
} }
return return n, err
} }
func (conn *obfs4Conn) Write(b []byte) (n int, err error) { func (conn *obfs4Conn) Write(b []byte) (int, error) {
chopBuf := bytes.NewBuffer(b) chopBuf := bytes.NewBuffer(b)
var payload [maxPacketPayloadLength]byte var (
var frameBuf bytes.Buffer payload [maxPacketPayloadLength]byte
frameBuf bytes.Buffer
n int
)
// Chop the pending data into payload frames. // Chop the pending data into payload frames.
for chopBuf.Len() > 0 { for chopBuf.Len() > 0 {
// Send maximum sized frames. // Send maximum sized frames.
rdLen := 0 rdLen, err := chopBuf.Read(payload[:])
rdLen, err = chopBuf.Read(payload[:])
if err != nil { if err != nil {
return 0, err return 0, err
} else if rdLen == 0 { } else if rdLen == 0 {
panic(fmt.Sprintf("BUG: Write(), chopping length was 0")) panic("BUG: Write(), chopping length was 0")
} }
n += rdLen n += rdLen
err = conn.makePacket(&frameBuf, packetTypePayload, payload[:rdLen], 0) if err = conn.makePacket(&frameBuf, packetTypePayload, payload[:rdLen], 0); err != nil {
if err != nil {
return 0, err return 0, err
} }
} }
@ -504,7 +511,7 @@ func (conn *obfs4Conn) Write(b []byte) (n int, err error) {
if conn.iatMode != iatParanoid { if conn.iatMode != iatParanoid {
// For non-paranoid IAT, pad once per burst. Paranoid IAT handles // For non-paranoid IAT, pad once per burst. Paranoid IAT handles
// things differently. // things differently.
if err = conn.padBurst(&frameBuf, conn.lenDist.Sample()); err != nil { if err := conn.padBurst(&frameBuf, conn.lenDist.Sample()); err != nil {
return 0, err return 0, err
} }
} }
@ -513,10 +520,11 @@ func (conn *obfs4Conn) Write(b []byte) (n int, err error) {
// because the frame encoder state is advanced, and the code doesn't keep // because the frame encoder state is advanced, and the code doesn't keep
// frameBuf around. In theory, write timeouts and whatnot could be // frameBuf around. In theory, write timeouts and whatnot could be
// supported if this wasn't the case, but that complicates the code. // supported if this wasn't the case, but that complicates the code.
if conn.iatMode != iatNone { var err error
if conn.iatMode != iatNone { //nolint:nestif
var iatFrame [framing.MaximumSegmentLength]byte var iatFrame [framing.MaximumSegmentLength]byte
for frameBuf.Len() > 0 { for frameBuf.Len() > 0 {
iatWrLen := 0 var iatWrLen int
switch conn.iatMode { switch conn.iatMode {
case iatEnabled: case iatEnabled:
@ -549,7 +557,7 @@ func (conn *obfs4Conn) Write(b []byte) (n int, err error) {
if err != nil { if err != nil {
return 0, err return 0, err
} else if iatWrLen == 0 { } else if iatWrLen == 0 {
panic(fmt.Sprintf("BUG: Write(), iat length was 0")) panic("BUG: Write(), iat length was 0")
} }
// Calculate the delay. The delay resolution is 100 usec, leading // Calculate the delay. The delay resolution is 100 usec, leading
@ -557,8 +565,7 @@ func (conn *obfs4Conn) Write(b []byte) (n int, err error) {
iatDelta := time.Duration(conn.iatDist.Sample() * 100) iatDelta := time.Duration(conn.iatDist.Sample() * 100)
// Write then sleep. // Write then sleep.
_, err = conn.Conn.Write(iatFrame[:iatWrLen]) if _, err = conn.Conn.Write(iatFrame[:iatWrLen]); err != nil {
if err != nil {
return 0, err return 0, err
} }
time.Sleep(iatDelta * time.Microsecond) time.Sleep(iatDelta * time.Microsecond)
@ -567,14 +574,14 @@ func (conn *obfs4Conn) Write(b []byte) (n int, err error) {
_, err = conn.Conn.Write(frameBuf.Bytes()) _, err = conn.Conn.Write(frameBuf.Bytes())
} }
return return n, err
} }
func (conn *obfs4Conn) SetDeadline(t time.Time) error { func (conn *obfs4Conn) SetDeadline(_ time.Time) error {
return syscall.ENOTSUP return syscall.ENOTSUP
} }
func (conn *obfs4Conn) SetWriteDeadline(t time.Time) error { func (conn *obfs4Conn) SetWriteDeadline(_ time.Time) error {
return syscall.ENOTSUP return syscall.ENOTSUP
} }
@ -594,13 +601,13 @@ func (conn *obfs4Conn) closeAfterDelay(sf *obfs4ServerFactory, startTime time.Ti
// Consume and discard data on this connection until the specified interval // Consume and discard data on this connection until the specified interval
// passes. // passes.
_, _ = io.Copy(ioutil.Discard, conn.Conn) _, _ = io.Copy(io.Discard, conn.Conn)
} }
func (conn *obfs4Conn) padBurst(burst *bytes.Buffer, toPadTo int) (err error) { func (conn *obfs4Conn) padBurst(burst *bytes.Buffer, toPadTo int) error {
tailLen := burst.Len() % framing.MaximumSegmentLength tailLen := burst.Len() % framing.MaximumSegmentLength
padLen := 0 var padLen int
if toPadTo >= tailLen { if toPadTo >= tailLen {
padLen = toPadTo - tailLen padLen = toPadTo - tailLen
} else { } else {
@ -608,32 +615,24 @@ func (conn *obfs4Conn) padBurst(burst *bytes.Buffer, toPadTo int) (err error) {
} }
if padLen > headerLength { if padLen > headerLength {
err = conn.makePacket(burst, packetTypePayload, []byte{}, if err := conn.makePacket(burst, packetTypePayload, []byte{}, uint16(padLen-headerLength)); err != nil {
uint16(padLen-headerLength)) return err
if err != nil {
return
} }
} else if padLen > 0 { } else if padLen > 0 {
err = conn.makePacket(burst, packetTypePayload, []byte{}, if err := conn.makePacket(burst, packetTypePayload, []byte{}, maxPacketPayloadLength); err != nil {
maxPacketPayloadLength) return err
if err != nil {
return
} }
err = conn.makePacket(burst, packetTypePayload, []byte{}, if err := conn.makePacket(burst, packetTypePayload, []byte{}, uint16(padLen)); err != nil {
uint16(padLen)) return err
if err != nil {
return
} }
} }
return return nil
}
func init() {
flag.BoolVar(&biasedDist, biasCmdArg, false, "Enable obfs4 using ScrambleSuit style table generation")
} }
var _ base.ClientFactory = (*obfs4ClientFactory)(nil) var (
var _ base.ServerFactory = (*obfs4ServerFactory)(nil) _ base.ClientFactory = (*obfs4ClientFactory)(nil)
var _ base.Transport = (*Transport)(nil) _ base.ServerFactory = (*obfs4ServerFactory)(nil)
var _ net.Conn = (*obfs4Conn)(nil) _ base.Transport = (*Transport)(nil)
_ net.Conn = (*obfs4Conn)(nil)
)

@ -30,6 +30,7 @@ package obfs4
import ( import (
"crypto/sha256" "crypto/sha256"
"encoding/binary" "encoding/binary"
"errors"
"fmt" "fmt"
"io" "io"
@ -52,7 +53,7 @@ const (
) )
// InvalidPacketLengthError is the error returned when decodePacket detects a // InvalidPacketLengthError is the error returned when decodePacket detects a
// invalid packet length/ // invalid packet length.
type InvalidPacketLengthError int type InvalidPacketLengthError int
func (e InvalidPacketLengthError) Error() string { func (e InvalidPacketLengthError) Error() string {
@ -85,7 +86,7 @@ func (conn *obfs4Conn) makePacket(w io.Writer, pktType uint8, data []byte, padLe
pkt[0] = pktType pkt[0] = pktType
binary.BigEndian.PutUint16(pkt[1:], uint16(len(data))) binary.BigEndian.PutUint16(pkt[1:], uint16(len(data)))
if len(data) > 0 { if len(data) > 0 {
copy(pkt[3:], data[:]) copy(pkt[3:], data)
} }
copy(pkt[3+len(data):], zeroPadBytes[:padLen]) copy(pkt[3+len(data):], zeroPadBytes[:padLen])
@ -108,23 +109,28 @@ func (conn *obfs4Conn) makePacket(w io.Writer, pktType uint8, data []byte, padLe
return nil return nil
} }
func (conn *obfs4Conn) readPackets() (err error) { func (conn *obfs4Conn) readPackets() error {
// Attempt to read off the network. // Attempt to read off the network.
rdLen, rdErr := conn.Conn.Read(conn.readBuffer) rdLen, rdErr := conn.Conn.Read(conn.readBuffer)
conn.receiveBuffer.Write(conn.readBuffer[:rdLen]) conn.receiveBuffer.Write(conn.readBuffer[:rdLen])
var decoded [framing.MaximumFramePayloadLength]byte var (
decoded [framing.MaximumFramePayloadLength]byte
err error
)
bufferLoop:
for conn.receiveBuffer.Len() > 0 { for conn.receiveBuffer.Len() > 0 {
// Decrypt an AEAD frame. // Decrypt an AEAD frame.
decLen := 0 var decLen int
decLen, err = conn.decoder.Decode(decoded[:], conn.receiveBuffer) decLen, err = conn.decoder.Decode(decoded[:], conn.receiveBuffer)
if err == framing.ErrAgain { switch {
break case errors.Is(err, framing.ErrAgain):
} else if err != nil { break bufferLoop
break case err != nil:
} else if decLen < packetOverhead { break bufferLoop
case decLen < packetOverhead:
err = InvalidPacketLengthError(decLen) err = InvalidPacketLengthError(decLen)
break break bufferLoop
} }
// Decode the packet. // Decode the packet.
@ -171,5 +177,5 @@ func (conn *obfs4Conn) readPackets() (err error) {
return rdErr return rdErr
} }
return return err
} }

@ -28,16 +28,17 @@
package obfs4 package obfs4
import ( import (
"bytes"
"encoding/base64" "encoding/base64"
"encoding/json" "encoding/json"
"fmt" "fmt"
"io/ioutil"
"os" "os"
"path" "path"
"strconv" "strconv"
"strings" "strings"
"git.torproject.org/pluggable-transports/goptlib.git" "gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/goptlib"
"gitlab.com/yawning/obfs4.git/common/csrand" "gitlab.com/yawning/obfs4.git/common/csrand"
"gitlab.com/yawning/obfs4.git/common/drbg" "gitlab.com/yawning/obfs4.git/common/drbg"
"gitlab.com/yawning/obfs4.git/common/ntor" "gitlab.com/yawning/obfs4.git/common/ntor"
@ -81,7 +82,7 @@ func (cert *obfs4ServerCert) unpack() (*ntor.NodeID, *ntor.PublicKey) {
func serverCertFromString(encoded string) (*obfs4ServerCert, error) { func serverCertFromString(encoded string) (*obfs4ServerCert, error) {
decoded, err := base64.StdEncoding.DecodeString(encoded + certSuffix) decoded, err := base64.StdEncoding.DecodeString(encoded + certSuffix)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to decode cert: %s", err) return nil, fmt.Errorf("failed to decode cert: %w", err)
} }
if len(decoded) != certLength { if len(decoded) != certLength {
@ -93,7 +94,10 @@ func serverCertFromString(encoded string) (*obfs4ServerCert, error) {
func serverCertFromState(st *obfs4ServerState) *obfs4ServerCert { func serverCertFromState(st *obfs4ServerState) *obfs4ServerCert {
cert := new(obfs4ServerCert) cert := new(obfs4ServerCert)
cert.raw = append(st.nodeID.Bytes()[:], st.identityKey.Public().Bytes()[:]...)
cert.raw = bytes.Clone(st.nodeID.Bytes()[:])
cert.raw = append(cert.raw, st.identityKey.Public().Bytes()[:]...)
return cert return cert
} }
@ -121,15 +125,16 @@ func serverStateFromArgs(stateDir string, args *pt.Args) (*obfs4ServerState, err
// Either a private key, node id, and seed are ALL specified, or // Either a private key, node id, and seed are ALL specified, or
// they should be loaded from the state file. // they should be loaded from the state file.
if !privKeyOk && !nodeIDOk && !seedOk { switch {
case !privKeyOk && !nodeIDOk && !seedOk:
if err := jsonServerStateFromFile(stateDir, &js); err != nil { if err := jsonServerStateFromFile(stateDir, &js); err != nil {
return nil, err return nil, err
} }
} else if !privKeyOk { case !privKeyOk:
return nil, fmt.Errorf("missing argument '%s'", privateKeyArg) return nil, fmt.Errorf("missing argument '%s'", privateKeyArg)
} else if !nodeIDOk { case !nodeIDOk:
return nil, fmt.Errorf("missing argument '%s'", nodeIDArg) return nil, fmt.Errorf("missing argument '%s'", nodeIDArg)
} else if !seedOk { case !seedOk:
return nil, fmt.Errorf("missing argument '%s'", seedArg) return nil, fmt.Errorf("missing argument '%s'", seedArg)
} }
@ -177,7 +182,7 @@ func serverStateFromJSONServerState(stateDir string, js *jsonServerState) (*obfs
func jsonServerStateFromFile(stateDir string, js *jsonServerState) error { func jsonServerStateFromFile(stateDir string, js *jsonServerState) error {
fPath := path.Join(stateDir, stateFile) fPath := path.Join(stateDir, stateFile)
f, err := ioutil.ReadFile(fPath) f, err := os.ReadFile(fPath)
if err != nil { if err != nil {
if os.IsNotExist(err) { if os.IsNotExist(err) {
if err = newJSONServerState(stateDir, js); err == nil { if err = newJSONServerState(stateDir, js); err == nil {
@ -188,27 +193,29 @@ func jsonServerStateFromFile(stateDir string, js *jsonServerState) error {
} }
if err := json.Unmarshal(f, js); err != nil { if err := json.Unmarshal(f, js); err != nil {
return fmt.Errorf("failed to load statefile '%s': %s", fPath, err) return fmt.Errorf("failed to load statefile '%s': %w", fPath, err)
} }
return nil return nil
} }
func newJSONServerState(stateDir string, js *jsonServerState) (err error) { func newJSONServerState(stateDir string, js *jsonServerState) error {
// Generate everything a server needs, using the cryptographic PRNG. // Generate everything a server needs, using the cryptographic PRNG.
var st obfs4ServerState var st obfs4ServerState
rawID := make([]byte, ntor.NodeIDLength) rawID := make([]byte, ntor.NodeIDLength)
if err = csrand.Bytes(rawID); err != nil { if err := csrand.Bytes(rawID); err != nil {
return return err
} }
var err error
if st.nodeID, err = ntor.NewNodeID(rawID); err != nil { if st.nodeID, err = ntor.NewNodeID(rawID); err != nil {
return return err
} }
if st.identityKey, err = ntor.NewKeypair(false); err != nil { if st.identityKey, err = ntor.NewKeypair(false); err != nil {
return return err
} }
if st.drbgSeed, err = drbg.NewSeed(); err != nil { if st.drbgSeed, err = drbg.NewSeed(); err != nil {
return return err
} }
st.iatMode = iatNone st.iatMode = iatNone
@ -228,11 +235,7 @@ func writeJSONServerState(stateDir string, js *jsonServerState) error {
if encoded, err = json.Marshal(js); err != nil { if encoded, err = json.Marshal(js); err != nil {
return err return err
} }
if err = ioutil.WriteFile(path.Join(stateDir, stateFile), encoded, 0600); err != nil { return os.WriteFile(path.Join(stateDir, stateFile), encoded, 0o600)
return err
}
return nil
} }
func newBridgeFile(stateDir string, st *obfs4ServerState) error { func newBridgeFile(stateDir string, st *obfs4ServerState) error {
@ -252,9 +255,5 @@ func newBridgeFile(stateDir string, st *obfs4ServerState) error {
st.clientString()) st.clientString())
tmp := []byte(prefix + bridgeLine) tmp := []byte(prefix + bridgeLine)
if err := ioutil.WriteFile(path.Join(stateDir, bridgeFile), tmp, 0600); err != nil { return os.WriteFile(path.Join(stateDir, bridgeFile), tmp, 0o600)
return err
}
return nil
} }

@ -33,7 +33,8 @@ import (
"fmt" "fmt"
"net" "net"
"git.torproject.org/pluggable-transports/goptlib.git" "gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/goptlib"
"gitlab.com/yawning/obfs4.git/transports/base" "gitlab.com/yawning/obfs4.git/transports/base"
) )
@ -58,8 +59,7 @@ func (t *Transport) ClientFactory(stateDir string) (base.ClientFactory, error) {
} }
// ServerFactory will one day return a new ssServerFactory instance. // ServerFactory will one day return a new ssServerFactory instance.
func (t *Transport) ServerFactory(stateDir string, args *pt.Args) (base.ServerFactory, error) { func (t *Transport) ServerFactory(_ string, _ *pt.Args) (base.ServerFactory, error) {
// TODO: Fill this in eventually, though obfs4 is better.
return nil, fmt.Errorf("server not supported") return nil, fmt.Errorf("server not supported")
} }
@ -72,11 +72,11 @@ func (cf *ssClientFactory) Transport() base.Transport {
return cf.transport return cf.transport
} }
func (cf *ssClientFactory) ParseArgs(args *pt.Args) (interface{}, error) { func (cf *ssClientFactory) ParseArgs(args *pt.Args) (any, error) {
return newClientArgs(args) return newClientArgs(args)
} }
func (cf *ssClientFactory) Dial(network, addr string, dialFn base.DialFunc, args interface{}) (net.Conn, error) { func (cf *ssClientFactory) Dial(network, addr string, dialFn base.DialFunc, args any) (net.Conn, error) {
// Validate args before opening outgoing connection. // Validate args before opening outgoing connection.
ca, ok := args.(*ssClientArgs) ca, ok := args.(*ssClientArgs)
if !ok { if !ok {
@ -95,5 +95,7 @@ func (cf *ssClientFactory) Dial(network, addr string, dialFn base.DialFunc, args
return conn, nil return conn, nil
} }
var _ base.ClientFactory = (*ssClientFactory)(nil) var (
var _ base.Transport = (*Transport)(nil) _ base.ClientFactory = (*ssClientFactory)(nil)
_ base.Transport = (*Transport)(nil)
)

@ -42,7 +42,9 @@ import (
"net" "net"
"time" "time"
"git.torproject.org/pluggable-transports/goptlib.git" "gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/goptlib"
"golang.org/x/crypto/hkdf"
"gitlab.com/yawning/obfs4.git/common/csrand" "gitlab.com/yawning/obfs4.git/common/csrand"
"gitlab.com/yawning/obfs4.git/common/drbg" "gitlab.com/yawning/obfs4.git/common/drbg"
"gitlab.com/yawning/obfs4.git/common/probdist" "gitlab.com/yawning/obfs4.git/common/probdist"
@ -87,8 +89,11 @@ type ssClientArgs struct {
sessionKey *uniformdh.PrivateKey sessionKey *uniformdh.PrivateKey
} }
func newClientArgs(args *pt.Args) (ca *ssClientArgs, err error) { func newClientArgs(args *pt.Args) (*ssClientArgs, error) {
ca = &ssClientArgs{} var (
ca ssClientArgs
err error
)
if ca.kB, err = parsePasswordArg(args); err != nil { if ca.kB, err = parsePasswordArg(args); err != nil {
return nil, err return nil, err
} }
@ -99,7 +104,7 @@ func newClientArgs(args *pt.Args) (ca *ssClientArgs, err error) {
if ca.sessionKey, err = uniformdh.GenerateKey(csrand.Reader); err != nil { if ca.sessionKey, err = uniformdh.GenerateKey(csrand.Reader); err != nil {
return nil, err return nil, err
} }
return return &ca, nil
} }
func parsePasswordArg(args *pt.Args) (*ssSharedSecret, error) { func parsePasswordArg(args *pt.Args) (*ssSharedSecret, error) {
@ -112,7 +117,7 @@ func parsePasswordArg(args *pt.Args) (*ssSharedSecret, error) {
// shared secret (k_B) used for handshaking. // shared secret (k_B) used for handshaking.
decoded, err := base32.StdEncoding.DecodeString(str) decoded, err := base32.StdEncoding.DecodeString(str)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to decode password: %s", err) return nil, fmt.Errorf("failed to decode password: %w", err)
} }
if len(decoded) != sharedSecretLength { if len(decoded) != sharedSecretLength {
return nil, fmt.Errorf("password length %d is invalid", len(decoded)) return nil, fmt.Errorf("password length %d is invalid", len(decoded))
@ -131,7 +136,7 @@ func newCryptoState(aesKey []byte, ivPrefix []byte, macKey []byte) (*ssCryptoSta
// The ScrambleSuit CTR-AES256 link crypto uses an 8 byte prefix from the // The ScrambleSuit CTR-AES256 link crypto uses an 8 byte prefix from the
// KDF, and a 64 bit counter initialized to 1 as the IV. The initial value // KDF, and a 64 bit counter initialized to 1 as the IV. The initial value
// of the counter isn't documented in the spec either. // of the counter isn't documented in the spec either.
var initialCtr = []byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01} initialCtr := []byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01}
iv := make([]byte, 0, aes.BlockSize) iv := make([]byte, 0, aes.BlockSize)
iv = append(iv, ivPrefix...) iv = append(iv, ivPrefix...)
iv = append(iv, initialCtr...) iv = append(iv, initialCtr...)
@ -168,7 +173,8 @@ type ssRxState struct {
payloadLen int payloadLen int
} }
func (conn *ssConn) Read(b []byte) (n int, err error) { func (conn *ssConn) Read(b []byte) (int, error) {
var err error
// If the receive payload buffer is empty, consume data off the network. // If the receive payload buffer is empty, consume data off the network.
for conn.receiveDecodedBuffer.Len() == 0 { for conn.receiveDecodedBuffer.Len() == 0 {
if err = conn.readPackets(); err != nil { if err = conn.readPackets(); err != nil {
@ -177,17 +183,19 @@ func (conn *ssConn) Read(b []byte) (n int, err error) {
} }
// Service the read request using buffered payload. // Service the read request using buffered payload.
var n int
if conn.receiveDecodedBuffer.Len() > 0 { if conn.receiveDecodedBuffer.Len() > 0 {
n, _ = conn.receiveDecodedBuffer.Read(b) n, _ = conn.receiveDecodedBuffer.Read(b)
} }
return return n, err
} }
func (conn *ssConn) Write(b []byte) (n int, err error) { func (conn *ssConn) Write(b []byte) (int, error) {
var frameBuf bytes.Buffer var frameBuf bytes.Buffer
p := b p := b
toSend := len(p) toSend := len(p)
var n int
for toSend > 0 { for toSend > 0 {
// Send as much payload as will fit into each frame as possible. // Send as much payload as will fit into each frame as possible.
wrLen := len(p) wrLen := len(p)
@ -195,7 +203,7 @@ func (conn *ssConn) Write(b []byte) (n int, err error) {
wrLen = maxPayloadLength wrLen = maxPayloadLength
} }
payload := p[:wrLen] payload := p[:wrLen]
if err = conn.makePacket(&frameBuf, pktPayload, payload, 0); err != nil { if err := conn.makePayloadPacket(&frameBuf, payload, 0); err != nil {
return 0, err return 0, err
} }
@ -205,28 +213,28 @@ func (conn *ssConn) Write(b []byte) (n int, err error) {
} }
// Pad out the burst as appropriate. // Pad out the burst as appropriate.
if err = conn.padBurst(&frameBuf, conn.lenDist.Sample()); err != nil { if err := conn.padBurst(&frameBuf, conn.lenDist.Sample()); err != nil {
return 0, err return 0, err
} }
// Write and return. // Write and return.
_, err = conn.Conn.Write(frameBuf.Bytes()) _, err := conn.Conn.Write(frameBuf.Bytes())
return return n, err
} }
func (conn *ssConn) SetDeadline(t time.Time) error { func (conn *ssConn) SetDeadline(_ time.Time) error {
return ErrNotSupported return ErrNotSupported
} }
func (conn *ssConn) SetReadDeadline(t time.Time) error { func (conn *ssConn) SetReadDeadline(_ time.Time) error {
return ErrNotSupported return ErrNotSupported
} }
func (conn *ssConn) SetWriteDeadline(t time.Time) error { func (conn *ssConn) SetWriteDeadline(_ time.Time) error {
return ErrNotSupported return ErrNotSupported
} }
func (conn *ssConn) makePacket(w io.Writer, pktType byte, data []byte, padLen int) error { func (conn *ssConn) makePayloadPacket(w io.Writer, data []byte, padLen int) error {
payloadLen := len(data) payloadLen := len(data)
totalLen := payloadLen + padLen totalLen := payloadLen + padLen
if totalLen > maxPayloadLength { if totalLen > maxPayloadLength {
@ -238,7 +246,7 @@ func (conn *ssConn) makePacket(w io.Writer, pktType byte, data []byte, padLen in
pkt := make([]byte, pktHdrLength, pktHdrLength+payloadLen+padLen) pkt := make([]byte, pktHdrLength, pktHdrLength+payloadLen+padLen)
binary.BigEndian.PutUint16(pkt[0:], uint16(totalLen)) binary.BigEndian.PutUint16(pkt[0:], uint16(totalLen))
binary.BigEndian.PutUint16(pkt[2:], uint16(payloadLen)) binary.BigEndian.PutUint16(pkt[2:], uint16(payloadLen))
pkt[4] = pktType pkt[4] = pktPayload
pkt = append(pkt, data...) pkt = append(pkt, data...)
pkt = append(pkt, zeroPadBytes[:padLen]...) pkt = append(pkt, zeroPadBytes[:padLen]...)
@ -319,7 +327,7 @@ func (conn *ssConn) readPackets() error {
// Authenticate the packet, by comparing the received MAC with the one // Authenticate the packet, by comparing the received MAC with the one
// calculated over the ciphertext consumed off the network. // calculated over the ciphertext consumed off the network.
cmpMAC := conn.rxCrypto.mac.Sum(nil)[:macLength] cmpMAC := conn.rxCrypto.mac.Sum(nil)[:macLength]
if !hmac.Equal(cmpMAC, conn.receiveState.mac[:]) { if !hmac.Equal(cmpMAC, conn.receiveState.mac) {
return ErrInvalidPacket return ErrInvalidPacket
} }
@ -426,7 +434,7 @@ handshakeUDH:
// Attempt to process all the data seen so far as a response. // Attempt to process all the data seen so far as a response.
var seed []byte var seed []byte
n, seed, err = hs.parseServerHandshake(conn.receiveBuffer.Bytes()) n, seed, err = hs.parseServerHandshake(conn.receiveBuffer.Bytes())
if err == errMarkNotFoundYet { if errors.Is(err, errMarkNotFoundYet) {
// No response found yet, keep trying. // No response found yet, keep trying.
continue continue
} else if err != nil { } else if err != nil {
@ -444,7 +452,12 @@ handshakeUDH:
func (conn *ssConn) initCrypto(seed []byte) error { func (conn *ssConn) initCrypto(seed []byte) error {
// Use HKDF-SHA256 (Expand only, no Extract) to generate session keys from // Use HKDF-SHA256 (Expand only, no Extract) to generate session keys from
// initial keying material. // initial keying material.
okm := hkdfExpand(sha256.New, seed, nil, kdfSecretLength) rd := hkdf.Expand(sha256.New, seed, nil)
okm := make([]byte, kdfSecretLength)
if _, err := io.ReadFull(rd, okm); err != nil {
return err
}
var err error var err error
conn.txCrypto, err = newCryptoState(okm[0:32], okm[32:40], okm[80:112]) conn.txCrypto, err = newCryptoState(okm[0:32], okm[32:40], okm[80:112])
if err != nil { if err != nil {
@ -463,7 +476,7 @@ func (conn *ssConn) padBurst(burst *bytes.Buffer, sampleLen int) error {
// the ScrambleSuit MTU) is sampleLen bytes. // the ScrambleSuit MTU) is sampleLen bytes.
dataLen := burst.Len() % maxSegmentLength dataLen := burst.Len() % maxSegmentLength
padLen := 0 var padLen int
if sampleLen >= dataLen { if sampleLen >= dataLen {
padLen = sampleLen - dataLen padLen = sampleLen - dataLen
} else { } else {
@ -481,12 +494,12 @@ func (conn *ssConn) padBurst(burst *bytes.Buffer, sampleLen int) error {
if padLen > maxSegmentLength { if padLen > maxSegmentLength {
// Note: packetmorpher.py: getPadding is slightly wrong and only // Note: packetmorpher.py: getPadding is slightly wrong and only
// accounts for one of the two packet headers. // accounts for one of the two packet headers.
if err := conn.makePacket(burst, pktPayload, nil, 700-pktOverhead); err != nil { if err := conn.makePayloadPacket(burst, nil, 700-pktOverhead); err != nil {
return err return err
} }
return conn.makePacket(burst, pktPayload, nil, padLen-(700+2*pktOverhead)) return conn.makePayloadPacket(burst, nil, padLen-(700+2*pktOverhead))
} }
return conn.makePacket(burst, pktPayload, nil, padLen-pktOverhead) return conn.makePayloadPacket(burst, nil, padLen-pktOverhead)
} }
func newScrambleSuitClientConn(conn net.Conn, tStore *ssTicketStore, ca *ssClientArgs) (net.Conn, error) { func newScrambleSuitClientConn(conn net.Conn, tStore *ssTicketStore, ca *ssClientArgs) (net.Conn, error) {

@ -34,7 +34,6 @@ import (
"errors" "errors"
"fmt" "fmt"
"hash" "hash"
"io/ioutil"
"net" "net"
"os" "os"
"path" "path"
@ -56,9 +55,7 @@ const (
ticketMaxPadLength = 1388 ticketMaxPadLength = 1388
) )
var ( var errInvalidTicket = errors.New("scramblesuit: invalid serialized ticket")
errInvalidTicket = errors.New("scramblesuit: invalid serialized ticket")
)
type ssTicketStore struct { type ssTicketStore struct {
sync.Mutex sync.Mutex
@ -129,7 +126,7 @@ func (s *ssTicketStore) getTicket(addr net.Addr) (*ssTicket, error) {
} }
// No ticket was found, that's fine. // No ticket was found, that's fine.
return nil, nil return nil, nil //nolint:nilnil
} }
func (s *ssTicketStore) serialize() error { func (s *ssTicketStore) serialize() error {
@ -146,7 +143,7 @@ func (s *ssTicketStore) serialize() error {
if err != nil { if err != nil {
return err return err
} }
return ioutil.WriteFile(s.filePath, jsonStr, 0600) return os.WriteFile(s.filePath, jsonStr, 0o600)
} }
func loadTicketStore(stateDir string) (*ssTicketStore, error) { func loadTicketStore(stateDir string) (*ssTicketStore, error) {
@ -154,7 +151,7 @@ func loadTicketStore(stateDir string) (*ssTicketStore, error) {
s := &ssTicketStore{filePath: fPath} s := &ssTicketStore{filePath: fPath}
s.store = make(map[string]*ssTicket) s.store = make(map[string]*ssTicket)
f, err := ioutil.ReadFile(fPath) f, err := os.ReadFile(fPath)
if err != nil { if err != nil {
// No ticket store is fine. // No ticket store is fine.
if os.IsNotExist(err) { if os.IsNotExist(err) {
@ -167,7 +164,7 @@ func loadTicketStore(stateDir string) (*ssTicketStore, error) {
encMap := make(map[string]*ssTicketJSON) encMap := make(map[string]*ssTicketJSON)
if err = json.Unmarshal(f, &encMap); err != nil { if err = json.Unmarshal(f, &encMap); err != nil {
return nil, fmt.Errorf("failed to load ticket store '%s': '%s'", fPath, err) return nil, fmt.Errorf("failed to load ticket store '%s': %w", fPath, err)
} }
for k, v := range encMap { for k, v := range encMap {
raw, err := base32.StdEncoding.DecodeString(v.KeyTicket) raw, err := base32.StdEncoding.DecodeString(v.KeyTicket)

@ -1,67 +0,0 @@
/*
* Copyright (c) 2015, Yawning Angel <yawning at schwanenlied dot me>
* All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* * Redistributions of source code must retain the above copyright notice,
* this list of conditions and the following disclaimer.
*
* * Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
* ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
* LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
* CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
* SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
* INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
* CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
* ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
* POSSIBILITY OF SUCH DAMAGE.
*/
package scramblesuit
import (
"crypto/hmac"
"hash"
)
func hkdfExpand(hashFn func() hash.Hash, prk []byte, info []byte, l int) []byte {
// Why, yes. golang.org/x/crypto/hkdf exists, and is a fine
// implementation of HKDF. However it does both the extract
// and expand, while ScrambleSuit only does extract, with no
// way to separate the two steps.
h := hmac.New(hashFn, prk)
digestSz := h.Size()
if l > 255*digestSz {
panic("hkdf: requested OKM length > 255*HashLen")
}
var t []byte
okm := make([]byte, 0, l)
toAppend := l
ctr := byte(1)
for toAppend > 0 {
h.Reset()
_, _ = h.Write(t)
_, _ = h.Write(info)
_, _ = h.Write([]byte{ctr})
t = h.Sum(nil)
ctr++
aLen := digestSz
if toAppend < digestSz {
aLen = toAppend
}
okm = append(okm, t[:aLen]...)
toAppend -= aLen
}
return okm
}

@ -41,8 +41,10 @@ import (
"gitlab.com/yawning/obfs4.git/transports/scramblesuit" "gitlab.com/yawning/obfs4.git/transports/scramblesuit"
) )
var transportMapLock sync.Mutex var (
var transportMap map[string]base.Transport = make(map[string]base.Transport) transportMapLock sync.Mutex
transportMap map[string]base.Transport = make(map[string]base.Transport)
)
// Register registers a transport protocol. // Register registers a transport protocol.
func Register(transport base.Transport) error { func Register(transport base.Transport) error {
@ -64,7 +66,7 @@ func Transports() []string {
transportMapLock.Lock() transportMapLock.Lock()
defer transportMapLock.Unlock() defer transportMapLock.Unlock()
var ret []string ret := make([]string, 0, len(transportMap))
for name := range transportMap { for name := range transportMap {
ret = append(ret, name) ret = append(ret, name)
} }

Loading…
Cancel
Save