From 5eff2d762322e61b60bd0df2c7c35b8b42653e41 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?CORNIERE=20R=C3=A9mi?= Date: Thu, 5 Dec 2019 18:12:00 +0100 Subject: [PATCH 1/5] Added callback to process errors after connection. Added tests and refactored a bit. --- .../xmpp_chat_client/xmpp_chat_client.go | 95 +++++ _examples/xmpp_component/xmpp_component.go | 6 +- _examples/xmpp_jukebox/xmpp_jukebox.go | 5 +- _examples/xmpp_oauth2/xmpp_oauth2.go | 6 +- _examples/xmpp_websocket/xmpp_websocket.go | 6 +- client.go | 23 +- client_test.go | 374 +++++++++++------ component.go | 20 +- component_test.go | 387 +++++++++--------- tcp_server_mock.go | 207 ++++++++++ 10 files changed, 795 insertions(+), 334 deletions(-) create mode 100644 _examples/xmpp_chat_client/xmpp_chat_client.go diff --git a/_examples/xmpp_chat_client/xmpp_chat_client.go b/_examples/xmpp_chat_client/xmpp_chat_client.go new file mode 100644 index 0000000..2b2d2e7 --- /dev/null +++ b/_examples/xmpp_chat_client/xmpp_chat_client.go @@ -0,0 +1,95 @@ +package main + +/* +xmpp_chat_client is a demo client that connect on an XMPP server to chat with other members +Note that this example sends to a very specific user. User logic is not implemented here. +*/ + +import ( + . "bufio" + "fmt" + "os" + + "gosrc.io/xmpp" + "gosrc.io/xmpp/stanza" +) + +const ( + currentUserAddress = "localhost:5222" + currentUserJid = "testuser@localhost" + currentUserPass = "testpass" + correspondantJid = "testuser2@localhost" +) + +func main() { + config := xmpp.Config{ + TransportConfiguration: xmpp.TransportConfiguration{ + Address: currentUserAddress, + }, + Jid: currentUserJid, + Credential: xmpp.Password(currentUserPass), + Insecure: true} + + var client *xmpp.Client + var err error + router := xmpp.NewRouter() + router.HandleFunc("message", handleMessage) + if client, err = xmpp.NewClient(config, router, errorHandler); err != nil { + fmt.Println("Error new client") + } + + // Connecting client and handling messages + // To use a stream manager, just write something like this instead : + //cm := xmpp.NewStreamManager(client, startMessaging) + //log.Fatal(cm.Run()) //=> this will lock the calling goroutine + + if err = client.Connect(); err != nil { + fmt.Printf("XMPP connection failed: %s", err) + return + } + startMessaging(client) + +} + +func startMessaging(client xmpp.Sender) { + reader := NewReader(os.Stdin) + textChan := make(chan string) + var text string + for { + fmt.Print("Enter text: ") + go readInput(reader, textChan) + select { + case <-killChan: + return + case text = <-textChan: + reply := stanza.Message{Attrs: stanza.Attrs{To: correspondantJid}, Body: text} + err := client.Send(reply) + if err != nil { + fmt.Printf("There was a problem sending the message : %v", reply) + return + } + } + } +} + +func readInput(reader *Reader, textChan chan string) { + text, _ := reader.ReadString('\n') + textChan <- text +} + +var killChan = make(chan struct{}) + +// If an error occurs, this is used +func errorHandler(err error) { + fmt.Printf("%v", err) + killChan <- struct{}{} +} + +func handleMessage(s xmpp.Sender, p stanza.Packet) { + msg, ok := p.(stanza.Message) + if !ok { + _, _ = fmt.Fprintf(os.Stdout, "Ignoring packet: %T\n", p) + return + } + _, _ = fmt.Fprintf(os.Stdout, "Body = %s - from = %s\n", msg.Body, msg.From) +} diff --git a/_examples/xmpp_component/xmpp_component.go b/_examples/xmpp_component/xmpp_component.go index 0452888..7f676cb 100644 --- a/_examples/xmpp_component/xmpp_component.go +++ b/_examples/xmpp_component/xmpp_component.go @@ -35,7 +35,7 @@ func main() { IQNamespaces("jabber:iq:version"). HandlerFunc(handleVersion) - component, err := xmpp.NewComponent(opts, router) + component, err := xmpp.NewComponent(opts, router, handleError) if err != nil { log.Fatalf("%+v", err) } @@ -47,6 +47,10 @@ func main() { log.Fatal(cm.Run()) } +func handleError(err error) { + fmt.Println(err.Error()) +} + func handleMessage(_ xmpp.Sender, p stanza.Packet) { msg, ok := p.(stanza.Message) if !ok { diff --git a/_examples/xmpp_jukebox/xmpp_jukebox.go b/_examples/xmpp_jukebox/xmpp_jukebox.go index 91f453c..ce7ebc9 100644 --- a/_examples/xmpp_jukebox/xmpp_jukebox.go +++ b/_examples/xmpp_jukebox/xmpp_jukebox.go @@ -53,7 +53,7 @@ func main() { handleIQ(s, p, player) }) - client, err := xmpp.NewClient(config, router) + client, err := xmpp.NewClient(config, router, errorHandler) if err != nil { log.Fatalf("%+v", err) } @@ -61,6 +61,9 @@ func main() { cm := xmpp.NewStreamManager(client, nil) log.Fatal(cm.Run()) } +func errorHandler(err error) { + fmt.Println(err.Error()) +} func handleMessage(s xmpp.Sender, p stanza.Packet, player *mpg123.Player) { msg, ok := p.(stanza.Message) diff --git a/_examples/xmpp_oauth2/xmpp_oauth2.go b/_examples/xmpp_oauth2/xmpp_oauth2.go index f322447..89b2639 100644 --- a/_examples/xmpp_oauth2/xmpp_oauth2.go +++ b/_examples/xmpp_oauth2/xmpp_oauth2.go @@ -28,7 +28,7 @@ func main() { router := xmpp.NewRouter() router.HandleFunc("message", handleMessage) - client, err := xmpp.NewClient(config, router) + client, err := xmpp.NewClient(config, router, errorHandler) if err != nil { log.Fatalf("%+v", err) } @@ -39,6 +39,10 @@ func main() { log.Fatal(cm.Run()) } +func errorHandler(err error) { + fmt.Println(err.Error()) +} + func handleMessage(s xmpp.Sender, p stanza.Packet) { msg, ok := p.(stanza.Message) if !ok { diff --git a/_examples/xmpp_websocket/xmpp_websocket.go b/_examples/xmpp_websocket/xmpp_websocket.go index 428a1d1..c8c0620 100644 --- a/_examples/xmpp_websocket/xmpp_websocket.go +++ b/_examples/xmpp_websocket/xmpp_websocket.go @@ -26,7 +26,7 @@ func main() { router := xmpp.NewRouter() router.HandleFunc("message", handleMessage) - client, err := xmpp.NewClient(config, router) + client, err := xmpp.NewClient(config, router, errorHandler) if err != nil { log.Fatalf("%+v", err) } @@ -37,6 +37,10 @@ func main() { log.Fatal(cm.Run()) } +func errorHandler(err error) { + fmt.Println(err.Error()) +} + func handleMessage(s xmpp.Sender, p stanza.Packet) { msg, ok := p.(stanza.Message) if !ok { diff --git a/client.go b/client.go index 14537db..cc152f3 100644 --- a/client.go +++ b/client.go @@ -98,6 +98,8 @@ type Client struct { router *Router // Track and broadcast connection state EventManager + // Handle errors from client execution + ErrorHandler func(error) } /* @@ -107,7 +109,7 @@ Setting up the client / Checking the parameters // NewClient generates a new XMPP client, based on Config passed as parameters. // If host is not specified, the DNS SRV should be used to find the host from the domainpart of the JID. // Default the port to 5222. -func NewClient(config Config, r *Router) (c *Client, err error) { +func NewClient(config Config, r *Router, errorHandler func(error)) (c *Client, err error) { // Parse JID if config.parsedJid, err = NewJid(config.Jid); err != nil { err = errors.New("missing jid") @@ -140,6 +142,7 @@ func NewClient(config Config, r *Router) (c *Client, err error) { c = new(Client) c.config = config c.router = r + c.ErrorHandler = errorHandler if c.config.ConnectTimeout == 0 { c.config.ConnectTimeout = 15 // 15 second as default @@ -185,13 +188,10 @@ func (c *Client) Resume(state SMState) error { // Start the keepalive go routine keepaliveQuit := make(chan struct{}) - go keepalive(c.transport, keepaliveQuit) + go keepalive(c, keepaliveQuit) // Start the receiver go routine state = c.Session.SMState - // Leaving this channel here for later. Not used atm. We should return this instead of an error because right - // now the returned error is lost in limbo. - errChan := make(chan error) - go c.recv(state, keepaliveQuit, errChan) + go c.recv(state, keepaliveQuit) // We're connected and can now receive and send messages. //fmt.Fprintf(client.conn, "%s%s", "chat", "Online") @@ -270,11 +270,11 @@ func (c *Client) sendWithWriter(writer io.Writer, packet []byte) error { // Go routines // Loop: Receive data from server -func (c *Client) recv(state SMState, keepaliveQuit chan<- struct{}, errChan chan<- error) { +func (c *Client) recv(state SMState, keepaliveQuit chan<- struct{}) { for { val, err := stanza.NextPacket(c.transport.GetDecoder()) if err != nil { - errChan <- err + c.ErrorHandler(err) close(keepaliveQuit) c.disconnected(state) return @@ -286,7 +286,7 @@ func (c *Client) recv(state SMState, keepaliveQuit chan<- struct{}, errChan chan c.router.route(c, val) close(keepaliveQuit) c.streamError(packet.Error.Local, packet.Text) - errChan <- errors.New("stream error: " + packet.Error.Local) + c.ErrorHandler(errors.New("stream error: " + packet.Error.Local)) return // Process Stream management nonzas case stanza.SMRequest: @@ -296,7 +296,7 @@ func (c *Client) recv(state SMState, keepaliveQuit chan<- struct{}, errChan chan }, H: state.Inbound} err = c.Send(answer) if err != nil { - errChan <- err + c.ErrorHandler(err) return } default: @@ -312,8 +312,9 @@ func (c *Client) recv(state SMState, keepaliveQuit chan<- struct{}, errChan chan // Loop: send whitespace keepalive to server // This is use to keep the connection open, but also to detect connection loss // and trigger proper client connection shutdown. -func keepalive(transport Transport, quit <-chan struct{}) { +func keepalive(c *Client, quit <-chan struct{}) { // TODO: Make keepalive interval configurable + transport := c.transport ticker := time.NewTicker(30 * time.Second) for { select { diff --git a/client_test.go b/client_test.go index 2636f29..15e104f 100644 --- a/client_test.go +++ b/client_test.go @@ -1,6 +1,7 @@ package xmpp import ( + "context" "encoding/xml" "errors" "fmt" @@ -14,15 +15,14 @@ import ( const ( // Default port is not standard XMPP port to avoid interfering // with local running XMPP server - testXMPPAddress = "localhost:15222" - - defaultTimeout = 2 * time.Second + testXMPPAddress = "localhost:15222" + testClientDomain = "localhost" ) func TestClient_Connect(t *testing.T) { // Setup Mock server mock := ServerMock{} - mock.Start(t, testXMPPAddress, handlerConnectSuccess) + mock.Start(t, testXMPPAddress, handlerClientConnectSuccess) // Test / Check result config := Config{ @@ -36,7 +36,7 @@ func TestClient_Connect(t *testing.T) { var client *Client var err error router := NewRouter() - if client, err = NewClient(config, router); err != nil { + if client, err = NewClient(config, router, clientDefaultErrorHandler); err != nil { t.Errorf("connect create XMPP client: %s", err) } @@ -64,7 +64,7 @@ func TestClient_NoInsecure(t *testing.T) { var client *Client var err error router := NewRouter() - if client, err = NewClient(config, router); err != nil { + if client, err = NewClient(config, router, clientDefaultErrorHandler); err != nil { t.Errorf("cannot create XMPP client: %s", err) } @@ -94,7 +94,7 @@ func TestClient_FeaturesTracking(t *testing.T) { var client *Client var err error router := NewRouter() - if client, err = NewClient(config, router); err != nil { + if client, err = NewClient(config, router, clientDefaultErrorHandler); err != nil { t.Errorf("cannot create XMPP client: %s", err) } @@ -109,7 +109,7 @@ func TestClient_FeaturesTracking(t *testing.T) { func TestClient_RFC3921Session(t *testing.T) { // Setup Mock server mock := ServerMock{} - mock.Start(t, testXMPPAddress, handlerConnectWithSession) + mock.Start(t, testXMPPAddress, handlerClientConnectWithSession) // Test / Check result config := Config{ @@ -124,7 +124,7 @@ func TestClient_RFC3921Session(t *testing.T) { var client *Client var err error router := NewRouter() - if client, err = NewClient(config, router); err != nil { + if client, err = NewClient(config, router, clientDefaultErrorHandler); err != nil { t.Errorf("connect create XMPP client: %s", err) } @@ -135,48 +135,254 @@ func TestClient_RFC3921Session(t *testing.T) { mock.Stop() } +// Testing sending an IQ to the mock server and reading its response. +func TestClient_SendIQ(t *testing.T) { + done := make(chan struct{}) + // Handler for Mock server + h := func(t *testing.T, c net.Conn) { + handlerClientConnectSuccess(t, c) + discardPresence(t, c) + respondToIQ(t, c) + done <- struct{}{} + } + client, mock := mockClientConnection(t, h, testClientIqPort) + + ctx, _ := context.WithTimeout(context.Background(), 30*time.Second) + iqReq := stanza.NewIQ(stanza.Attrs{Type: stanza.IQTypeGet, From: "test1@localhost/mremond-mbp", To: defaultServerName, Id: defaultStreamID, Lang: "en"}) + disco := iqReq.DiscoInfo() + iqReq.Payload = disco + + // Handle a possible error + errChan := make(chan error) + errorHandler := func(err error) { + errChan <- err + } + client.ErrorHandler = errorHandler + res, err := client.SendIQ(ctx, iqReq) + if err != nil { + t.Errorf(err.Error()) + } + + select { + case <-res: // If the server responds with an IQ, we pass the test + case err := <-errChan: // If the server sends an error, or there is a connection error + t.Errorf(err.Error()) + case <-time.After(defaultChannelTimeout): // If we timeout + t.Errorf("Failed to receive response, to sent IQ, from mock server") + } + select { + case <-done: + mock.Stop() + case <-time.After(defaultChannelTimeout): + t.Errorf("The mock server failed to finish its job !") + } +} + +func TestClient_SendIQFail(t *testing.T) { + done := make(chan struct{}) + // Handler for Mock server + h := func(t *testing.T, c net.Conn) { + handlerClientConnectSuccess(t, c) + discardPresence(t, c) + respondToIQ(t, c) + done <- struct{}{} + } + client, mock := mockClientConnection(t, h, testClientIqFailPort) + + //================== + // Create an IQ to send + ctx, _ := context.WithTimeout(context.Background(), 30*time.Second) + iqReq := stanza.NewIQ(stanza.Attrs{Type: stanza.IQTypeGet, From: "test1@localhost/mremond-mbp", To: defaultServerName, Id: defaultStreamID, Lang: "en"}) + disco := iqReq.DiscoInfo() + iqReq.Payload = disco + // Removing the id to make the stanza invalid. The IQ constructor makes a random one if none is specified + // so we need to overwrite it. + iqReq.Id = "" + + // Handle a possible error + errChan := make(chan error) + errorHandler := func(err error) { + errChan <- err + } + client.ErrorHandler = errorHandler + res, _ := client.SendIQ(ctx, iqReq) + + // Test + select { + case <-res: // If the server responds with an IQ + t.Errorf("Server should not respond with an IQ since the request is expected to be invalid !") + case <-errChan: // If the server sends an error, the test passes + case <-time.After(defaultChannelTimeout): // If we timeout + t.Errorf("Failed to receive response, to sent IQ, from mock server") + } + select { + case <-done: + mock.Stop() + case <-time.After(defaultChannelTimeout): + t.Errorf("The mock server failed to finish its job !") + } +} + +func TestClient_SendRaw(t *testing.T) { + done := make(chan struct{}) + // Handler for Mock server + h := func(t *testing.T, c net.Conn) { + handlerClientConnectSuccess(t, c) + discardPresence(t, c) + respondToIQ(t, c) + done <- struct{}{} + } + type testCase struct { + req string + shouldErr bool + port int + } + testRequests := make(map[string]testCase) + // Sending a correct IQ of type get. Not supposed to err + testRequests["Correct IQ"] = testCase{ + req: ``, + shouldErr: false, + port: testClientRawPort + 100, + } + // Sending an IQ with a missing ID. Should err + testRequests["IQ with missing ID"] = testCase{ + req: ``, + shouldErr: true, + port: testClientRawPort, + } + + // A handler for the client. + // In the failing test, the server returns a stream error, which triggers this handler, client side. + errChan := make(chan error) + errHandler := func(err error) { + errChan <- err + } + + // Tests for all the IQs + for name, tcase := range testRequests { + t.Run(name, func(st *testing.T) { + //Connecting to a mock server, initialized with given port and handler function + c, m := mockClientConnection(t, h, tcase.port) + c.ErrorHandler = errHandler + // Sending raw xml from test case + err := c.SendRaw(tcase.req) + if err != nil { + t.Errorf("Error sending Raw string") + } + // Just wait a little so the message has time to arrive + select { + // We don't use the default "long" timeout here because waiting it out means passing the test. + case <-time.After(100 * time.Millisecond): + case err = <-errChan: + if err == nil && tcase.shouldErr { + t.Errorf("Failed to get closing stream err") + } else if err != nil && !tcase.shouldErr { + t.Errorf("This test is not supposed to err !") + } + } + c.transport.Close() + select { + case <-done: + m.Stop() + case <-time.After(defaultChannelTimeout): + t.Errorf("The mock server failed to finish its job !") + } + }) + } +} + +func TestClient_Disconnect(t *testing.T) { + c, m := mockClientConnection(t, handlerClientConnectSuccess, testClientBasePort) + err := c.transport.Ping() + if err != nil { + t.Errorf("Could not ping but not disconnected yet") + } + c.Disconnect() + err = c.transport.Ping() + if err == nil { + t.Errorf("Did not disconnect properly") + } + m.Stop() +} + +func TestClient_DisconnectStreamManager(t *testing.T) { + // Init mock server + // Setup Mock server + mock := ServerMock{} + mock.Start(t, testXMPPAddress, handlerAbortTLS) + + // Test / Check result + config := Config{ + TransportConfiguration: TransportConfiguration{ + Address: testXMPPAddress, + }, + Jid: "test@localhost", + Credential: Password("test"), + } + + var client *Client + var err error + router := NewRouter() + if client, err = NewClient(config, router, clientDefaultErrorHandler); err != nil { + t.Errorf("cannot create XMPP client: %s", err) + } + + sman := NewStreamManager(client, nil) + errChan := make(chan error) + runSMan := func(errChan chan error) { + errChan <- sman.Run() + } + + go runSMan(errChan) + select { + case <-errChan: + case <-time.After(defaultChannelTimeout): + // When insecure is not allowed: + t.Errorf("should fail as insecure connection is not allowed and server does not support TLS") + } + mock.Stop() +} + //============================================================================= // Basic XMPP Server Mock Handlers. -const serverStreamOpen = "" - // Test connection with a basic straightforward workflow -func handlerConnectSuccess(t *testing.T, c net.Conn) { +func handlerClientConnectSuccess(t *testing.T, c net.Conn) { decoder := xml.NewDecoder(c) - checkOpenStream(t, c, decoder) + checkClientOpenStream(t, c, decoder) sendStreamFeatures(t, c, decoder) // Send initial features readAuth(t, decoder) fmt.Fprintln(c, "") - checkOpenStream(t, c, decoder) // Reset stream - sendBindFeature(t, c, decoder) // Send post auth features + checkClientOpenStream(t, c, decoder) // Reset stream + sendBindFeature(t, c, decoder) // Send post auth features bind(t, c, decoder) } // We expect client will abort on TLS func handlerAbortTLS(t *testing.T, c net.Conn) { decoder := xml.NewDecoder(c) - checkOpenStream(t, c, decoder) + checkClientOpenStream(t, c, decoder) sendStreamFeatures(t, c, decoder) // Send initial features } // Test connection with mandatory session (RFC-3921) -func handlerConnectWithSession(t *testing.T, c net.Conn) { +func handlerClientConnectWithSession(t *testing.T, c net.Conn) { decoder := xml.NewDecoder(c) - checkOpenStream(t, c, decoder) + checkClientOpenStream(t, c, decoder) sendStreamFeatures(t, c, decoder) // Send initial features readAuth(t, decoder) fmt.Fprintln(c, "") - checkOpenStream(t, c, decoder) // Reset stream - sendRFC3921Feature(t, c, decoder) // Send post auth features + checkClientOpenStream(t, c, decoder) // Reset stream + sendRFC3921Feature(t, c, decoder) // Send post auth features bind(t, c, decoder) session(t, c, decoder) } -func checkOpenStream(t *testing.T, c net.Conn, decoder *xml.Decoder) { +func checkClientOpenStream(t *testing.T, c net.Conn, decoder *xml.Decoder) { c.SetDeadline(time.Now().Add(defaultTimeout)) defer c.SetDeadline(time.Time{}) @@ -202,105 +408,35 @@ func checkOpenStream(t *testing.T, c net.Conn, decoder *xml.Decoder) { } } -func sendStreamFeatures(t *testing.T, c net.Conn, _ *xml.Decoder) { - // This is a basic server, supporting only 1 stream feature: SASL Plain Auth - features := ` - - PLAIN - -` - if _, err := fmt.Fprintln(c, features); err != nil { - t.Errorf("cannot send stream feature: %s", err) +func mockClientConnection(t *testing.T, serverHandler func(*testing.T, net.Conn), port int) (*Client, ServerMock) { + mock := ServerMock{} + testServerAddress := fmt.Sprintf("%s:%d", testClientDomain, port) + + mock.Start(t, testServerAddress, serverHandler) + + config := Config{ + TransportConfiguration: TransportConfiguration{ + Address: testServerAddress, + }, + Jid: "test@localhost", + Credential: Password("test"), + Insecure: true} + + var client *Client + var err error + router := NewRouter() + if client, err = NewClient(config, router, clientDefaultErrorHandler); err != nil { + t.Errorf("connect create XMPP client: %s", err) } + + if err = client.Connect(); err != nil { + t.Errorf("XMPP connection failed: %s", err) + } + + return client, mock } -// TODO return err in case of error reading the auth params -func readAuth(t *testing.T, decoder *xml.Decoder) string { - se, err := stanza.NextStart(decoder) - if err != nil { - t.Errorf("cannot read auth: %s", err) - return "" - } - - var nv interface{} - nv = &stanza.SASLAuth{} - // Decode element into pointer storage - if err = decoder.DecodeElement(nv, &se); err != nil { - t.Errorf("cannot decode auth: %s", err) - return "" - } - - switch v := nv.(type) { - case *stanza.SASLAuth: - return v.Value - } - return "" -} - -func sendBindFeature(t *testing.T, c net.Conn, _ *xml.Decoder) { - // This is a basic server, supporting only 1 stream feature after auth: resource binding - features := ` - -` - if _, err := fmt.Fprintln(c, features); err != nil { - t.Errorf("cannot send stream feature: %s", err) - } -} - -func sendRFC3921Feature(t *testing.T, c net.Conn, _ *xml.Decoder) { - // This is a basic server, supporting only 2 features after auth: resource & session binding - features := ` - - -` - if _, err := fmt.Fprintln(c, features); err != nil { - t.Errorf("cannot send stream feature: %s", err) - } -} - -func bind(t *testing.T, c net.Conn, decoder *xml.Decoder) { - se, err := stanza.NextStart(decoder) - if err != nil { - t.Errorf("cannot read bind: %s", err) - return - } - - iq := &stanza.IQ{} - // Decode element into pointer storage - if err = decoder.DecodeElement(&iq, &se); err != nil { - t.Errorf("cannot decode bind iq: %s", err) - return - } - - // TODO Check all elements - switch iq.Payload.(type) { - case *stanza.Bind: - result := ` - - %s - -` - fmt.Fprintf(c, result, iq.Id, "test@localhost/test") // TODO use real JID - } -} - -func session(t *testing.T, c net.Conn, decoder *xml.Decoder) { - se, err := stanza.NextStart(decoder) - if err != nil { - t.Errorf("cannot read session: %s", err) - return - } - - iq := &stanza.IQ{} - // Decode element into pointer storage - if err = decoder.DecodeElement(&iq, &se); err != nil { - t.Errorf("cannot decode session iq: %s", err) - return - } - - switch iq.Payload.(type) { - case *stanza.StreamSession: - result := `` - fmt.Fprintf(c, result, iq.Id) - } +// This really should not be used as is. +// It's just meant to be a placeholder when error handling is not needed at this level +func clientDefaultErrorHandler(err error) { } diff --git a/component.go b/component.go index 471f1db..2f61aef 100644 --- a/component.go +++ b/component.go @@ -48,11 +48,12 @@ type Component struct { transport Transport // read / write - socketProxy io.ReadWriter // TODO + socketProxy io.ReadWriter // TODO + ErrorHandler func(error) } -func NewComponent(opts ComponentOptions, r *Router) (*Component, error) { - c := Component{ComponentOptions: opts, router: r} +func NewComponent(opts ComponentOptions, r *Router, errorHandler func(error)) (*Component, error) { + c := Component{ComponentOptions: opts, router: r, ErrorHandler: errorHandler} return &c, nil } @@ -104,11 +105,8 @@ func (c *Component) Resume(sm SMState) error { case stanza.Handshake: // Start the receiver go routine c.updateState(StateSessionEstablished) - // Leaving this channel here for later. Not used atm. We should return this instead of an error because right - // now the returned error is lost in limbo. - errChan := make(chan error) - go c.recv(errChan) // Sends to errChan - return err // Should be empty at this point + go c.recv() + return err // Should be empty at this point default: c.updateState(StatePermanentError) return NewConnError(errors.New("expecting handshake result, got "+v.Name()), true) @@ -128,13 +126,13 @@ func (c *Component) SetHandler(handler EventHandler) { } // Receiver Go routine receiver -func (c *Component) recv(errChan chan<- error) { +func (c *Component) recv() { for { val, err := stanza.NextPacket(c.transport.GetDecoder()) if err != nil { c.updateState(StateDisconnected) - errChan <- err + c.ErrorHandler(err) return } // Handle stream errors @@ -142,7 +140,7 @@ func (c *Component) recv(errChan chan<- error) { case stanza.StreamError: c.router.route(c, val) c.streamError(p.Error.Local, p.Text) - errChan <- errors.New("stream error: " + p.Error.Local) + c.ErrorHandler(errors.New("stream error: " + p.Error.Local)) return } c.router.route(c, val) diff --git a/component_test.go b/component_test.go index 4e115f0..48963a5 100644 --- a/component_test.go +++ b/component_test.go @@ -5,6 +5,7 @@ import ( "encoding/xml" "errors" "fmt" + "github.com/google/uuid" "gosrc.io/xmpp/stanza" "net" "strings" @@ -15,19 +16,7 @@ import ( // Tests are ran in parallel, so each test creating a server must use a different port so we do not get any // conflict. Using iota for this should do the trick. const ( - testComponentDomain = "localhost" - defaultServerName = "testServer" - defaultStreamID = "91bd0bba-012f-4d92-bb17-5fc41e6fe545" - defaultComponentName = "Test Component" - - // Default port is not standard XMPP port to avoid interfering - // with local running XMPP server - testHandshakePort = iota + 15222 - testDecoderPort - testSendIqPort - testSendRawPort - testDisconnectPort - testSManDisconnectPort + defaultChannelTimeout = 5 * time.Second ) func TestHandshake(t *testing.T) { @@ -48,16 +37,14 @@ func TestHandshake(t *testing.T) { // Tests connection process with a handshake exchange // Tests multiple session IDs. All connections should generate a unique stream ID -func TestGenerateHandshake(t *testing.T) { +func TestGenerateHandshakeId(t *testing.T) { // Using this array with a channel to make a queue of values to test // These are stream IDs that will be used to test the connection process, mixing them with the "secret" to generate // some handshake value - var uuidsArray = [5]string{ - "cc9b3249-9582-4780-825f-4311b42f9b0e", - "bba8be3c-d98e-4e26-b9bb-9ed34578a503", - "dae72822-80e8-496b-b763-ab685f53a188", - "a45d6c06-de49-4bb0-935b-1a2201b71028", - "7dc6924f-0eca-4237-9898-18654b8d891e", + var uuidsArray = [5]string{} + for i := 1; i < len(uuidsArray); i++ { + id, _ := uuid.NewRandom() + uuidsArray[i] = id.String() } // Channel to pass stream IDs as a queue @@ -95,7 +82,7 @@ func TestGenerateHandshake(t *testing.T) { Type: "service", } router := NewRouter() - c, err := NewComponent(opts, router) + c, err := NewComponent(opts, router, componentDefaultErrorHandler) if err != nil { t.Errorf("%+v", err) } @@ -126,7 +113,7 @@ func TestStreamManager(t *testing.T) { // The decoder is expected to be built after a valid connection // Based on the xmpp_component example. func TestDecoder(t *testing.T) { - c, _ := mockConnection(t, testDecoderPort, handlerForComponentHandshakeDefaultID) + c, _ := mockComponentConnection(t, testDecoderPort, handlerForComponentHandshakeDefaultID) if c.transport.GetDecoder() == nil { t.Errorf("Failed to initialize decoder. Decoder is nil.") } @@ -134,39 +121,103 @@ func TestDecoder(t *testing.T) { // Tests sending an IQ to the server, and getting the response func TestSendIq(t *testing.T) { + done := make(chan struct{}) + h := func(t *testing.T, c net.Conn) { + handlerForComponentIQSend(t, c) + done <- struct{}{} + } + //Connecting to a mock server, initialized with given port and handler function - c, m := mockConnection(t, testSendIqPort, handlerForComponentIQSend) + c, m := mockComponentConnection(t, testSendIqPort, h) ctx, _ := context.WithTimeout(context.Background(), 30*time.Second) iqReq := stanza.NewIQ(stanza.Attrs{Type: stanza.IQTypeGet, From: "test1@localhost/mremond-mbp", To: defaultServerName, Id: defaultStreamID, Lang: "en"}) disco := iqReq.DiscoInfo() iqReq.Payload = disco + // Handle a possible error + errChan := make(chan error) + errorHandler := func(err error) { + errChan <- err + } + c.ErrorHandler = errorHandler + var res chan stanza.IQ res, _ = c.SendIQ(ctx, iqReq) select { case <-res: - case <-time.After(100 * time.Millisecond): + case err := <-errChan: + t.Errorf(err.Error()) + case <-time.After(defaultChannelTimeout): t.Errorf("Failed to receive response, to sent IQ, from mock server") } - m.Stop() + select { + case <-done: + m.Stop() + case <-time.After(defaultChannelTimeout): + t.Errorf("The mock server failed to finish its job !") + } +} + +// Checking that error handling is done properly client side when an invalid IQ is sent and the server responds in kind. +func TestSendIqFail(t *testing.T) { + done := make(chan struct{}) + h := func(t *testing.T, c net.Conn) { + handlerForComponentIQSend(t, c) + done <- struct{}{} + } + //Connecting to a mock server, initialized with given port and handler function + c, m := mockComponentConnection(t, testSendIqFailPort, h) + + ctx, _ := context.WithTimeout(context.Background(), 30*time.Second) + iqReq := stanza.NewIQ(stanza.Attrs{Type: stanza.IQTypeGet, From: "test1@localhost/mremond-mbp", To: defaultServerName, Id: defaultStreamID, Lang: "en"}) + + // Removing the id to make the stanza invalid. The IQ constructor makes a random one if none is specified + // so we need to overwrite it. + iqReq.Id = "" + disco := iqReq.DiscoInfo() + iqReq.Payload = disco + + errChan := make(chan error) + errorHandler := func(err error) { + errChan <- err + } + c.ErrorHandler = errorHandler + + var res chan stanza.IQ + res, _ = c.SendIQ(ctx, iqReq) + + select { + case r := <-res: // Do we get an IQ response from the server ? + t.Errorf("We should not be getting an IQ response here : this should fail !") + fmt.Println(r) + case <-errChan: // Do we get a stream error from the server ? + // If we get an error from the server, the test passes. + case <-time.After(defaultChannelTimeout): // Timeout ? + t.Errorf("Failed to receive response, to sent IQ, from mock server") + } + + select { + case <-done: + m.Stop() + case <-time.After(defaultChannelTimeout): + t.Errorf("The mock server failed to finish its job !") + } } // Tests sending raw xml to the mock server. -// TODO : check the server response client side ? // Right now, the server response is not checked and an err is passed in a channel if the test is supposed to err. // In this test, we use IQs func TestSendRaw(t *testing.T) { - // Error channel for the handler - errChan := make(chan error) + done := make(chan struct{}) // Handler for the mock server h := func(t *testing.T, c net.Conn) { // Completes the connection by exchanging handshakes handlerForComponentHandshakeDefaultID(t, c) - receiveRawIq(t, c, errChan) - return + receiveIq(c, xml.NewDecoder(c)) + done <- struct{}{} } type testCase struct { @@ -185,12 +236,19 @@ func TestSendRaw(t *testing.T) { shouldErr: true, } + // A handler for the component. + // In the failing test, the server returns a stream error, which triggers this handler, component side. + errChan := make(chan error) + errHandler := func(err error) { + errChan <- err + } + // Tests for all the IQs for name, tcase := range testRequests { t.Run(name, func(st *testing.T) { //Connecting to a mock server, initialized with given port and handler function - c, m := mockConnection(t, testSendRawPort, h) - + c, m := mockComponentConnection(t, testSendRawPort, h) + c.ErrorHandler = errHandler // Sending raw xml from test case err := c.SendRaw(tcase.req) if err != nil { @@ -198,21 +256,29 @@ func TestSendRaw(t *testing.T) { } // Just wait a little so the message has time to arrive select { - case <-time.After(100 * time.Millisecond): + // We don't use the default "long" timeout here because waiting it out means passing the test. + case <-time.After(200 * time.Millisecond): case err = <-errChan: if err == nil && tcase.shouldErr { t.Errorf("Failed to get closing stream err") + } else if err != nil && !tcase.shouldErr { + t.Errorf("This test is not supposed to err ! => %s", err.Error()) } } c.transport.Close() - m.Stop() + select { + case <-done: + m.Stop() + case <-time.After(defaultChannelTimeout): + t.Errorf("The mock server failed to finish its job !") + } }) } } // Tests the Disconnect method for Components func TestDisconnect(t *testing.T) { - c, m := mockConnection(t, testDisconnectPort, handlerForComponentHandshakeDefaultID) + c, m := mockComponentConnection(t, testDisconnectPort, handlerForComponentHandshakeDefaultID) err := c.transport.Ping() if err != nil { t.Errorf("Could not ping but not disconnected yet") @@ -257,14 +323,97 @@ func TestStreamManagerDisconnect(t *testing.T) { //============================================================================= // Basic XMPP Server Mock Handlers. -// Performs a Component connection with a handshake. It uses a default ID defined in this file as a constant. -// Used in the mock server as a Handler -func handlerForComponentHandshakeDefaultID(t *testing.T, c net.Conn) { - decoder := xml.NewDecoder(c) - checkOpenStreamHandshakeDefaultID(t, c, decoder) - readHandshakeComponent(t, decoder) - fmt.Fprintln(c, "") // That's all the server needs to return (see xep-0114) - return + +//=============================== +// Init mock server and connection +// Creating a mock server and connecting a Component to it. Initialized with given port and handler function +// The Component and mock are both returned +func mockComponentConnection(t *testing.T, port int, handler func(t *testing.T, c net.Conn)) (*Component, *ServerMock) { + // Init mock server + testComponentAddress := fmt.Sprintf("%s:%d", testComponentDomain, port) + mock := ServerMock{} + mock.Start(t, testComponentAddress, handler) + + //================================== + // Create Component to connect to it + c := makeBasicComponent(defaultComponentName, testComponentAddress, t) + + //======================================== + // Connect the new Component to the server + err := c.Connect() + if err != nil { + t.Errorf("%+v", err) + } + + return c, &mock +} + +func makeBasicComponent(name string, mockServerAddr string, t *testing.T) *Component { + opts := ComponentOptions{ + TransportConfiguration: TransportConfiguration{ + Address: mockServerAddr, + Domain: "localhost", + }, + Domain: testComponentDomain, + Secret: "mypass", + Name: name, + Category: "gateway", + Type: "service", + } + router := NewRouter() + c, err := NewComponent(opts, router, componentDefaultErrorHandler) + if err != nil { + t.Errorf("%+v", err) + } + c.transport, err = NewComponentTransport(c.ComponentOptions.TransportConfiguration) + if err != nil { + t.Errorf("%+v", err) + } + return c +} + +// This really should not be used as is. +// It's just meant to be a placeholder when error handling is not needed at this level +func componentDefaultErrorHandler(err error) { + +} + +// Sends IQ response to Component request. +// No parsing of the request here. We just check that it's valid, and send the default response. +func handlerForComponentIQSend(t *testing.T, c net.Conn) { + // Completes the connection by exchanging handshakes + handlerForComponentHandshakeDefaultID(t, c) + respondToIQ(t, c) +} + +// Used for ID and handshake related tests +func checkOpenStreamHandshakeID(t *testing.T, c net.Conn, decoder *xml.Decoder, streamID string) { + c.SetDeadline(time.Now().Add(defaultTimeout)) + defer c.SetDeadline(time.Time{}) + + for { // TODO clean up. That for loop is not elegant and I prefer bounded recursion. + token, err := decoder.Token() + if err != nil { + t.Errorf("cannot read next token: %s", err) + } + + switch elem := token.(type) { + // Wait for first startElement + case xml.StartElement: + if elem.Name.Space != stanza.NSStream || elem.Name.Local != "stream" { + err = errors.New("xmpp: expected but got <" + elem.Name.Local + "> in " + elem.Name.Space) + return + } + if _, err := fmt.Fprintf(c, serverStreamOpen, "localhost", streamID, stanza.NSComponent, stanza.NSStream); err != nil { + t.Errorf("cannot write server stream open: %s", err) + } + return + } + } +} + +func checkOpenStreamHandshakeDefaultID(t *testing.T, c net.Conn, decoder *xml.Decoder) { + checkOpenStreamHandshakeID(t, c, decoder, defaultStreamID) } // Performs a Component connection with a handshake. It uses a default ID defined in this file as a constant. @@ -303,152 +452,12 @@ func readHandshakeComponent(t *testing.T, decoder *xml.Decoder) { } } -func checkOpenStreamHandshakeDefaultID(t *testing.T, c net.Conn, decoder *xml.Decoder) { - checkOpenStreamHandshakeID(t, c, decoder, defaultStreamID) -} - -// Used for ID and handshake related tests -func checkOpenStreamHandshakeID(t *testing.T, c net.Conn, decoder *xml.Decoder, streamID string) { - c.SetDeadline(time.Now().Add(defaultTimeout)) - defer c.SetDeadline(time.Time{}) - - for { // TODO clean up. That for loop is not elegant and I prefer bounded recursion. - token, err := decoder.Token() - if err != nil { - t.Errorf("cannot read next token: %s", err) - } - - switch elem := token.(type) { - // Wait for first startElement - case xml.StartElement: - if elem.Name.Space != stanza.NSStream || elem.Name.Local != "stream" { - err = errors.New("xmpp: expected but got <" + elem.Name.Local + "> in " + elem.Name.Space) - return - } - if _, err := fmt.Fprintf(c, serverStreamOpen, "localhost", streamID, stanza.NSComponent, stanza.NSStream); err != nil { - t.Errorf("cannot write server stream open: %s", err) - } - return - } - } -} - -//============================================================================= -// Sends IQ response to Component request. -// No parsing of the request here. We just check that it's valid, and send the default response. -func handlerForComponentIQSend(t *testing.T, c net.Conn) { - // Completes the connection by exchanging handshakes - handlerForComponentHandshakeDefaultID(t, c) - - // Decoder to parse the request +// Performs a Component connection with a handshake. It uses a default ID defined in this file as a constant. +// Used in the mock server as a Handler +func handlerForComponentHandshakeDefaultID(t *testing.T, c net.Conn) { decoder := xml.NewDecoder(c) - - iqReq, err := receiveIq(t, c, decoder) - if err != nil { - t.Errorf("Error receiving the IQ stanza : %v", err) - } else if !iqReq.IsValid() { - t.Errorf("server received an IQ stanza : %v", iqReq) - } - - // Crafting response - iqResp := stanza.NewIQ(stanza.Attrs{Type: stanza.IQTypeResult, From: iqReq.To, To: iqReq.From, Id: iqReq.Id, Lang: "en"}) - disco := iqResp.DiscoInfo() - disco.AddFeatures("vcard-temp", - `http://jabber.org/protocol/address`) - - disco.AddIdentity("Multicast", "service", "multicast") - iqResp.Payload = disco - - // Sending response to the Component - mResp, err := xml.Marshal(iqResp) - _, err = fmt.Fprintln(c, string(mResp)) - if err != nil { - t.Errorf("Could not send response stanza : %s", err) - } + checkOpenStreamHandshakeDefaultID(t, c, decoder) + readHandshakeComponent(t, decoder) + fmt.Fprintln(c, "") // That's all the server needs to return (see xep-0114) return } - -// Reads next request coming from the Component. Expecting it to be an IQ request -func receiveIq(t *testing.T, c net.Conn, decoder *xml.Decoder) (stanza.IQ, error) { - c.SetDeadline(time.Now().Add(defaultTimeout)) - defer c.SetDeadline(time.Time{}) - var iqStz stanza.IQ - err := decoder.Decode(&iqStz) - if err != nil { - t.Errorf("cannot read the received IQ stanza: %s", err) - } - if !iqStz.IsValid() { - t.Errorf("received IQ stanza is invalid : %s", err) - } - return iqStz, nil -} - -func receiveRawIq(t *testing.T, c net.Conn, errChan chan error) { - c.SetDeadline(time.Now().Add(defaultTimeout)) - defer c.SetDeadline(time.Time{}) - decoder := xml.NewDecoder(c) - var iq stanza.IQ - err := decoder.Decode(&iq) - if err != nil || !iq.IsValid() { - s := stanza.StreamError{ - XMLName: xml.Name{Local: "stream:error"}, - Error: xml.Name{Local: "xml-not-well-formed"}, - Text: `XML was not well-formed`, - } - raw, _ := xml.Marshal(s) - fmt.Fprintln(c, string(raw)) - fmt.Fprintln(c, ``) // TODO : check this client side - errChan <- fmt.Errorf("invalid xml") - return - } - errChan <- nil - return -} - -//=============================== -// Init mock server and connection -// Creating a mock server and connecting a Component to it. Initialized with given port and handler function -// The Component and mock are both returned -func mockConnection(t *testing.T, port int, handler func(t *testing.T, c net.Conn)) (*Component, *ServerMock) { - // Init mock server - testComponentAddress := fmt.Sprintf("%s:%d", testComponentDomain, port) - mock := ServerMock{} - mock.Start(t, testComponentAddress, handler) - - //================================== - // Create Component to connect to it - c := makeBasicComponent(defaultComponentName, testComponentAddress, t) - - //======================================== - // Connect the new Component to the server - err := c.Connect() - if err != nil { - t.Errorf("%+v", err) - } - - return c, &mock -} - -func makeBasicComponent(name string, mockServerAddr string, t *testing.T) *Component { - opts := ComponentOptions{ - TransportConfiguration: TransportConfiguration{ - Address: mockServerAddr, - Domain: "localhost", - }, - Domain: testComponentDomain, - Secret: "mypass", - Name: name, - Category: "gateway", - Type: "service", - } - router := NewRouter() - c, err := NewComponent(opts, router) - if err != nil { - t.Errorf("%+v", err) - } - c.transport, err = NewComponentTransport(c.ComponentOptions.TransportConfiguration) - if err != nil { - t.Errorf("%+v", err) - } - return c -} diff --git a/tcp_server_mock.go b/tcp_server_mock.go index bdc4397..4afed80 100644 --- a/tcp_server_mock.go +++ b/tcp_server_mock.go @@ -1,12 +1,42 @@ package xmpp import ( + "encoding/xml" + "fmt" + "gosrc.io/xmpp/stanza" "net" "testing" + "time" ) //============================================================================= // TCP Server Mock +const ( + defaultTimeout = 2 * time.Second + testComponentDomain = "localhost" + defaultServerName = "testServer" + defaultStreamID = "91bd0bba-012f-4d92-bb17-5fc41e6fe545" + defaultComponentName = "Test Component" + serverStreamOpen = "" + + // Default port is not standard XMPP port to avoid interfering + // with local running XMPP server + + // Component tests + testHandshakePort = iota + 15222 + testDecoderPort + testSendIqPort + testSendIqFailPort + testSendRawPort + testDisconnectPort + testSManDisconnectPort + + // Client tests + testClientBasePort + testClientRawPort + testClientIqPort + testClientIqFailPort +) // ClientHandler is passed by the test client to provide custom behaviour to // the TCP server mock. This allows customizing the server behaviour to allow @@ -81,3 +111,180 @@ func (mock *ServerMock) loop() { go mock.handler(mock.t, conn) } } + +//====================================================================================================================== +// A few functions commonly used for tests. Trying to avoid duplicates in client and component test files. +//====================================================================================================================== + +func respondToIQ(t *testing.T, c net.Conn) { + // Decoder to parse the request + decoder := xml.NewDecoder(c) + + iqReq, err := receiveIq(c, decoder) + if err != nil { + t.Fatalf("failed to receive IQ : %s", err.Error()) + } + + if !iqReq.IsValid() { + mockIQError(c) + return + } + + // Crafting response + iqResp := stanza.NewIQ(stanza.Attrs{Type: stanza.IQTypeResult, From: iqReq.To, To: iqReq.From, Id: iqReq.Id, Lang: "en"}) + disco := iqResp.DiscoInfo() + disco.AddFeatures("vcard-temp", + `http://jabber.org/protocol/address`) + + disco.AddIdentity("Multicast", "service", "multicast") + iqResp.Payload = disco + + // Sending response to the Component + mResp, err := xml.Marshal(iqResp) + _, err = fmt.Fprintln(c, string(mResp)) + if err != nil { + t.Errorf("Could not send response stanza : %s", err) + } + return +} + +// When a presence stanza is automatically sent (right now it's the case in the client), we may want to discard it +// and test further stanzas. +func discardPresence(t *testing.T, c net.Conn) { + decoder := xml.NewDecoder(c) + c.SetDeadline(time.Now().Add(defaultTimeout)) + defer c.SetDeadline(time.Time{}) + var presenceStz stanza.Presence + err := decoder.Decode(&presenceStz) + if err != nil { + t.Errorf("Expected presence but this happened : %s", err.Error()) + } +} + +// Reads next request coming from the Component. Expecting it to be an IQ request +func receiveIq(c net.Conn, decoder *xml.Decoder) (*stanza.IQ, error) { + c.SetDeadline(time.Now().Add(defaultTimeout)) + defer c.SetDeadline(time.Time{}) + var iqStz stanza.IQ + err := decoder.Decode(&iqStz) + if err != nil { + return nil, err + } + return &iqStz, nil +} + +// Should be used in server handlers when an IQ sent by a client or component is invalid. +// This responds as expected from a "real" server, aside from the error message. +func mockIQError(c net.Conn) { + s := stanza.StreamError{ + XMLName: xml.Name{Local: "stream:error"}, + Error: xml.Name{Local: "xml-not-well-formed"}, + Text: `XML was not well-formed`, + } + raw, _ := xml.Marshal(s) + fmt.Fprintln(c, string(raw)) + fmt.Fprintln(c, ``) +} + +func sendStreamFeatures(t *testing.T, c net.Conn, _ *xml.Decoder) { + // This is a basic server, supporting only 1 stream feature: SASL Plain Auth + features := ` + + PLAIN + +` + if _, err := fmt.Fprintln(c, features); err != nil { + t.Errorf("cannot send stream feature: %s", err) + } +} + +// TODO return err in case of error reading the auth params +func readAuth(t *testing.T, decoder *xml.Decoder) string { + se, err := stanza.NextStart(decoder) + if err != nil { + t.Errorf("cannot read auth: %s", err) + return "" + } + + var nv interface{} + nv = &stanza.SASLAuth{} + // Decode element into pointer storage + if err = decoder.DecodeElement(nv, &se); err != nil { + t.Errorf("cannot decode auth: %s", err) + return "" + } + + switch v := nv.(type) { + case *stanza.SASLAuth: + return v.Value + } + return "" +} + +func sendBindFeature(t *testing.T, c net.Conn, _ *xml.Decoder) { + // This is a basic server, supporting only 1 stream feature after auth: resource binding + features := ` + +` + if _, err := fmt.Fprintln(c, features); err != nil { + t.Errorf("cannot send stream feature: %s", err) + } +} + +func sendRFC3921Feature(t *testing.T, c net.Conn, _ *xml.Decoder) { + // This is a basic server, supporting only 2 features after auth: resource & session binding + features := ` + + +` + if _, err := fmt.Fprintln(c, features); err != nil { + t.Errorf("cannot send stream feature: %s", err) + } +} + +func bind(t *testing.T, c net.Conn, decoder *xml.Decoder) { + se, err := stanza.NextStart(decoder) + if err != nil { + t.Errorf("cannot read bind: %s", err) + return + } + + iq := &stanza.IQ{} + // Decode element into pointer storage + if err = decoder.DecodeElement(&iq, &se); err != nil { + t.Errorf("cannot decode bind iq: %s", err) + return + } + + // TODO Check all elements + switch iq.Payload.(type) { + case *stanza.Bind: + result := ` + + %s + +` + fmt.Fprintf(c, result, iq.Id, "test@localhost/test") // TODO use real JID + } +} + +func session(t *testing.T, c net.Conn, decoder *xml.Decoder) { + se, err := stanza.NextStart(decoder) + if err != nil { + t.Errorf("cannot read session: %s", err) + return + } + + iq := &stanza.IQ{} + // Decode element into pointer storage + if err = decoder.DecodeElement(&iq, &se); err != nil { + t.Errorf("cannot decode session iq: %s", err) + return + } + + switch iq.Payload.(type) { + case *stanza.StreamSession: + result := `` + fmt.Fprintf(c, result, iq.Id) + } +} From e675e65a592d6ccd15714a524d047fa14c09aed9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?CORNIERE=20R=C3=A9mi?= Date: Thu, 5 Dec 2019 18:12:00 +0100 Subject: [PATCH 2/5] Added callback to process errors after connection. Added tests and refactored a bit. --- .../xmpp_chat_client/xmpp_chat_client.go | 95 +++++ _examples/xmpp_component/xmpp_component.go | 6 +- _examples/xmpp_jukebox/xmpp_jukebox.go | 5 +- _examples/xmpp_oauth2/xmpp_oauth2.go | 6 +- _examples/xmpp_websocket/xmpp_websocket.go | 6 +- client.go | 18 +- client_test.go | 374 +++++++++++------ component.go | 20 +- component_test.go | 387 +++++++++--------- tcp_server_mock.go | 207 ++++++++++ 10 files changed, 792 insertions(+), 332 deletions(-) create mode 100644 _examples/xmpp_chat_client/xmpp_chat_client.go diff --git a/_examples/xmpp_chat_client/xmpp_chat_client.go b/_examples/xmpp_chat_client/xmpp_chat_client.go new file mode 100644 index 0000000..2b2d2e7 --- /dev/null +++ b/_examples/xmpp_chat_client/xmpp_chat_client.go @@ -0,0 +1,95 @@ +package main + +/* +xmpp_chat_client is a demo client that connect on an XMPP server to chat with other members +Note that this example sends to a very specific user. User logic is not implemented here. +*/ + +import ( + . "bufio" + "fmt" + "os" + + "gosrc.io/xmpp" + "gosrc.io/xmpp/stanza" +) + +const ( + currentUserAddress = "localhost:5222" + currentUserJid = "testuser@localhost" + currentUserPass = "testpass" + correspondantJid = "testuser2@localhost" +) + +func main() { + config := xmpp.Config{ + TransportConfiguration: xmpp.TransportConfiguration{ + Address: currentUserAddress, + }, + Jid: currentUserJid, + Credential: xmpp.Password(currentUserPass), + Insecure: true} + + var client *xmpp.Client + var err error + router := xmpp.NewRouter() + router.HandleFunc("message", handleMessage) + if client, err = xmpp.NewClient(config, router, errorHandler); err != nil { + fmt.Println("Error new client") + } + + // Connecting client and handling messages + // To use a stream manager, just write something like this instead : + //cm := xmpp.NewStreamManager(client, startMessaging) + //log.Fatal(cm.Run()) //=> this will lock the calling goroutine + + if err = client.Connect(); err != nil { + fmt.Printf("XMPP connection failed: %s", err) + return + } + startMessaging(client) + +} + +func startMessaging(client xmpp.Sender) { + reader := NewReader(os.Stdin) + textChan := make(chan string) + var text string + for { + fmt.Print("Enter text: ") + go readInput(reader, textChan) + select { + case <-killChan: + return + case text = <-textChan: + reply := stanza.Message{Attrs: stanza.Attrs{To: correspondantJid}, Body: text} + err := client.Send(reply) + if err != nil { + fmt.Printf("There was a problem sending the message : %v", reply) + return + } + } + } +} + +func readInput(reader *Reader, textChan chan string) { + text, _ := reader.ReadString('\n') + textChan <- text +} + +var killChan = make(chan struct{}) + +// If an error occurs, this is used +func errorHandler(err error) { + fmt.Printf("%v", err) + killChan <- struct{}{} +} + +func handleMessage(s xmpp.Sender, p stanza.Packet) { + msg, ok := p.(stanza.Message) + if !ok { + _, _ = fmt.Fprintf(os.Stdout, "Ignoring packet: %T\n", p) + return + } + _, _ = fmt.Fprintf(os.Stdout, "Body = %s - from = %s\n", msg.Body, msg.From) +} diff --git a/_examples/xmpp_component/xmpp_component.go b/_examples/xmpp_component/xmpp_component.go index 0452888..7f676cb 100644 --- a/_examples/xmpp_component/xmpp_component.go +++ b/_examples/xmpp_component/xmpp_component.go @@ -35,7 +35,7 @@ func main() { IQNamespaces("jabber:iq:version"). HandlerFunc(handleVersion) - component, err := xmpp.NewComponent(opts, router) + component, err := xmpp.NewComponent(opts, router, handleError) if err != nil { log.Fatalf("%+v", err) } @@ -47,6 +47,10 @@ func main() { log.Fatal(cm.Run()) } +func handleError(err error) { + fmt.Println(err.Error()) +} + func handleMessage(_ xmpp.Sender, p stanza.Packet) { msg, ok := p.(stanza.Message) if !ok { diff --git a/_examples/xmpp_jukebox/xmpp_jukebox.go b/_examples/xmpp_jukebox/xmpp_jukebox.go index 91f453c..ce7ebc9 100644 --- a/_examples/xmpp_jukebox/xmpp_jukebox.go +++ b/_examples/xmpp_jukebox/xmpp_jukebox.go @@ -53,7 +53,7 @@ func main() { handleIQ(s, p, player) }) - client, err := xmpp.NewClient(config, router) + client, err := xmpp.NewClient(config, router, errorHandler) if err != nil { log.Fatalf("%+v", err) } @@ -61,6 +61,9 @@ func main() { cm := xmpp.NewStreamManager(client, nil) log.Fatal(cm.Run()) } +func errorHandler(err error) { + fmt.Println(err.Error()) +} func handleMessage(s xmpp.Sender, p stanza.Packet, player *mpg123.Player) { msg, ok := p.(stanza.Message) diff --git a/_examples/xmpp_oauth2/xmpp_oauth2.go b/_examples/xmpp_oauth2/xmpp_oauth2.go index f322447..89b2639 100644 --- a/_examples/xmpp_oauth2/xmpp_oauth2.go +++ b/_examples/xmpp_oauth2/xmpp_oauth2.go @@ -28,7 +28,7 @@ func main() { router := xmpp.NewRouter() router.HandleFunc("message", handleMessage) - client, err := xmpp.NewClient(config, router) + client, err := xmpp.NewClient(config, router, errorHandler) if err != nil { log.Fatalf("%+v", err) } @@ -39,6 +39,10 @@ func main() { log.Fatal(cm.Run()) } +func errorHandler(err error) { + fmt.Println(err.Error()) +} + func handleMessage(s xmpp.Sender, p stanza.Packet) { msg, ok := p.(stanza.Message) if !ok { diff --git a/_examples/xmpp_websocket/xmpp_websocket.go b/_examples/xmpp_websocket/xmpp_websocket.go index 428a1d1..c8c0620 100644 --- a/_examples/xmpp_websocket/xmpp_websocket.go +++ b/_examples/xmpp_websocket/xmpp_websocket.go @@ -26,7 +26,7 @@ func main() { router := xmpp.NewRouter() router.HandleFunc("message", handleMessage) - client, err := xmpp.NewClient(config, router) + client, err := xmpp.NewClient(config, router, errorHandler) if err != nil { log.Fatalf("%+v", err) } @@ -37,6 +37,10 @@ func main() { log.Fatal(cm.Run()) } +func errorHandler(err error) { + fmt.Println(err.Error()) +} + func handleMessage(s xmpp.Sender, p stanza.Packet) { msg, ok := p.(stanza.Message) if !ok { diff --git a/client.go b/client.go index ecb2aad..a5ad1bf 100644 --- a/client.go +++ b/client.go @@ -98,6 +98,8 @@ type Client struct { router *Router // Track and broadcast connection state EventManager + // Handle errors from client execution + ErrorHandler func(error) } /* @@ -107,7 +109,7 @@ Setting up the client / Checking the parameters // NewClient generates a new XMPP client, based on Config passed as parameters. // If host is not specified, the DNS SRV should be used to find the host from the domainpart of the JID. // Default the port to 5222. -func NewClient(config Config, r *Router) (c *Client, err error) { +func NewClient(config Config, r *Router, errorHandler func(error)) (c *Client, err error) { if config.KeepaliveInterval == 0 { config.KeepaliveInterval = time.Second * 30 } @@ -143,6 +145,7 @@ func NewClient(config Config, r *Router) (c *Client, err error) { c = new(Client) c.config = config c.router = r + c.ErrorHandler = errorHandler if c.config.ConnectTimeout == 0 { c.config.ConnectTimeout = 15 // 15 second as default @@ -191,10 +194,7 @@ func (c *Client) Resume(state SMState) error { go keepalive(c.transport, c.config.KeepaliveInterval, keepaliveQuit) // Start the receiver go routine state = c.Session.SMState - // Leaving this channel here for later. Not used atm. We should return this instead of an error because right - // now the returned error is lost in limbo. - errChan := make(chan error) - go c.recv(state, keepaliveQuit, errChan) + go c.recv(state, keepaliveQuit) // We're connected and can now receive and send messages. //fmt.Fprintf(client.conn, "%s%s", "chat", "Online") @@ -273,11 +273,11 @@ func (c *Client) sendWithWriter(writer io.Writer, packet []byte) error { // Go routines // Loop: Receive data from server -func (c *Client) recv(state SMState, keepaliveQuit chan<- struct{}, errChan chan<- error) { +func (c *Client) recv(state SMState, keepaliveQuit chan<- struct{}) { for { val, err := stanza.NextPacket(c.transport.GetDecoder()) if err != nil { - errChan <- err + c.ErrorHandler(err) close(keepaliveQuit) c.disconnected(state) return @@ -289,7 +289,7 @@ func (c *Client) recv(state SMState, keepaliveQuit chan<- struct{}, errChan chan c.router.route(c, val) close(keepaliveQuit) c.streamError(packet.Error.Local, packet.Text) - errChan <- errors.New("stream error: " + packet.Error.Local) + c.ErrorHandler(errors.New("stream error: " + packet.Error.Local)) return // Process Stream management nonzas case stanza.SMRequest: @@ -299,7 +299,7 @@ func (c *Client) recv(state SMState, keepaliveQuit chan<- struct{}, errChan chan }, H: state.Inbound} err = c.Send(answer) if err != nil { - errChan <- err + c.ErrorHandler(err) return } default: diff --git a/client_test.go b/client_test.go index e18e8ba..0caace0 100644 --- a/client_test.go +++ b/client_test.go @@ -1,6 +1,7 @@ package xmpp import ( + "context" "encoding/xml" "errors" "fmt" @@ -14,9 +15,8 @@ import ( const ( // Default port is not standard XMPP port to avoid interfering // with local running XMPP server - testXMPPAddress = "localhost:15222" - - defaultTimeout = 2 * time.Second + testXMPPAddress = "localhost:15222" + testClientDomain = "localhost" ) func TestEventManager(t *testing.T) { @@ -40,7 +40,7 @@ func TestEventManager(t *testing.T) { func TestClient_Connect(t *testing.T) { // Setup Mock server mock := ServerMock{} - mock.Start(t, testXMPPAddress, handlerConnectSuccess) + mock.Start(t, testXMPPAddress, handlerClientConnectSuccess) // Test / Check result config := Config{ @@ -54,7 +54,7 @@ func TestClient_Connect(t *testing.T) { var client *Client var err error router := NewRouter() - if client, err = NewClient(config, router); err != nil { + if client, err = NewClient(config, router, clientDefaultErrorHandler); err != nil { t.Errorf("connect create XMPP client: %s", err) } @@ -82,7 +82,7 @@ func TestClient_NoInsecure(t *testing.T) { var client *Client var err error router := NewRouter() - if client, err = NewClient(config, router); err != nil { + if client, err = NewClient(config, router, clientDefaultErrorHandler); err != nil { t.Errorf("cannot create XMPP client: %s", err) } @@ -112,7 +112,7 @@ func TestClient_FeaturesTracking(t *testing.T) { var client *Client var err error router := NewRouter() - if client, err = NewClient(config, router); err != nil { + if client, err = NewClient(config, router, clientDefaultErrorHandler); err != nil { t.Errorf("cannot create XMPP client: %s", err) } @@ -127,7 +127,7 @@ func TestClient_FeaturesTracking(t *testing.T) { func TestClient_RFC3921Session(t *testing.T) { // Setup Mock server mock := ServerMock{} - mock.Start(t, testXMPPAddress, handlerConnectWithSession) + mock.Start(t, testXMPPAddress, handlerClientConnectWithSession) // Test / Check result config := Config{ @@ -142,7 +142,7 @@ func TestClient_RFC3921Session(t *testing.T) { var client *Client var err error router := NewRouter() - if client, err = NewClient(config, router); err != nil { + if client, err = NewClient(config, router, clientDefaultErrorHandler); err != nil { t.Errorf("connect create XMPP client: %s", err) } @@ -153,48 +153,254 @@ func TestClient_RFC3921Session(t *testing.T) { mock.Stop() } +// Testing sending an IQ to the mock server and reading its response. +func TestClient_SendIQ(t *testing.T) { + done := make(chan struct{}) + // Handler for Mock server + h := func(t *testing.T, c net.Conn) { + handlerClientConnectSuccess(t, c) + discardPresence(t, c) + respondToIQ(t, c) + done <- struct{}{} + } + client, mock := mockClientConnection(t, h, testClientIqPort) + + ctx, _ := context.WithTimeout(context.Background(), 30*time.Second) + iqReq := stanza.NewIQ(stanza.Attrs{Type: stanza.IQTypeGet, From: "test1@localhost/mremond-mbp", To: defaultServerName, Id: defaultStreamID, Lang: "en"}) + disco := iqReq.DiscoInfo() + iqReq.Payload = disco + + // Handle a possible error + errChan := make(chan error) + errorHandler := func(err error) { + errChan <- err + } + client.ErrorHandler = errorHandler + res, err := client.SendIQ(ctx, iqReq) + if err != nil { + t.Errorf(err.Error()) + } + + select { + case <-res: // If the server responds with an IQ, we pass the test + case err := <-errChan: // If the server sends an error, or there is a connection error + t.Errorf(err.Error()) + case <-time.After(defaultChannelTimeout): // If we timeout + t.Errorf("Failed to receive response, to sent IQ, from mock server") + } + select { + case <-done: + mock.Stop() + case <-time.After(defaultChannelTimeout): + t.Errorf("The mock server failed to finish its job !") + } +} + +func TestClient_SendIQFail(t *testing.T) { + done := make(chan struct{}) + // Handler for Mock server + h := func(t *testing.T, c net.Conn) { + handlerClientConnectSuccess(t, c) + discardPresence(t, c) + respondToIQ(t, c) + done <- struct{}{} + } + client, mock := mockClientConnection(t, h, testClientIqFailPort) + + //================== + // Create an IQ to send + ctx, _ := context.WithTimeout(context.Background(), 30*time.Second) + iqReq := stanza.NewIQ(stanza.Attrs{Type: stanza.IQTypeGet, From: "test1@localhost/mremond-mbp", To: defaultServerName, Id: defaultStreamID, Lang: "en"}) + disco := iqReq.DiscoInfo() + iqReq.Payload = disco + // Removing the id to make the stanza invalid. The IQ constructor makes a random one if none is specified + // so we need to overwrite it. + iqReq.Id = "" + + // Handle a possible error + errChan := make(chan error) + errorHandler := func(err error) { + errChan <- err + } + client.ErrorHandler = errorHandler + res, _ := client.SendIQ(ctx, iqReq) + + // Test + select { + case <-res: // If the server responds with an IQ + t.Errorf("Server should not respond with an IQ since the request is expected to be invalid !") + case <-errChan: // If the server sends an error, the test passes + case <-time.After(defaultChannelTimeout): // If we timeout + t.Errorf("Failed to receive response, to sent IQ, from mock server") + } + select { + case <-done: + mock.Stop() + case <-time.After(defaultChannelTimeout): + t.Errorf("The mock server failed to finish its job !") + } +} + +func TestClient_SendRaw(t *testing.T) { + done := make(chan struct{}) + // Handler for Mock server + h := func(t *testing.T, c net.Conn) { + handlerClientConnectSuccess(t, c) + discardPresence(t, c) + respondToIQ(t, c) + done <- struct{}{} + } + type testCase struct { + req string + shouldErr bool + port int + } + testRequests := make(map[string]testCase) + // Sending a correct IQ of type get. Not supposed to err + testRequests["Correct IQ"] = testCase{ + req: ``, + shouldErr: false, + port: testClientRawPort + 100, + } + // Sending an IQ with a missing ID. Should err + testRequests["IQ with missing ID"] = testCase{ + req: ``, + shouldErr: true, + port: testClientRawPort, + } + + // A handler for the client. + // In the failing test, the server returns a stream error, which triggers this handler, client side. + errChan := make(chan error) + errHandler := func(err error) { + errChan <- err + } + + // Tests for all the IQs + for name, tcase := range testRequests { + t.Run(name, func(st *testing.T) { + //Connecting to a mock server, initialized with given port and handler function + c, m := mockClientConnection(t, h, tcase.port) + c.ErrorHandler = errHandler + // Sending raw xml from test case + err := c.SendRaw(tcase.req) + if err != nil { + t.Errorf("Error sending Raw string") + } + // Just wait a little so the message has time to arrive + select { + // We don't use the default "long" timeout here because waiting it out means passing the test. + case <-time.After(100 * time.Millisecond): + case err = <-errChan: + if err == nil && tcase.shouldErr { + t.Errorf("Failed to get closing stream err") + } else if err != nil && !tcase.shouldErr { + t.Errorf("This test is not supposed to err !") + } + } + c.transport.Close() + select { + case <-done: + m.Stop() + case <-time.After(defaultChannelTimeout): + t.Errorf("The mock server failed to finish its job !") + } + }) + } +} + +func TestClient_Disconnect(t *testing.T) { + c, m := mockClientConnection(t, handlerClientConnectSuccess, testClientBasePort) + err := c.transport.Ping() + if err != nil { + t.Errorf("Could not ping but not disconnected yet") + } + c.Disconnect() + err = c.transport.Ping() + if err == nil { + t.Errorf("Did not disconnect properly") + } + m.Stop() +} + +func TestClient_DisconnectStreamManager(t *testing.T) { + // Init mock server + // Setup Mock server + mock := ServerMock{} + mock.Start(t, testXMPPAddress, handlerAbortTLS) + + // Test / Check result + config := Config{ + TransportConfiguration: TransportConfiguration{ + Address: testXMPPAddress, + }, + Jid: "test@localhost", + Credential: Password("test"), + } + + var client *Client + var err error + router := NewRouter() + if client, err = NewClient(config, router, clientDefaultErrorHandler); err != nil { + t.Errorf("cannot create XMPP client: %s", err) + } + + sman := NewStreamManager(client, nil) + errChan := make(chan error) + runSMan := func(errChan chan error) { + errChan <- sman.Run() + } + + go runSMan(errChan) + select { + case <-errChan: + case <-time.After(defaultChannelTimeout): + // When insecure is not allowed: + t.Errorf("should fail as insecure connection is not allowed and server does not support TLS") + } + mock.Stop() +} + //============================================================================= // Basic XMPP Server Mock Handlers. -const serverStreamOpen = "" - // Test connection with a basic straightforward workflow -func handlerConnectSuccess(t *testing.T, c net.Conn) { +func handlerClientConnectSuccess(t *testing.T, c net.Conn) { decoder := xml.NewDecoder(c) - checkOpenStream(t, c, decoder) + checkClientOpenStream(t, c, decoder) sendStreamFeatures(t, c, decoder) // Send initial features readAuth(t, decoder) fmt.Fprintln(c, "") - checkOpenStream(t, c, decoder) // Reset stream - sendBindFeature(t, c, decoder) // Send post auth features + checkClientOpenStream(t, c, decoder) // Reset stream + sendBindFeature(t, c, decoder) // Send post auth features bind(t, c, decoder) } // We expect client will abort on TLS func handlerAbortTLS(t *testing.T, c net.Conn) { decoder := xml.NewDecoder(c) - checkOpenStream(t, c, decoder) + checkClientOpenStream(t, c, decoder) sendStreamFeatures(t, c, decoder) // Send initial features } // Test connection with mandatory session (RFC-3921) -func handlerConnectWithSession(t *testing.T, c net.Conn) { +func handlerClientConnectWithSession(t *testing.T, c net.Conn) { decoder := xml.NewDecoder(c) - checkOpenStream(t, c, decoder) + checkClientOpenStream(t, c, decoder) sendStreamFeatures(t, c, decoder) // Send initial features readAuth(t, decoder) fmt.Fprintln(c, "") - checkOpenStream(t, c, decoder) // Reset stream - sendRFC3921Feature(t, c, decoder) // Send post auth features + checkClientOpenStream(t, c, decoder) // Reset stream + sendRFC3921Feature(t, c, decoder) // Send post auth features bind(t, c, decoder) session(t, c, decoder) } -func checkOpenStream(t *testing.T, c net.Conn, decoder *xml.Decoder) { +func checkClientOpenStream(t *testing.T, c net.Conn, decoder *xml.Decoder) { c.SetDeadline(time.Now().Add(defaultTimeout)) defer c.SetDeadline(time.Time{}) @@ -220,105 +426,35 @@ func checkOpenStream(t *testing.T, c net.Conn, decoder *xml.Decoder) { } } -func sendStreamFeatures(t *testing.T, c net.Conn, _ *xml.Decoder) { - // This is a basic server, supporting only 1 stream feature: SASL Plain Auth - features := ` - - PLAIN - -` - if _, err := fmt.Fprintln(c, features); err != nil { - t.Errorf("cannot send stream feature: %s", err) +func mockClientConnection(t *testing.T, serverHandler func(*testing.T, net.Conn), port int) (*Client, ServerMock) { + mock := ServerMock{} + testServerAddress := fmt.Sprintf("%s:%d", testClientDomain, port) + + mock.Start(t, testServerAddress, serverHandler) + + config := Config{ + TransportConfiguration: TransportConfiguration{ + Address: testServerAddress, + }, + Jid: "test@localhost", + Credential: Password("test"), + Insecure: true} + + var client *Client + var err error + router := NewRouter() + if client, err = NewClient(config, router, clientDefaultErrorHandler); err != nil { + t.Errorf("connect create XMPP client: %s", err) } + + if err = client.Connect(); err != nil { + t.Errorf("XMPP connection failed: %s", err) + } + + return client, mock } -// TODO return err in case of error reading the auth params -func readAuth(t *testing.T, decoder *xml.Decoder) string { - se, err := stanza.NextStart(decoder) - if err != nil { - t.Errorf("cannot read auth: %s", err) - return "" - } - - var nv interface{} - nv = &stanza.SASLAuth{} - // Decode element into pointer storage - if err = decoder.DecodeElement(nv, &se); err != nil { - t.Errorf("cannot decode auth: %s", err) - return "" - } - - switch v := nv.(type) { - case *stanza.SASLAuth: - return v.Value - } - return "" -} - -func sendBindFeature(t *testing.T, c net.Conn, _ *xml.Decoder) { - // This is a basic server, supporting only 1 stream feature after auth: resource binding - features := ` - -` - if _, err := fmt.Fprintln(c, features); err != nil { - t.Errorf("cannot send stream feature: %s", err) - } -} - -func sendRFC3921Feature(t *testing.T, c net.Conn, _ *xml.Decoder) { - // This is a basic server, supporting only 2 features after auth: resource & session binding - features := ` - - -` - if _, err := fmt.Fprintln(c, features); err != nil { - t.Errorf("cannot send stream feature: %s", err) - } -} - -func bind(t *testing.T, c net.Conn, decoder *xml.Decoder) { - se, err := stanza.NextStart(decoder) - if err != nil { - t.Errorf("cannot read bind: %s", err) - return - } - - iq := &stanza.IQ{} - // Decode element into pointer storage - if err = decoder.DecodeElement(&iq, &se); err != nil { - t.Errorf("cannot decode bind iq: %s", err) - return - } - - // TODO Check all elements - switch iq.Payload.(type) { - case *stanza.Bind: - result := ` - - %s - -` - fmt.Fprintf(c, result, iq.Id, "test@localhost/test") // TODO use real JID - } -} - -func session(t *testing.T, c net.Conn, decoder *xml.Decoder) { - se, err := stanza.NextStart(decoder) - if err != nil { - t.Errorf("cannot read session: %s", err) - return - } - - iq := &stanza.IQ{} - // Decode element into pointer storage - if err = decoder.DecodeElement(&iq, &se); err != nil { - t.Errorf("cannot decode session iq: %s", err) - return - } - - switch iq.Payload.(type) { - case *stanza.StreamSession: - result := `` - fmt.Fprintf(c, result, iq.Id) - } +// This really should not be used as is. +// It's just meant to be a placeholder when error handling is not needed at this level +func clientDefaultErrorHandler(err error) { } diff --git a/component.go b/component.go index 471f1db..2f61aef 100644 --- a/component.go +++ b/component.go @@ -48,11 +48,12 @@ type Component struct { transport Transport // read / write - socketProxy io.ReadWriter // TODO + socketProxy io.ReadWriter // TODO + ErrorHandler func(error) } -func NewComponent(opts ComponentOptions, r *Router) (*Component, error) { - c := Component{ComponentOptions: opts, router: r} +func NewComponent(opts ComponentOptions, r *Router, errorHandler func(error)) (*Component, error) { + c := Component{ComponentOptions: opts, router: r, ErrorHandler: errorHandler} return &c, nil } @@ -104,11 +105,8 @@ func (c *Component) Resume(sm SMState) error { case stanza.Handshake: // Start the receiver go routine c.updateState(StateSessionEstablished) - // Leaving this channel here for later. Not used atm. We should return this instead of an error because right - // now the returned error is lost in limbo. - errChan := make(chan error) - go c.recv(errChan) // Sends to errChan - return err // Should be empty at this point + go c.recv() + return err // Should be empty at this point default: c.updateState(StatePermanentError) return NewConnError(errors.New("expecting handshake result, got "+v.Name()), true) @@ -128,13 +126,13 @@ func (c *Component) SetHandler(handler EventHandler) { } // Receiver Go routine receiver -func (c *Component) recv(errChan chan<- error) { +func (c *Component) recv() { for { val, err := stanza.NextPacket(c.transport.GetDecoder()) if err != nil { c.updateState(StateDisconnected) - errChan <- err + c.ErrorHandler(err) return } // Handle stream errors @@ -142,7 +140,7 @@ func (c *Component) recv(errChan chan<- error) { case stanza.StreamError: c.router.route(c, val) c.streamError(p.Error.Local, p.Text) - errChan <- errors.New("stream error: " + p.Error.Local) + c.ErrorHandler(errors.New("stream error: " + p.Error.Local)) return } c.router.route(c, val) diff --git a/component_test.go b/component_test.go index 4e115f0..48963a5 100644 --- a/component_test.go +++ b/component_test.go @@ -5,6 +5,7 @@ import ( "encoding/xml" "errors" "fmt" + "github.com/google/uuid" "gosrc.io/xmpp/stanza" "net" "strings" @@ -15,19 +16,7 @@ import ( // Tests are ran in parallel, so each test creating a server must use a different port so we do not get any // conflict. Using iota for this should do the trick. const ( - testComponentDomain = "localhost" - defaultServerName = "testServer" - defaultStreamID = "91bd0bba-012f-4d92-bb17-5fc41e6fe545" - defaultComponentName = "Test Component" - - // Default port is not standard XMPP port to avoid interfering - // with local running XMPP server - testHandshakePort = iota + 15222 - testDecoderPort - testSendIqPort - testSendRawPort - testDisconnectPort - testSManDisconnectPort + defaultChannelTimeout = 5 * time.Second ) func TestHandshake(t *testing.T) { @@ -48,16 +37,14 @@ func TestHandshake(t *testing.T) { // Tests connection process with a handshake exchange // Tests multiple session IDs. All connections should generate a unique stream ID -func TestGenerateHandshake(t *testing.T) { +func TestGenerateHandshakeId(t *testing.T) { // Using this array with a channel to make a queue of values to test // These are stream IDs that will be used to test the connection process, mixing them with the "secret" to generate // some handshake value - var uuidsArray = [5]string{ - "cc9b3249-9582-4780-825f-4311b42f9b0e", - "bba8be3c-d98e-4e26-b9bb-9ed34578a503", - "dae72822-80e8-496b-b763-ab685f53a188", - "a45d6c06-de49-4bb0-935b-1a2201b71028", - "7dc6924f-0eca-4237-9898-18654b8d891e", + var uuidsArray = [5]string{} + for i := 1; i < len(uuidsArray); i++ { + id, _ := uuid.NewRandom() + uuidsArray[i] = id.String() } // Channel to pass stream IDs as a queue @@ -95,7 +82,7 @@ func TestGenerateHandshake(t *testing.T) { Type: "service", } router := NewRouter() - c, err := NewComponent(opts, router) + c, err := NewComponent(opts, router, componentDefaultErrorHandler) if err != nil { t.Errorf("%+v", err) } @@ -126,7 +113,7 @@ func TestStreamManager(t *testing.T) { // The decoder is expected to be built after a valid connection // Based on the xmpp_component example. func TestDecoder(t *testing.T) { - c, _ := mockConnection(t, testDecoderPort, handlerForComponentHandshakeDefaultID) + c, _ := mockComponentConnection(t, testDecoderPort, handlerForComponentHandshakeDefaultID) if c.transport.GetDecoder() == nil { t.Errorf("Failed to initialize decoder. Decoder is nil.") } @@ -134,39 +121,103 @@ func TestDecoder(t *testing.T) { // Tests sending an IQ to the server, and getting the response func TestSendIq(t *testing.T) { + done := make(chan struct{}) + h := func(t *testing.T, c net.Conn) { + handlerForComponentIQSend(t, c) + done <- struct{}{} + } + //Connecting to a mock server, initialized with given port and handler function - c, m := mockConnection(t, testSendIqPort, handlerForComponentIQSend) + c, m := mockComponentConnection(t, testSendIqPort, h) ctx, _ := context.WithTimeout(context.Background(), 30*time.Second) iqReq := stanza.NewIQ(stanza.Attrs{Type: stanza.IQTypeGet, From: "test1@localhost/mremond-mbp", To: defaultServerName, Id: defaultStreamID, Lang: "en"}) disco := iqReq.DiscoInfo() iqReq.Payload = disco + // Handle a possible error + errChan := make(chan error) + errorHandler := func(err error) { + errChan <- err + } + c.ErrorHandler = errorHandler + var res chan stanza.IQ res, _ = c.SendIQ(ctx, iqReq) select { case <-res: - case <-time.After(100 * time.Millisecond): + case err := <-errChan: + t.Errorf(err.Error()) + case <-time.After(defaultChannelTimeout): t.Errorf("Failed to receive response, to sent IQ, from mock server") } - m.Stop() + select { + case <-done: + m.Stop() + case <-time.After(defaultChannelTimeout): + t.Errorf("The mock server failed to finish its job !") + } +} + +// Checking that error handling is done properly client side when an invalid IQ is sent and the server responds in kind. +func TestSendIqFail(t *testing.T) { + done := make(chan struct{}) + h := func(t *testing.T, c net.Conn) { + handlerForComponentIQSend(t, c) + done <- struct{}{} + } + //Connecting to a mock server, initialized with given port and handler function + c, m := mockComponentConnection(t, testSendIqFailPort, h) + + ctx, _ := context.WithTimeout(context.Background(), 30*time.Second) + iqReq := stanza.NewIQ(stanza.Attrs{Type: stanza.IQTypeGet, From: "test1@localhost/mremond-mbp", To: defaultServerName, Id: defaultStreamID, Lang: "en"}) + + // Removing the id to make the stanza invalid. The IQ constructor makes a random one if none is specified + // so we need to overwrite it. + iqReq.Id = "" + disco := iqReq.DiscoInfo() + iqReq.Payload = disco + + errChan := make(chan error) + errorHandler := func(err error) { + errChan <- err + } + c.ErrorHandler = errorHandler + + var res chan stanza.IQ + res, _ = c.SendIQ(ctx, iqReq) + + select { + case r := <-res: // Do we get an IQ response from the server ? + t.Errorf("We should not be getting an IQ response here : this should fail !") + fmt.Println(r) + case <-errChan: // Do we get a stream error from the server ? + // If we get an error from the server, the test passes. + case <-time.After(defaultChannelTimeout): // Timeout ? + t.Errorf("Failed to receive response, to sent IQ, from mock server") + } + + select { + case <-done: + m.Stop() + case <-time.After(defaultChannelTimeout): + t.Errorf("The mock server failed to finish its job !") + } } // Tests sending raw xml to the mock server. -// TODO : check the server response client side ? // Right now, the server response is not checked and an err is passed in a channel if the test is supposed to err. // In this test, we use IQs func TestSendRaw(t *testing.T) { - // Error channel for the handler - errChan := make(chan error) + done := make(chan struct{}) // Handler for the mock server h := func(t *testing.T, c net.Conn) { // Completes the connection by exchanging handshakes handlerForComponentHandshakeDefaultID(t, c) - receiveRawIq(t, c, errChan) - return + receiveIq(c, xml.NewDecoder(c)) + done <- struct{}{} } type testCase struct { @@ -185,12 +236,19 @@ func TestSendRaw(t *testing.T) { shouldErr: true, } + // A handler for the component. + // In the failing test, the server returns a stream error, which triggers this handler, component side. + errChan := make(chan error) + errHandler := func(err error) { + errChan <- err + } + // Tests for all the IQs for name, tcase := range testRequests { t.Run(name, func(st *testing.T) { //Connecting to a mock server, initialized with given port and handler function - c, m := mockConnection(t, testSendRawPort, h) - + c, m := mockComponentConnection(t, testSendRawPort, h) + c.ErrorHandler = errHandler // Sending raw xml from test case err := c.SendRaw(tcase.req) if err != nil { @@ -198,21 +256,29 @@ func TestSendRaw(t *testing.T) { } // Just wait a little so the message has time to arrive select { - case <-time.After(100 * time.Millisecond): + // We don't use the default "long" timeout here because waiting it out means passing the test. + case <-time.After(200 * time.Millisecond): case err = <-errChan: if err == nil && tcase.shouldErr { t.Errorf("Failed to get closing stream err") + } else if err != nil && !tcase.shouldErr { + t.Errorf("This test is not supposed to err ! => %s", err.Error()) } } c.transport.Close() - m.Stop() + select { + case <-done: + m.Stop() + case <-time.After(defaultChannelTimeout): + t.Errorf("The mock server failed to finish its job !") + } }) } } // Tests the Disconnect method for Components func TestDisconnect(t *testing.T) { - c, m := mockConnection(t, testDisconnectPort, handlerForComponentHandshakeDefaultID) + c, m := mockComponentConnection(t, testDisconnectPort, handlerForComponentHandshakeDefaultID) err := c.transport.Ping() if err != nil { t.Errorf("Could not ping but not disconnected yet") @@ -257,14 +323,97 @@ func TestStreamManagerDisconnect(t *testing.T) { //============================================================================= // Basic XMPP Server Mock Handlers. -// Performs a Component connection with a handshake. It uses a default ID defined in this file as a constant. -// Used in the mock server as a Handler -func handlerForComponentHandshakeDefaultID(t *testing.T, c net.Conn) { - decoder := xml.NewDecoder(c) - checkOpenStreamHandshakeDefaultID(t, c, decoder) - readHandshakeComponent(t, decoder) - fmt.Fprintln(c, "") // That's all the server needs to return (see xep-0114) - return + +//=============================== +// Init mock server and connection +// Creating a mock server and connecting a Component to it. Initialized with given port and handler function +// The Component and mock are both returned +func mockComponentConnection(t *testing.T, port int, handler func(t *testing.T, c net.Conn)) (*Component, *ServerMock) { + // Init mock server + testComponentAddress := fmt.Sprintf("%s:%d", testComponentDomain, port) + mock := ServerMock{} + mock.Start(t, testComponentAddress, handler) + + //================================== + // Create Component to connect to it + c := makeBasicComponent(defaultComponentName, testComponentAddress, t) + + //======================================== + // Connect the new Component to the server + err := c.Connect() + if err != nil { + t.Errorf("%+v", err) + } + + return c, &mock +} + +func makeBasicComponent(name string, mockServerAddr string, t *testing.T) *Component { + opts := ComponentOptions{ + TransportConfiguration: TransportConfiguration{ + Address: mockServerAddr, + Domain: "localhost", + }, + Domain: testComponentDomain, + Secret: "mypass", + Name: name, + Category: "gateway", + Type: "service", + } + router := NewRouter() + c, err := NewComponent(opts, router, componentDefaultErrorHandler) + if err != nil { + t.Errorf("%+v", err) + } + c.transport, err = NewComponentTransport(c.ComponentOptions.TransportConfiguration) + if err != nil { + t.Errorf("%+v", err) + } + return c +} + +// This really should not be used as is. +// It's just meant to be a placeholder when error handling is not needed at this level +func componentDefaultErrorHandler(err error) { + +} + +// Sends IQ response to Component request. +// No parsing of the request here. We just check that it's valid, and send the default response. +func handlerForComponentIQSend(t *testing.T, c net.Conn) { + // Completes the connection by exchanging handshakes + handlerForComponentHandshakeDefaultID(t, c) + respondToIQ(t, c) +} + +// Used for ID and handshake related tests +func checkOpenStreamHandshakeID(t *testing.T, c net.Conn, decoder *xml.Decoder, streamID string) { + c.SetDeadline(time.Now().Add(defaultTimeout)) + defer c.SetDeadline(time.Time{}) + + for { // TODO clean up. That for loop is not elegant and I prefer bounded recursion. + token, err := decoder.Token() + if err != nil { + t.Errorf("cannot read next token: %s", err) + } + + switch elem := token.(type) { + // Wait for first startElement + case xml.StartElement: + if elem.Name.Space != stanza.NSStream || elem.Name.Local != "stream" { + err = errors.New("xmpp: expected but got <" + elem.Name.Local + "> in " + elem.Name.Space) + return + } + if _, err := fmt.Fprintf(c, serverStreamOpen, "localhost", streamID, stanza.NSComponent, stanza.NSStream); err != nil { + t.Errorf("cannot write server stream open: %s", err) + } + return + } + } +} + +func checkOpenStreamHandshakeDefaultID(t *testing.T, c net.Conn, decoder *xml.Decoder) { + checkOpenStreamHandshakeID(t, c, decoder, defaultStreamID) } // Performs a Component connection with a handshake. It uses a default ID defined in this file as a constant. @@ -303,152 +452,12 @@ func readHandshakeComponent(t *testing.T, decoder *xml.Decoder) { } } -func checkOpenStreamHandshakeDefaultID(t *testing.T, c net.Conn, decoder *xml.Decoder) { - checkOpenStreamHandshakeID(t, c, decoder, defaultStreamID) -} - -// Used for ID and handshake related tests -func checkOpenStreamHandshakeID(t *testing.T, c net.Conn, decoder *xml.Decoder, streamID string) { - c.SetDeadline(time.Now().Add(defaultTimeout)) - defer c.SetDeadline(time.Time{}) - - for { // TODO clean up. That for loop is not elegant and I prefer bounded recursion. - token, err := decoder.Token() - if err != nil { - t.Errorf("cannot read next token: %s", err) - } - - switch elem := token.(type) { - // Wait for first startElement - case xml.StartElement: - if elem.Name.Space != stanza.NSStream || elem.Name.Local != "stream" { - err = errors.New("xmpp: expected but got <" + elem.Name.Local + "> in " + elem.Name.Space) - return - } - if _, err := fmt.Fprintf(c, serverStreamOpen, "localhost", streamID, stanza.NSComponent, stanza.NSStream); err != nil { - t.Errorf("cannot write server stream open: %s", err) - } - return - } - } -} - -//============================================================================= -// Sends IQ response to Component request. -// No parsing of the request here. We just check that it's valid, and send the default response. -func handlerForComponentIQSend(t *testing.T, c net.Conn) { - // Completes the connection by exchanging handshakes - handlerForComponentHandshakeDefaultID(t, c) - - // Decoder to parse the request +// Performs a Component connection with a handshake. It uses a default ID defined in this file as a constant. +// Used in the mock server as a Handler +func handlerForComponentHandshakeDefaultID(t *testing.T, c net.Conn) { decoder := xml.NewDecoder(c) - - iqReq, err := receiveIq(t, c, decoder) - if err != nil { - t.Errorf("Error receiving the IQ stanza : %v", err) - } else if !iqReq.IsValid() { - t.Errorf("server received an IQ stanza : %v", iqReq) - } - - // Crafting response - iqResp := stanza.NewIQ(stanza.Attrs{Type: stanza.IQTypeResult, From: iqReq.To, To: iqReq.From, Id: iqReq.Id, Lang: "en"}) - disco := iqResp.DiscoInfo() - disco.AddFeatures("vcard-temp", - `http://jabber.org/protocol/address`) - - disco.AddIdentity("Multicast", "service", "multicast") - iqResp.Payload = disco - - // Sending response to the Component - mResp, err := xml.Marshal(iqResp) - _, err = fmt.Fprintln(c, string(mResp)) - if err != nil { - t.Errorf("Could not send response stanza : %s", err) - } + checkOpenStreamHandshakeDefaultID(t, c, decoder) + readHandshakeComponent(t, decoder) + fmt.Fprintln(c, "") // That's all the server needs to return (see xep-0114) return } - -// Reads next request coming from the Component. Expecting it to be an IQ request -func receiveIq(t *testing.T, c net.Conn, decoder *xml.Decoder) (stanza.IQ, error) { - c.SetDeadline(time.Now().Add(defaultTimeout)) - defer c.SetDeadline(time.Time{}) - var iqStz stanza.IQ - err := decoder.Decode(&iqStz) - if err != nil { - t.Errorf("cannot read the received IQ stanza: %s", err) - } - if !iqStz.IsValid() { - t.Errorf("received IQ stanza is invalid : %s", err) - } - return iqStz, nil -} - -func receiveRawIq(t *testing.T, c net.Conn, errChan chan error) { - c.SetDeadline(time.Now().Add(defaultTimeout)) - defer c.SetDeadline(time.Time{}) - decoder := xml.NewDecoder(c) - var iq stanza.IQ - err := decoder.Decode(&iq) - if err != nil || !iq.IsValid() { - s := stanza.StreamError{ - XMLName: xml.Name{Local: "stream:error"}, - Error: xml.Name{Local: "xml-not-well-formed"}, - Text: `XML was not well-formed`, - } - raw, _ := xml.Marshal(s) - fmt.Fprintln(c, string(raw)) - fmt.Fprintln(c, ``) // TODO : check this client side - errChan <- fmt.Errorf("invalid xml") - return - } - errChan <- nil - return -} - -//=============================== -// Init mock server and connection -// Creating a mock server and connecting a Component to it. Initialized with given port and handler function -// The Component and mock are both returned -func mockConnection(t *testing.T, port int, handler func(t *testing.T, c net.Conn)) (*Component, *ServerMock) { - // Init mock server - testComponentAddress := fmt.Sprintf("%s:%d", testComponentDomain, port) - mock := ServerMock{} - mock.Start(t, testComponentAddress, handler) - - //================================== - // Create Component to connect to it - c := makeBasicComponent(defaultComponentName, testComponentAddress, t) - - //======================================== - // Connect the new Component to the server - err := c.Connect() - if err != nil { - t.Errorf("%+v", err) - } - - return c, &mock -} - -func makeBasicComponent(name string, mockServerAddr string, t *testing.T) *Component { - opts := ComponentOptions{ - TransportConfiguration: TransportConfiguration{ - Address: mockServerAddr, - Domain: "localhost", - }, - Domain: testComponentDomain, - Secret: "mypass", - Name: name, - Category: "gateway", - Type: "service", - } - router := NewRouter() - c, err := NewComponent(opts, router) - if err != nil { - t.Errorf("%+v", err) - } - c.transport, err = NewComponentTransport(c.ComponentOptions.TransportConfiguration) - if err != nil { - t.Errorf("%+v", err) - } - return c -} diff --git a/tcp_server_mock.go b/tcp_server_mock.go index bdc4397..4afed80 100644 --- a/tcp_server_mock.go +++ b/tcp_server_mock.go @@ -1,12 +1,42 @@ package xmpp import ( + "encoding/xml" + "fmt" + "gosrc.io/xmpp/stanza" "net" "testing" + "time" ) //============================================================================= // TCP Server Mock +const ( + defaultTimeout = 2 * time.Second + testComponentDomain = "localhost" + defaultServerName = "testServer" + defaultStreamID = "91bd0bba-012f-4d92-bb17-5fc41e6fe545" + defaultComponentName = "Test Component" + serverStreamOpen = "" + + // Default port is not standard XMPP port to avoid interfering + // with local running XMPP server + + // Component tests + testHandshakePort = iota + 15222 + testDecoderPort + testSendIqPort + testSendIqFailPort + testSendRawPort + testDisconnectPort + testSManDisconnectPort + + // Client tests + testClientBasePort + testClientRawPort + testClientIqPort + testClientIqFailPort +) // ClientHandler is passed by the test client to provide custom behaviour to // the TCP server mock. This allows customizing the server behaviour to allow @@ -81,3 +111,180 @@ func (mock *ServerMock) loop() { go mock.handler(mock.t, conn) } } + +//====================================================================================================================== +// A few functions commonly used for tests. Trying to avoid duplicates in client and component test files. +//====================================================================================================================== + +func respondToIQ(t *testing.T, c net.Conn) { + // Decoder to parse the request + decoder := xml.NewDecoder(c) + + iqReq, err := receiveIq(c, decoder) + if err != nil { + t.Fatalf("failed to receive IQ : %s", err.Error()) + } + + if !iqReq.IsValid() { + mockIQError(c) + return + } + + // Crafting response + iqResp := stanza.NewIQ(stanza.Attrs{Type: stanza.IQTypeResult, From: iqReq.To, To: iqReq.From, Id: iqReq.Id, Lang: "en"}) + disco := iqResp.DiscoInfo() + disco.AddFeatures("vcard-temp", + `http://jabber.org/protocol/address`) + + disco.AddIdentity("Multicast", "service", "multicast") + iqResp.Payload = disco + + // Sending response to the Component + mResp, err := xml.Marshal(iqResp) + _, err = fmt.Fprintln(c, string(mResp)) + if err != nil { + t.Errorf("Could not send response stanza : %s", err) + } + return +} + +// When a presence stanza is automatically sent (right now it's the case in the client), we may want to discard it +// and test further stanzas. +func discardPresence(t *testing.T, c net.Conn) { + decoder := xml.NewDecoder(c) + c.SetDeadline(time.Now().Add(defaultTimeout)) + defer c.SetDeadline(time.Time{}) + var presenceStz stanza.Presence + err := decoder.Decode(&presenceStz) + if err != nil { + t.Errorf("Expected presence but this happened : %s", err.Error()) + } +} + +// Reads next request coming from the Component. Expecting it to be an IQ request +func receiveIq(c net.Conn, decoder *xml.Decoder) (*stanza.IQ, error) { + c.SetDeadline(time.Now().Add(defaultTimeout)) + defer c.SetDeadline(time.Time{}) + var iqStz stanza.IQ + err := decoder.Decode(&iqStz) + if err != nil { + return nil, err + } + return &iqStz, nil +} + +// Should be used in server handlers when an IQ sent by a client or component is invalid. +// This responds as expected from a "real" server, aside from the error message. +func mockIQError(c net.Conn) { + s := stanza.StreamError{ + XMLName: xml.Name{Local: "stream:error"}, + Error: xml.Name{Local: "xml-not-well-formed"}, + Text: `XML was not well-formed`, + } + raw, _ := xml.Marshal(s) + fmt.Fprintln(c, string(raw)) + fmt.Fprintln(c, ``) +} + +func sendStreamFeatures(t *testing.T, c net.Conn, _ *xml.Decoder) { + // This is a basic server, supporting only 1 stream feature: SASL Plain Auth + features := ` + + PLAIN + +` + if _, err := fmt.Fprintln(c, features); err != nil { + t.Errorf("cannot send stream feature: %s", err) + } +} + +// TODO return err in case of error reading the auth params +func readAuth(t *testing.T, decoder *xml.Decoder) string { + se, err := stanza.NextStart(decoder) + if err != nil { + t.Errorf("cannot read auth: %s", err) + return "" + } + + var nv interface{} + nv = &stanza.SASLAuth{} + // Decode element into pointer storage + if err = decoder.DecodeElement(nv, &se); err != nil { + t.Errorf("cannot decode auth: %s", err) + return "" + } + + switch v := nv.(type) { + case *stanza.SASLAuth: + return v.Value + } + return "" +} + +func sendBindFeature(t *testing.T, c net.Conn, _ *xml.Decoder) { + // This is a basic server, supporting only 1 stream feature after auth: resource binding + features := ` + +` + if _, err := fmt.Fprintln(c, features); err != nil { + t.Errorf("cannot send stream feature: %s", err) + } +} + +func sendRFC3921Feature(t *testing.T, c net.Conn, _ *xml.Decoder) { + // This is a basic server, supporting only 2 features after auth: resource & session binding + features := ` + + +` + if _, err := fmt.Fprintln(c, features); err != nil { + t.Errorf("cannot send stream feature: %s", err) + } +} + +func bind(t *testing.T, c net.Conn, decoder *xml.Decoder) { + se, err := stanza.NextStart(decoder) + if err != nil { + t.Errorf("cannot read bind: %s", err) + return + } + + iq := &stanza.IQ{} + // Decode element into pointer storage + if err = decoder.DecodeElement(&iq, &se); err != nil { + t.Errorf("cannot decode bind iq: %s", err) + return + } + + // TODO Check all elements + switch iq.Payload.(type) { + case *stanza.Bind: + result := ` + + %s + +` + fmt.Fprintf(c, result, iq.Id, "test@localhost/test") // TODO use real JID + } +} + +func session(t *testing.T, c net.Conn, decoder *xml.Decoder) { + se, err := stanza.NextStart(decoder) + if err != nil { + t.Errorf("cannot read session: %s", err) + return + } + + iq := &stanza.IQ{} + // Decode element into pointer storage + if err = decoder.DecodeElement(&iq, &se); err != nil { + t.Errorf("cannot decode session iq: %s", err) + return + } + + switch iq.Payload.(type) { + case *stanza.StreamSession: + result := `` + fmt.Fprintf(c, result, iq.Id) + } +} From 6d8e9d325a7862f5c3b9d51206374621aefa9417 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?CORNIERE=20R=C3=A9mi?= Date: Mon, 9 Dec 2019 13:31:01 +0100 Subject: [PATCH 3/5] Try removing decoder from IQ tests and changing writing method --- client.go | 3 +-- component.go | 12 +++++++++--- tcp_server_mock.go | 18 +++++++++++------- 3 files changed, 21 insertions(+), 12 deletions(-) diff --git a/client.go b/client.go index a5ad1bf..4d7857e 100644 --- a/client.go +++ b/client.go @@ -4,7 +4,6 @@ import ( "context" "encoding/xml" "errors" - "fmt" "io" "net" "time" @@ -200,7 +199,7 @@ func (c *Client) Resume(state SMState) error { //fmt.Fprintf(client.conn, "%s%s", "chat", "Online") // TODO: Do we always want to send initial presence automatically ? // Do we need an option to avoid that or do we rely on client to send the presence itself ? - _, err = fmt.Fprintf(c.transport, "") + err = c.sendWithWriter(c.transport, []byte("")) return err } diff --git a/component.go b/component.go index 2f61aef..8b96240 100644 --- a/component.go +++ b/component.go @@ -85,7 +85,7 @@ func (c *Component) Resume(sm SMState) error { c.updateState(StateConnected) // Authentication - if _, err := fmt.Fprintf(c.transport, "%s", c.handshake(streamId)); err != nil { + if err := c.sendWithWriter(c.transport, []byte(fmt.Sprintf("%s", c.handshake(streamId)))); err != nil { c.updateState(StateStreamError) return NewConnError(errors.New("cannot send handshake "+err.Error()), false) @@ -159,12 +159,18 @@ func (c *Component) Send(packet stanza.Packet) error { return errors.New("cannot marshal packet " + err.Error()) } - if _, err := fmt.Fprintf(transport, string(data)); err != nil { + if err := c.sendWithWriter(transport, data); err != nil { return errors.New("cannot send packet " + err.Error()) } return nil } +func (c *Component) sendWithWriter(writer io.Writer, packet []byte) error { + var err error + _, err = writer.Write(packet) + return err +} + // SendIQ sends an IQ set or get stanza to the server. If a result is received // the provided handler function will automatically be called. // @@ -195,7 +201,7 @@ func (c *Component) SendRaw(packet string) error { } var err error - _, err = fmt.Fprintf(transport, packet) + err = c.sendWithWriter(transport, []byte(packet)) return err } diff --git a/tcp_server_mock.go b/tcp_server_mock.go index 4afed80..efdda23 100644 --- a/tcp_server_mock.go +++ b/tcp_server_mock.go @@ -117,21 +117,25 @@ func (mock *ServerMock) loop() { //====================================================================================================================== func respondToIQ(t *testing.T, c net.Conn) { - // Decoder to parse the request - decoder := xml.NewDecoder(c) - - iqReq, err := receiveIq(c, decoder) + recvBuf := make([]byte, 1024) + var iqR stanza.IQ + _, err := c.Read(recvBuf[:]) // recv data if err != nil { - t.Fatalf("failed to receive IQ : %s", err.Error()) + if netErr, ok := err.(net.Error); ok && netErr.Timeout() { + t.Errorf("read timeout: %s", err) + } else { + t.Errorf("read error: %s", err) + } } + xml.Unmarshal(recvBuf, &iqR) - if !iqReq.IsValid() { + if !iqR.IsValid() { mockIQError(c) return } // Crafting response - iqResp := stanza.NewIQ(stanza.Attrs{Type: stanza.IQTypeResult, From: iqReq.To, To: iqReq.From, Id: iqReq.Id, Lang: "en"}) + iqResp := stanza.NewIQ(stanza.Attrs{Type: stanza.IQTypeResult, From: iqR.To, To: iqR.From, Id: iqR.Id, Lang: "en"}) disco := iqResp.DiscoInfo() disco.AddFeatures("vcard-temp", `http://jabber.org/protocol/address`) From fd48f52f3db309764f32bb63a2b0c3848006ff86 Mon Sep 17 00:00:00 2001 From: rcorniere Date: Tue, 10 Dec 2019 14:30:15 +0100 Subject: [PATCH 4/5] Using precisely sized buffers for tcp tests --- client.go | 3 ++- client_test.go | 6 +++--- tcp_server_mock.go | 16 ++++++++++++++-- 3 files changed, 19 insertions(+), 6 deletions(-) diff --git a/client.go b/client.go index 4d7857e..254a793 100644 --- a/client.go +++ b/client.go @@ -20,6 +20,7 @@ type ConnState = uint8 // This is a the list of events happening on the connection that the // client can be notified about. const ( + InitialPresence = "" StateDisconnected ConnState = iota StateConnected StateSessionEstablished @@ -199,7 +200,7 @@ func (c *Client) Resume(state SMState) error { //fmt.Fprintf(client.conn, "%s%s", "chat", "Online") // TODO: Do we always want to send initial presence automatically ? // Do we need an option to avoid that or do we rely on client to send the presence itself ? - err = c.sendWithWriter(c.transport, []byte("")) + err = c.sendWithWriter(c.transport, []byte(InitialPresence)) return err } diff --git a/client_test.go b/client_test.go index 0caace0..f2b775a 100644 --- a/client_test.go +++ b/client_test.go @@ -184,15 +184,15 @@ func TestClient_SendIQ(t *testing.T) { select { case <-res: // If the server responds with an IQ, we pass the test case err := <-errChan: // If the server sends an error, or there is a connection error - t.Errorf(err.Error()) + t.Fatal(err.Error()) case <-time.After(defaultChannelTimeout): // If we timeout - t.Errorf("Failed to receive response, to sent IQ, from mock server") + t.Fatal("Failed to receive response, to sent IQ, from mock server") } select { case <-done: mock.Stop() case <-time.After(defaultChannelTimeout): - t.Errorf("The mock server failed to finish its job !") + t.Fatal("The mock server failed to finish its job !") } } diff --git a/tcp_server_mock.go b/tcp_server_mock.go index efdda23..1084cbd 100644 --- a/tcp_server_mock.go +++ b/tcp_server_mock.go @@ -120,6 +120,7 @@ func respondToIQ(t *testing.T, c net.Conn) { recvBuf := make([]byte, 1024) var iqR stanza.IQ _, err := c.Read(recvBuf[:]) // recv data + if err != nil { if netErr, ok := err.(net.Error); ok && netErr.Timeout() { t.Errorf("read timeout: %s", err) @@ -155,11 +156,22 @@ func respondToIQ(t *testing.T, c net.Conn) { // When a presence stanza is automatically sent (right now it's the case in the client), we may want to discard it // and test further stanzas. func discardPresence(t *testing.T, c net.Conn) { - decoder := xml.NewDecoder(c) c.SetDeadline(time.Now().Add(defaultTimeout)) defer c.SetDeadline(time.Time{}) var presenceStz stanza.Presence - err := decoder.Decode(&presenceStz) + + recvBuf := make([]byte, len(InitialPresence)) + _, err := c.Read(recvBuf[:]) // recv data + + if err != nil { + if netErr, ok := err.(net.Error); ok && netErr.Timeout() { + t.Errorf("read timeout: %s", err) + } else { + t.Errorf("read error: %s", err) + } + } + xml.Unmarshal(recvBuf, &presenceStz) + if err != nil { t.Errorf("Expected presence but this happened : %s", err.Error()) } From 3c9b0db5b80ea26f031ee38c8ab56783faf2ad72 Mon Sep 17 00:00:00 2001 From: rcorniere Date: Tue, 10 Dec 2019 17:15:16 +0100 Subject: [PATCH 5/5] Fixed decoder usage. Decoders have internal buffering, and creating many on a single TCP connection can cause issues in parsing exchanged XML documents. --- client_test.go | 83 ++++++++++++++++++--------------------- component_test.go | 78 ++++++++++++++++++------------------ doc.go | 2 +- session.go | 2 +- tcp_server_mock.go | 98 +++++++++++++++++++++++----------------------- 5 files changed, 130 insertions(+), 133 deletions(-) diff --git a/client_test.go b/client_test.go index f2b775a..8d109d0 100644 --- a/client_test.go +++ b/client_test.go @@ -5,7 +5,6 @@ import ( "encoding/xml" "errors" "fmt" - "net" "testing" "time" @@ -157,10 +156,10 @@ func TestClient_RFC3921Session(t *testing.T) { func TestClient_SendIQ(t *testing.T) { done := make(chan struct{}) // Handler for Mock server - h := func(t *testing.T, c net.Conn) { - handlerClientConnectSuccess(t, c) - discardPresence(t, c) - respondToIQ(t, c) + h := func(t *testing.T, sc *ServerConn) { + handlerClientConnectSuccess(t, sc) + discardPresence(t, sc) + respondToIQ(t, sc) done <- struct{}{} } client, mock := mockClientConnection(t, h, testClientIqPort) @@ -199,10 +198,10 @@ func TestClient_SendIQ(t *testing.T) { func TestClient_SendIQFail(t *testing.T) { done := make(chan struct{}) // Handler for Mock server - h := func(t *testing.T, c net.Conn) { - handlerClientConnectSuccess(t, c) - discardPresence(t, c) - respondToIQ(t, c) + h := func(t *testing.T, sc *ServerConn) { + handlerClientConnectSuccess(t, sc) + discardPresence(t, sc) + respondToIQ(t, sc) done <- struct{}{} } client, mock := mockClientConnection(t, h, testClientIqFailPort) @@ -244,10 +243,10 @@ func TestClient_SendIQFail(t *testing.T) { func TestClient_SendRaw(t *testing.T) { done := make(chan struct{}) // Handler for Mock server - h := func(t *testing.T, c net.Conn) { - handlerClientConnectSuccess(t, c) - discardPresence(t, c) - respondToIQ(t, c) + h := func(t *testing.T, sc *ServerConn) { + handlerClientConnectSuccess(t, sc) + discardPresence(t, sc) + respondToIQ(t, sc) done <- struct{}{} } type testCase struct { @@ -365,48 +364,44 @@ func TestClient_DisconnectStreamManager(t *testing.T) { // Basic XMPP Server Mock Handlers. // Test connection with a basic straightforward workflow -func handlerClientConnectSuccess(t *testing.T, c net.Conn) { - decoder := xml.NewDecoder(c) - checkClientOpenStream(t, c, decoder) +func handlerClientConnectSuccess(t *testing.T, sc *ServerConn) { + checkClientOpenStream(t, sc) + sendStreamFeatures(t, sc) // Send initial features + readAuth(t, sc.decoder) + fmt.Fprintln(sc.connection, "") - sendStreamFeatures(t, c, decoder) // Send initial features - readAuth(t, decoder) - fmt.Fprintln(c, "") - - checkClientOpenStream(t, c, decoder) // Reset stream - sendBindFeature(t, c, decoder) // Send post auth features - bind(t, c, decoder) + checkClientOpenStream(t, sc) // Reset stream + sendBindFeature(t, sc) // Send post auth features + bind(t, sc) } // We expect client will abort on TLS -func handlerAbortTLS(t *testing.T, c net.Conn) { - decoder := xml.NewDecoder(c) - checkClientOpenStream(t, c, decoder) - sendStreamFeatures(t, c, decoder) // Send initial features +func handlerAbortTLS(t *testing.T, sc *ServerConn) { + checkClientOpenStream(t, sc) + sendStreamFeatures(t, sc) // Send initial features } // Test connection with mandatory session (RFC-3921) -func handlerClientConnectWithSession(t *testing.T, c net.Conn) { - decoder := xml.NewDecoder(c) - checkClientOpenStream(t, c, decoder) +func handlerClientConnectWithSession(t *testing.T, sc *ServerConn) { + checkClientOpenStream(t, sc) - sendStreamFeatures(t, c, decoder) // Send initial features - readAuth(t, decoder) - fmt.Fprintln(c, "") + sendStreamFeatures(t, sc) // Send initial features + readAuth(t, sc.decoder) + fmt.Fprintln(sc.connection, "") - checkClientOpenStream(t, c, decoder) // Reset stream - sendRFC3921Feature(t, c, decoder) // Send post auth features - bind(t, c, decoder) - session(t, c, decoder) + checkClientOpenStream(t, sc) // Reset stream + sendRFC3921Feature(t, sc) // Send post auth features + bind(t, sc) + session(t, sc) } -func checkClientOpenStream(t *testing.T, c net.Conn, decoder *xml.Decoder) { - c.SetDeadline(time.Now().Add(defaultTimeout)) - defer c.SetDeadline(time.Time{}) +func checkClientOpenStream(t *testing.T, sc *ServerConn) { + sc.connection.SetDeadline(time.Now().Add(defaultTimeout)) + defer sc.connection.SetDeadline(time.Time{}) for { // TODO clean up. That for loop is not elegant and I prefer bounded recursion. var token xml.Token - token, err := decoder.Token() + token, err := sc.decoder.Token() if err != nil { t.Errorf("cannot read next token: %s", err) } @@ -418,7 +413,7 @@ func checkClientOpenStream(t *testing.T, c net.Conn, decoder *xml.Decoder) { err = errors.New("xmpp: expected but got <" + elem.Name.Local + "> in " + elem.Name.Space) return } - if _, err := fmt.Fprintf(c, serverStreamOpen, "localhost", "streamid1", stanza.NSClient, stanza.NSStream); err != nil { + if _, err := fmt.Fprintf(sc.connection, serverStreamOpen, "localhost", "streamid1", stanza.NSClient, stanza.NSStream); err != nil { t.Errorf("cannot write server stream open: %s", err) } return @@ -426,8 +421,8 @@ func checkClientOpenStream(t *testing.T, c net.Conn, decoder *xml.Decoder) { } } -func mockClientConnection(t *testing.T, serverHandler func(*testing.T, net.Conn), port int) (*Client, ServerMock) { - mock := ServerMock{} +func mockClientConnection(t *testing.T, serverHandler func(*testing.T, *ServerConn), port int) (*Client, *ServerMock) { + mock := &ServerMock{} testServerAddress := fmt.Sprintf("%s:%d", testClientDomain, port) mock.Start(t, testServerAddress, serverHandler) diff --git a/component_test.go b/component_test.go index 48963a5..f4d1a07 100644 --- a/component_test.go +++ b/component_test.go @@ -7,7 +7,6 @@ import ( "fmt" "github.com/google/uuid" "gosrc.io/xmpp/stanza" - "net" "strings" "testing" "time" @@ -36,7 +35,7 @@ func TestHandshake(t *testing.T) { } // Tests connection process with a handshake exchange -// Tests multiple session IDs. All connections should generate a unique stream ID +// Tests multiple session IDs. All serverConnections should generate a unique stream ID func TestGenerateHandshakeId(t *testing.T) { // Using this array with a channel to make a queue of values to test // These are stream IDs that will be used to test the connection process, mixing them with the "secret" to generate @@ -56,11 +55,11 @@ func TestGenerateHandshakeId(t *testing.T) { // Performs a Component connection with a handshake. It expects to have an ID sent its way through the "uchan" // channel of this file. Otherwise it will hang for ever. - h := func(t *testing.T, c net.Conn) { - decoder := xml.NewDecoder(c) - checkOpenStreamHandshakeID(t, c, decoder, <-uchan) - readHandshakeComponent(t, decoder) - fmt.Fprintln(c, "") // That's all the server needs to return (see xep-0114) + h := func(t *testing.T, sc *ServerConn) { + + checkOpenStreamHandshakeID(t, sc, <-uchan) + readHandshakeComponent(t, sc.decoder) + fmt.Fprintln(sc.connection, "") // That's all the server needs to return (see xep-0114) return } @@ -122,8 +121,8 @@ func TestDecoder(t *testing.T) { // Tests sending an IQ to the server, and getting the response func TestSendIq(t *testing.T) { done := make(chan struct{}) - h := func(t *testing.T, c net.Conn) { - handlerForComponentIQSend(t, c) + h := func(t *testing.T, sc *ServerConn) { + handlerForComponentIQSend(t, sc) done <- struct{}{} } @@ -164,8 +163,8 @@ func TestSendIq(t *testing.T) { // Checking that error handling is done properly client side when an invalid IQ is sent and the server responds in kind. func TestSendIqFail(t *testing.T) { done := make(chan struct{}) - h := func(t *testing.T, c net.Conn) { - handlerForComponentIQSend(t, c) + h := func(t *testing.T, sc *ServerConn) { + handlerForComponentIQSend(t, sc) done <- struct{}{} } //Connecting to a mock server, initialized with given port and handler function @@ -213,27 +212,30 @@ func TestSendIqFail(t *testing.T) { func TestSendRaw(t *testing.T) { done := make(chan struct{}) // Handler for the mock server - h := func(t *testing.T, c net.Conn) { + h := func(t *testing.T, sc *ServerConn) { // Completes the connection by exchanging handshakes - handlerForComponentHandshakeDefaultID(t, c) - receiveIq(c, xml.NewDecoder(c)) + handlerForComponentHandshakeDefaultID(t, sc) + respondToIQ(t, sc) done <- struct{}{} } type testCase struct { req string shouldErr bool + port int } testRequests := make(map[string]testCase) // Sending a correct IQ of type get. Not supposed to err testRequests["Correct IQ"] = testCase{ req: ``, shouldErr: false, + port: testSendRawPort + 100, } // Sending an IQ with a missing ID. Should err testRequests["IQ with missing ID"] = testCase{ req: ``, shouldErr: true, + port: testSendRawPort + 200, } // A handler for the component. @@ -247,7 +249,7 @@ func TestSendRaw(t *testing.T) { for name, tcase := range testRequests { t.Run(name, func(st *testing.T) { //Connecting to a mock server, initialized with given port and handler function - c, m := mockComponentConnection(t, testSendRawPort, h) + c, m := mockComponentConnection(t, tcase.port, h) c.ErrorHandler = errHandler // Sending raw xml from test case err := c.SendRaw(tcase.req) @@ -328,10 +330,10 @@ func TestStreamManagerDisconnect(t *testing.T) { // Init mock server and connection // Creating a mock server and connecting a Component to it. Initialized with given port and handler function // The Component and mock are both returned -func mockComponentConnection(t *testing.T, port int, handler func(t *testing.T, c net.Conn)) (*Component, *ServerMock) { +func mockComponentConnection(t *testing.T, port int, handler func(t *testing.T, sc *ServerConn)) (*Component, *ServerMock) { // Init mock server testComponentAddress := fmt.Sprintf("%s:%d", testComponentDomain, port) - mock := ServerMock{} + mock := &ServerMock{} mock.Start(t, testComponentAddress, handler) //================================== @@ -345,7 +347,9 @@ func mockComponentConnection(t *testing.T, port int, handler func(t *testing.T, t.Errorf("%+v", err) } - return c, &mock + // Now that the Component is connected, let's set the xml.Decoder for the server + + return c, mock } func makeBasicComponent(name string, mockServerAddr string, t *testing.T) *Component { @@ -380,19 +384,19 @@ func componentDefaultErrorHandler(err error) { // Sends IQ response to Component request. // No parsing of the request here. We just check that it's valid, and send the default response. -func handlerForComponentIQSend(t *testing.T, c net.Conn) { +func handlerForComponentIQSend(t *testing.T, sc *ServerConn) { // Completes the connection by exchanging handshakes - handlerForComponentHandshakeDefaultID(t, c) - respondToIQ(t, c) + handlerForComponentHandshakeDefaultID(t, sc) + respondToIQ(t, sc) } // Used for ID and handshake related tests -func checkOpenStreamHandshakeID(t *testing.T, c net.Conn, decoder *xml.Decoder, streamID string) { - c.SetDeadline(time.Now().Add(defaultTimeout)) - defer c.SetDeadline(time.Time{}) +func checkOpenStreamHandshakeID(t *testing.T, sc *ServerConn, streamID string) { + sc.connection.SetDeadline(time.Now().Add(defaultTimeout)) + defer sc.connection.SetDeadline(time.Time{}) for { // TODO clean up. That for loop is not elegant and I prefer bounded recursion. - token, err := decoder.Token() + token, err := sc.decoder.Token() if err != nil { t.Errorf("cannot read next token: %s", err) } @@ -404,7 +408,7 @@ func checkOpenStreamHandshakeID(t *testing.T, c net.Conn, decoder *xml.Decoder, err = errors.New("xmpp: expected but got <" + elem.Name.Local + "> in " + elem.Name.Space) return } - if _, err := fmt.Fprintf(c, serverStreamOpen, "localhost", streamID, stanza.NSComponent, stanza.NSStream); err != nil { + if _, err := fmt.Fprintf(sc.connection, serverStreamOpen, "localhost", streamID, stanza.NSComponent, stanza.NSStream); err != nil { t.Errorf("cannot write server stream open: %s", err) } return @@ -412,16 +416,15 @@ func checkOpenStreamHandshakeID(t *testing.T, c net.Conn, decoder *xml.Decoder, } } -func checkOpenStreamHandshakeDefaultID(t *testing.T, c net.Conn, decoder *xml.Decoder) { - checkOpenStreamHandshakeID(t, c, decoder, defaultStreamID) +func checkOpenStreamHandshakeDefaultID(t *testing.T, sc *ServerConn) { + checkOpenStreamHandshakeID(t, sc, defaultStreamID) } // Performs a Component connection with a handshake. It uses a default ID defined in this file as a constant. // This handler is supposed to fail by sending a "message" stanza instead of a stanza to finalize the handshake. -func handlerComponentFailedHandshakeDefaultID(t *testing.T, c net.Conn) { - decoder := xml.NewDecoder(c) - checkOpenStreamHandshakeDefaultID(t, c, decoder) - readHandshakeComponent(t, decoder) +func handlerComponentFailedHandshakeDefaultID(t *testing.T, sc *ServerConn) { + checkOpenStreamHandshakeDefaultID(t, sc) + readHandshakeComponent(t, sc.decoder) // Send a message, instead of a "" tag, to fail the handshake process dans disconnect the client. me := stanza.Message{ @@ -429,7 +432,7 @@ func handlerComponentFailedHandshakeDefaultID(t *testing.T, c net.Conn) { Body: "Fail my handshake.", } s, _ := xml.Marshal(me) - fmt.Fprintln(c, string(s)) + fmt.Fprintln(sc.connection, string(s)) return } @@ -454,10 +457,9 @@ func readHandshakeComponent(t *testing.T, decoder *xml.Decoder) { // Performs a Component connection with a handshake. It uses a default ID defined in this file as a constant. // Used in the mock server as a Handler -func handlerForComponentHandshakeDefaultID(t *testing.T, c net.Conn) { - decoder := xml.NewDecoder(c) - checkOpenStreamHandshakeDefaultID(t, c, decoder) - readHandshakeComponent(t, decoder) - fmt.Fprintln(c, "") // That's all the server needs to return (see xep-0114) +func handlerForComponentHandshakeDefaultID(t *testing.T, sc *ServerConn) { + checkOpenStreamHandshakeDefaultID(t, sc) + readHandshakeComponent(t, sc.decoder) + fmt.Fprintln(sc.connection, "") // That's all the server needs to return (see xep-0114) return } diff --git a/doc.go b/doc.go index 40f4f6a..f29bbf6 100644 --- a/doc.go +++ b/doc.go @@ -29,7 +29,7 @@ Components XMPP components can typically be used to extends the features of an XMPP server, in a portable way, using component protocol over persistent TCP -connections. +serverConnections. Component protocol is defined in XEP-114 (https://xmpp.org/extensions/xep-0114.html). diff --git a/session.go b/session.go index 22d76b2..6b9c75a 100644 --- a/session.go +++ b/session.go @@ -119,7 +119,7 @@ func (s *Session) startTlsIfSupported(o Config) { return } - // If we do not allow cleartext connections, make it explicit that server do not support starttls + // If we do not allow cleartext serverConnections, make it explicit that server do not support starttls if !o.Insecure { s.err = errors.New("XMPP server does not advertise support for starttls") } diff --git a/tcp_server_mock.go b/tcp_server_mock.go index 1084cbd..c8f5d97 100644 --- a/tcp_server_mock.go +++ b/tcp_server_mock.go @@ -41,16 +41,21 @@ const ( // ClientHandler is passed by the test client to provide custom behaviour to // the TCP server mock. This allows customizing the server behaviour to allow // testing clients under various scenarii. -type ClientHandler func(t *testing.T, conn net.Conn) +type ClientHandler func(t *testing.T, serverConn *ServerConn) // ServerMock is a simple TCP server that can be use to mock basic server // behaviour to test clients. type ServerMock struct { - t *testing.T - handler ClientHandler - listener net.Listener - connections []net.Conn - done chan struct{} + t *testing.T + handler ClientHandler + listener net.Listener + serverConnections []*ServerConn + done chan struct{} +} + +type ServerConn struct { + connection net.Conn + decoder *xml.Decoder } // Start launches the mock TCP server, listening to an actual address / port. @@ -68,9 +73,9 @@ func (mock *ServerMock) Stop() { if mock.listener != nil { mock.listener.Close() } - // Close all existing connections - for _, c := range mock.connections { - c.Close() + // Close all existing serverConnections + for _, c := range mock.serverConnections { + c.connection.Close() } } @@ -90,13 +95,14 @@ func (mock *ServerMock) init(addr string) error { return nil } -// loop accepts connections and creates a go routine per connection. +// loop accepts serverConnections and creates a go routine per connection. // The go routine is running the client handler, that is used to provide the // real TCP server behaviour. func (mock *ServerMock) loop() { listener := mock.listener for { conn, err := listener.Accept() + serverConn := &ServerConn{conn, xml.NewDecoder(conn)} if err != nil { select { case <-mock.done: @@ -106,9 +112,10 @@ func (mock *ServerMock) loop() { } return } - mock.connections = append(mock.connections, conn) + mock.serverConnections = append(mock.serverConnections, serverConn) + // TODO Create and pass a context to cancel the handler if they are still around = avoid possible leak on complex handlers - go mock.handler(mock.t, conn) + go mock.handler(mock.t, serverConn) } } @@ -116,27 +123,20 @@ func (mock *ServerMock) loop() { // A few functions commonly used for tests. Trying to avoid duplicates in client and component test files. //====================================================================================================================== -func respondToIQ(t *testing.T, c net.Conn) { - recvBuf := make([]byte, 1024) - var iqR stanza.IQ - _, err := c.Read(recvBuf[:]) // recv data - +func respondToIQ(t *testing.T, sc *ServerConn) { + // Decoder to parse the request + iqReq, err := receiveIq(sc) if err != nil { - if netErr, ok := err.(net.Error); ok && netErr.Timeout() { - t.Errorf("read timeout: %s", err) - } else { - t.Errorf("read error: %s", err) - } + t.Fatalf("failed to receive IQ : %s", err.Error()) } - xml.Unmarshal(recvBuf, &iqR) - if !iqR.IsValid() { - mockIQError(c) + if !iqReq.IsValid() { + mockIQError(sc.connection) return } // Crafting response - iqResp := stanza.NewIQ(stanza.Attrs{Type: stanza.IQTypeResult, From: iqR.To, To: iqR.From, Id: iqR.Id, Lang: "en"}) + iqResp := stanza.NewIQ(stanza.Attrs{Type: stanza.IQTypeResult, From: iqReq.To, To: iqReq.From, Id: iqReq.Id, Lang: "en"}) disco := iqResp.DiscoInfo() disco.AddFeatures("vcard-temp", `http://jabber.org/protocol/address`) @@ -146,7 +146,7 @@ func respondToIQ(t *testing.T, c net.Conn) { // Sending response to the Component mResp, err := xml.Marshal(iqResp) - _, err = fmt.Fprintln(c, string(mResp)) + _, err = fmt.Fprintln(sc.connection, string(mResp)) if err != nil { t.Errorf("Could not send response stanza : %s", err) } @@ -155,13 +155,13 @@ func respondToIQ(t *testing.T, c net.Conn) { // When a presence stanza is automatically sent (right now it's the case in the client), we may want to discard it // and test further stanzas. -func discardPresence(t *testing.T, c net.Conn) { - c.SetDeadline(time.Now().Add(defaultTimeout)) - defer c.SetDeadline(time.Time{}) +func discardPresence(t *testing.T, sc *ServerConn) { + sc.connection.SetDeadline(time.Now().Add(defaultTimeout)) + defer sc.connection.SetDeadline(time.Time{}) var presenceStz stanza.Presence recvBuf := make([]byte, len(InitialPresence)) - _, err := c.Read(recvBuf[:]) // recv data + _, err := sc.connection.Read(recvBuf[:]) // recv data if err != nil { if netErr, ok := err.(net.Error); ok && netErr.Timeout() { @@ -178,11 +178,11 @@ func discardPresence(t *testing.T, c net.Conn) { } // Reads next request coming from the Component. Expecting it to be an IQ request -func receiveIq(c net.Conn, decoder *xml.Decoder) (*stanza.IQ, error) { - c.SetDeadline(time.Now().Add(defaultTimeout)) - defer c.SetDeadline(time.Time{}) +func receiveIq(sc *ServerConn) (*stanza.IQ, error) { + sc.connection.SetDeadline(time.Now().Add(defaultTimeout)) + defer sc.connection.SetDeadline(time.Time{}) var iqStz stanza.IQ - err := decoder.Decode(&iqStz) + err := sc.decoder.Decode(&iqStz) if err != nil { return nil, err } @@ -202,14 +202,14 @@ func mockIQError(c net.Conn) { fmt.Fprintln(c, ``) } -func sendStreamFeatures(t *testing.T, c net.Conn, _ *xml.Decoder) { +func sendStreamFeatures(t *testing.T, sc *ServerConn) { // This is a basic server, supporting only 1 stream feature: SASL Plain Auth features := ` PLAIN ` - if _, err := fmt.Fprintln(c, features); err != nil { + if _, err := fmt.Fprintln(sc.connection, features); err != nil { t.Errorf("cannot send stream feature: %s", err) } } @@ -237,29 +237,29 @@ func readAuth(t *testing.T, decoder *xml.Decoder) string { return "" } -func sendBindFeature(t *testing.T, c net.Conn, _ *xml.Decoder) { +func sendBindFeature(t *testing.T, sc *ServerConn) { // This is a basic server, supporting only 1 stream feature after auth: resource binding features := ` ` - if _, err := fmt.Fprintln(c, features); err != nil { + if _, err := fmt.Fprintln(sc.connection, features); err != nil { t.Errorf("cannot send stream feature: %s", err) } } -func sendRFC3921Feature(t *testing.T, c net.Conn, _ *xml.Decoder) { +func sendRFC3921Feature(t *testing.T, sc *ServerConn) { // This is a basic server, supporting only 2 features after auth: resource & session binding features := ` ` - if _, err := fmt.Fprintln(c, features); err != nil { + if _, err := fmt.Fprintln(sc.connection, features); err != nil { t.Errorf("cannot send stream feature: %s", err) } } -func bind(t *testing.T, c net.Conn, decoder *xml.Decoder) { - se, err := stanza.NextStart(decoder) +func bind(t *testing.T, sc *ServerConn) { + se, err := stanza.NextStart(sc.decoder) if err != nil { t.Errorf("cannot read bind: %s", err) return @@ -267,7 +267,7 @@ func bind(t *testing.T, c net.Conn, decoder *xml.Decoder) { iq := &stanza.IQ{} // Decode element into pointer storage - if err = decoder.DecodeElement(&iq, &se); err != nil { + if err = sc.decoder.DecodeElement(&iq, &se); err != nil { t.Errorf("cannot decode bind iq: %s", err) return } @@ -280,12 +280,12 @@ func bind(t *testing.T, c net.Conn, decoder *xml.Decoder) { %s ` - fmt.Fprintf(c, result, iq.Id, "test@localhost/test") // TODO use real JID + fmt.Fprintf(sc.connection, result, iq.Id, "test@localhost/test") // TODO use real JID } } -func session(t *testing.T, c net.Conn, decoder *xml.Decoder) { - se, err := stanza.NextStart(decoder) +func session(t *testing.T, sc *ServerConn) { + se, err := stanza.NextStart(sc.decoder) if err != nil { t.Errorf("cannot read session: %s", err) return @@ -293,7 +293,7 @@ func session(t *testing.T, c net.Conn, decoder *xml.Decoder) { iq := &stanza.IQ{} // Decode element into pointer storage - if err = decoder.DecodeElement(&iq, &se); err != nil { + if err = sc.decoder.DecodeElement(&iq, &se); err != nil { t.Errorf("cannot decode session iq: %s", err) return } @@ -301,6 +301,6 @@ func session(t *testing.T, c net.Conn, decoder *xml.Decoder) { switch iq.Payload.(type) { case *stanza.StreamSession: result := `` - fmt.Fprintf(c, result, iq.Id) + fmt.Fprintf(sc.connection, result, iq.Id) } }