diff --git a/go.mod b/go.mod index d7864f3..58c82fd 100644 --- a/go.mod +++ b/go.mod @@ -7,7 +7,7 @@ require ( github.com/beevik/etree v1.3.0 github.com/gabriel-vasile/mimetype v1.4.3 github.com/pborman/getopt/v2 v2.1.0 - github.com/xmppo/go-xmpp v0.1.5-0.20240402113945-0ae62a33a21d + github.com/xmppo/go-xmpp v0.1.5-0.20240402134834-6e5d6e449eec salsa.debian.org/mdosch/xmppsrv v0.2.6 ) diff --git a/go.sum b/go.sum index 4568668..229fb77 100644 --- a/go.sum +++ b/go.sum @@ -27,6 +27,8 @@ github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5Cc github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/xmppo/go-xmpp v0.1.5-0.20240402113945-0ae62a33a21d h1:r5QzwgZpnLYvkmrhJbvLlDEl2w/EjNz4JbsHuktkOu8= github.com/xmppo/go-xmpp v0.1.5-0.20240402113945-0ae62a33a21d/go.mod h1:yyTnJMs6I6KUKv3BjXc4i3NU/iWBxY3yBGiUvUcW0Qg= +github.com/xmppo/go-xmpp v0.1.5-0.20240402134834-6e5d6e449eec h1:gJbmNIfHDSIZyarVQh42vV/iqhEzD4xyHOCqWG2sFfs= +github.com/xmppo/go-xmpp v0.1.5-0.20240402134834-6e5d6e449eec/go.mod h1:yyTnJMs6I6KUKv3BjXc4i3NU/iWBxY3yBGiUvUcW0Qg= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= diff --git a/main.go b/main.go index e26a142..0b20eda 100644 --- a/main.go +++ b/main.go @@ -6,7 +6,6 @@ package main import ( "bufio" - "context" "crypto/tls" "errors" "fmt" @@ -33,8 +32,7 @@ type configuration struct { alias string } -func closeAndExit(client *xmpp.Client, cancel context.CancelFunc, err error) { - cancel() +func closeAndExit(client *xmpp.Client, err error) { client.Close() if err != nil { log.Fatal(err) @@ -353,8 +351,7 @@ func main() { iqc := make(chan xmpp.IQ, defaultBufferSize) msgc := make(chan xmpp.Chat, defaultBufferSize) - ctx, cancel := context.WithCancel(context.Background()) - go rcvStanzas(client, iqc, msgc, ctx, cancel) + go rcvStanzas(client, iqc, msgc) for _, r := range getopt.Args() { var re recipientsType re.Jid = r @@ -372,7 +369,7 @@ func main() { for i, recipient := range recipients { validatedJid, err := MarshalJID(recipient.Jid) if err != nil { - closeAndExit(client, cancel, err) + closeAndExit(client, err) } recipients[i].Jid = validatedJid } @@ -381,52 +378,52 @@ func main() { case *flagOxGenPrivKeyX25519: validatedOwnJid, err := MarshalJID(user) if err != nil { - closeAndExit(client, cancel, err) + closeAndExit(client, err) } err = oxGenPrivKey(validatedOwnJid, client, iqc, *flagOxPassphrase, "x25519") if err != nil { - closeAndExit(client, cancel, err) + closeAndExit(client, err) } os.Exit(0) case *flagOxGenPrivKeyRSA: validatedOwnJid, err := MarshalJID(user) if err != nil { - closeAndExit(client, cancel, err) + closeAndExit(client, err) } err = oxGenPrivKey(validatedOwnJid, client, iqc, *flagOxPassphrase, "rsa") if err != nil { - closeAndExit(client, cancel, err) + closeAndExit(client, err) } os.Exit(0) case *flagOxImportPrivKey != "": validatedOwnJid, err := MarshalJID(user) if err != nil { - closeAndExit(client, cancel, err) + closeAndExit(client, err) } err = oxImportPrivKey(validatedOwnJid, *flagOxImportPrivKey, client, iqc) if err != nil { - closeAndExit(client, cancel, err) + closeAndExit(client, err) } os.Exit(0) case *flagOxDeleteNodes: validatedOwnJid, err := MarshalJID(user) if err != nil { - closeAndExit(client, cancel, err) + closeAndExit(client, err) } err = oxDeleteNodes(validatedOwnJid, client, iqc) if err != nil { - closeAndExit(client, cancel, err) + closeAndExit(client, err) } os.Exit(0) case *flagOx: validatedOwnJid, err := MarshalJID(user) if err != nil { - closeAndExit(client, cancel, err) + closeAndExit(client, err) } oxPrivKey, err = oxGetPrivKey(validatedOwnJid, *flagOxPassphrase) if err != nil { - closeAndExit(client, cancel, err) + closeAndExit(client, err) } } @@ -434,7 +431,7 @@ func main() { message, err = httpUpload(client, iqc, tlsConfig.ServerName, *flagHTTPUpload, timeout) if err != nil { - closeAndExit(client, cancel, err) + closeAndExit(client, err) } } @@ -444,7 +441,7 @@ func main() { // Check if the URI is valid. uri, err := validURI(message) if err != nil { - closeAndExit(client, cancel, err) + closeAndExit(client, err) } message = uri.String() } @@ -467,7 +464,7 @@ func main() { _, err = client.JoinMUCNoHistory(recipient.Jid, alias) } if err != nil { - closeAndExit(client, cancel, err) + closeAndExit(client, err) } } } @@ -479,7 +476,7 @@ func main() { // Send raw XML _, err = client.SendOrg(message) if err != nil { - closeAndExit(client, cancel, err) + closeAndExit(client, err) } case *flagInteractive: // Send in endless loop (for usage with e.g. "tail -f"). @@ -489,18 +486,13 @@ func main() { signal.Notify(c, os.Interrupt) go func() { for range c { - closeAndExit(client, cancel, nil) + closeAndExit(client, nil) } }() for { message, err = reader.ReadString('\n') - select { - case <-ctx.Done(): - return - default: - if err != nil { - closeAndExit(client, cancel, errors.New("failed to read from stdin")) - } + if err != nil { + closeAndExit(client, errors.New("failed to read from stdin")) } message = strings.TrimSuffix(message, "\n") @@ -524,7 +516,7 @@ func main() { } _, err = client.SendOrg(oxMessage) if err != nil { - closeAndExit(client, cancel, err) + closeAndExit(client, err) } default: _, err = client.Send(xmpp.Chat{ @@ -532,7 +524,7 @@ func main() { Type: msgType, Text: message, }) if err != nil { - closeAndExit(client, cancel, err) + closeAndExit(client, err) } } } @@ -544,7 +536,7 @@ func main() { signal.Notify(c, os.Interrupt) go func() { for range c { - closeAndExit(client, cancel, nil) + closeAndExit(client, nil) } }() for { @@ -651,7 +643,7 @@ func main() { } _, err = client.SendOrg(oxMessage) if err != nil { - closeAndExit(client, cancel, err) + closeAndExit(client, err) } default: _, err = client.Send(xmpp.Chat{ @@ -659,10 +651,10 @@ func main() { Type: msgType, Text: message, }) if err != nil { - closeAndExit(client, cancel, err) + closeAndExit(client, err) } } } } - closeAndExit(client, cancel, nil) + closeAndExit(client, nil) } diff --git a/stanzahandling.go b/stanzahandling.go index 386724c..f0df6b2 100644 --- a/stanzahandling.go +++ b/stanzahandling.go @@ -5,9 +5,9 @@ package main import ( - "context" "errors" "fmt" + "io" "log" "runtime" "time" @@ -44,35 +44,11 @@ func getIQ(id string, c chan xmpp.IQ, iqc chan xmpp.IQ) { } } -func rcvStanzas(client *xmpp.Client, iqc chan xmpp.IQ, msgc chan xmpp.Chat, ctx context.Context, cancel context.CancelFunc) { - var err error - var received interface{} - r := make(chan interface{}) +func rcvStanzas(client *xmpp.Client, iqc chan xmpp.IQ, msgc chan xmpp.Chat) { for { - select { - case <-ctx.Done(): - return - default: - go func() { - received, err = client.Recv() - r <- received - }() - select { - case <-ctx.Done(): - return - case <-r: - } - } - // Don't print errors if the program is getting shut down, - // as the errors might be triggered from trying to read from - // a closed connection. - select { - case <-ctx.Done(): - return - default: - if err != nil { - closeAndExit(client, cancel, err) - } + received, err := client.Recv() + if err != nil && err != io.EOF { + closeAndExit(client, err) } switch v := received.(type) { case xmpp.Chat: