diff --git a/client.go b/client.go index cc152f3..4d7857e 100644 --- a/client.go +++ b/client.go @@ -4,7 +4,6 @@ import ( "context" "encoding/xml" "errors" - "fmt" "io" "net" "time" @@ -60,21 +59,21 @@ type EventManager struct { Handler EventHandler } -func (em EventManager) updateState(state ConnState) { +func (em *EventManager) updateState(state ConnState) { em.CurrentState = state if em.Handler != nil { em.Handler(Event{State: em.CurrentState}) } } -func (em EventManager) disconnected(state SMState) { +func (em *EventManager) disconnected(state SMState) { em.CurrentState = StateDisconnected if em.Handler != nil { em.Handler(Event{State: em.CurrentState, SMState: state}) } } -func (em EventManager) streamError(error, desc string) { +func (em *EventManager) streamError(error, desc string) { em.CurrentState = StateStreamError if em.Handler != nil { em.Handler(Event{State: em.CurrentState, StreamError: error, Description: desc}) @@ -110,6 +109,9 @@ Setting up the client / Checking the parameters // If host is not specified, the DNS SRV should be used to find the host from the domainpart of the JID. // Default the port to 5222. func NewClient(config Config, r *Router, errorHandler func(error)) (c *Client, err error) { + if config.KeepaliveInterval == 0 { + config.KeepaliveInterval = time.Second * 30 + } // Parse JID if config.parsedJid, err = NewJid(config.Jid); err != nil { err = errors.New("missing jid") @@ -188,7 +190,7 @@ func (c *Client) Resume(state SMState) error { // Start the keepalive go routine keepaliveQuit := make(chan struct{}) - go keepalive(c, keepaliveQuit) + go keepalive(c.transport, c.config.KeepaliveInterval, keepaliveQuit) // Start the receiver go routine state = c.Session.SMState go c.recv(state, keepaliveQuit) @@ -197,7 +199,7 @@ func (c *Client) Resume(state SMState) error { //fmt.Fprintf(client.conn, "%s%s", "chat", "Online") // TODO: Do we always want to send initial presence automatically ? // Do we need an option to avoid that or do we rely on client to send the presence itself ? - _, err = fmt.Fprintf(c.transport, "") + err = c.sendWithWriter(c.transport, []byte("")) return err } @@ -312,10 +314,8 @@ func (c *Client) recv(state SMState, keepaliveQuit chan<- struct{}) { // Loop: send whitespace keepalive to server // This is use to keep the connection open, but also to detect connection loss // and trigger proper client connection shutdown. -func keepalive(c *Client, quit <-chan struct{}) { - // TODO: Make keepalive interval configurable - transport := c.transport - ticker := time.NewTicker(30 * time.Second) +func keepalive(transport Transport, interval time.Duration, quit <-chan struct{}) { + ticker := time.NewTicker(interval) for { select { case <-ticker.C: diff --git a/client_test.go b/client_test.go index 15e104f..0caace0 100644 --- a/client_test.go +++ b/client_test.go @@ -19,6 +19,24 @@ const ( testClientDomain = "localhost" ) +func TestEventManager(t *testing.T) { + mgr := EventManager{} + mgr.updateState(StateConnected) + if mgr.CurrentState != StateConnected { + t.Fatal("CurrentState not updated by updateState()") + } + + mgr.disconnected(SMState{}) + if mgr.CurrentState != StateDisconnected { + t.Fatalf("CurrentState not reset by disconnected()") + } + + mgr.streamError(ErrTLSNotSupported.Error(), "") + if mgr.CurrentState != StateStreamError { + t.Fatalf("CurrentState not set by streamError()") + } +} + func TestClient_Connect(t *testing.T) { // Setup Mock server mock := ServerMock{} diff --git a/component.go b/component.go index 2f61aef..8b96240 100644 --- a/component.go +++ b/component.go @@ -85,7 +85,7 @@ func (c *Component) Resume(sm SMState) error { c.updateState(StateConnected) // Authentication - if _, err := fmt.Fprintf(c.transport, "%s", c.handshake(streamId)); err != nil { + if err := c.sendWithWriter(c.transport, []byte(fmt.Sprintf("%s", c.handshake(streamId)))); err != nil { c.updateState(StateStreamError) return NewConnError(errors.New("cannot send handshake "+err.Error()), false) @@ -159,12 +159,18 @@ func (c *Component) Send(packet stanza.Packet) error { return errors.New("cannot marshal packet " + err.Error()) } - if _, err := fmt.Fprintf(transport, string(data)); err != nil { + if err := c.sendWithWriter(transport, data); err != nil { return errors.New("cannot send packet " + err.Error()) } return nil } +func (c *Component) sendWithWriter(writer io.Writer, packet []byte) error { + var err error + _, err = writer.Write(packet) + return err +} + // SendIQ sends an IQ set or get stanza to the server. If a result is received // the provided handler function will automatically be called. // @@ -195,7 +201,7 @@ func (c *Component) SendRaw(packet string) error { } var err error - _, err = fmt.Fprintf(transport, packet) + err = c.sendWithWriter(transport, []byte(packet)) return err } diff --git a/config.go b/config.go index e3ea108..da4d4ab 100644 --- a/config.go +++ b/config.go @@ -2,6 +2,7 @@ package xmpp import ( "os" + "time" ) // Config & TransportConfiguration must not be modified after having been passed to NewClient. Any @@ -9,12 +10,13 @@ import ( type Config struct { TransportConfiguration - Jid string - parsedJid *Jid // For easier manipulation - Credential Credential - StreamLogger *os.File // Used for debugging - Lang string // TODO: should default to 'en' - ConnectTimeout int // Client timeout in seconds. Default to 15 + Jid string + parsedJid *Jid // For easier manipulation + Credential Credential + StreamLogger *os.File // Used for debugging + Lang string // TODO: should default to 'en' + KeepaliveInterval time.Duration // Interval between keepalive packets + ConnectTimeout int // Client timeout in seconds. Default to 15 // Insecure can be set to true to allow to open a session without TLS. If TLS // is supported on the server, we will still try to use it. Insecure bool diff --git a/tcp_server_mock.go b/tcp_server_mock.go index 4afed80..efdda23 100644 --- a/tcp_server_mock.go +++ b/tcp_server_mock.go @@ -117,21 +117,25 @@ func (mock *ServerMock) loop() { //====================================================================================================================== func respondToIQ(t *testing.T, c net.Conn) { - // Decoder to parse the request - decoder := xml.NewDecoder(c) - - iqReq, err := receiveIq(c, decoder) + recvBuf := make([]byte, 1024) + var iqR stanza.IQ + _, err := c.Read(recvBuf[:]) // recv data if err != nil { - t.Fatalf("failed to receive IQ : %s", err.Error()) + if netErr, ok := err.(net.Error); ok && netErr.Timeout() { + t.Errorf("read timeout: %s", err) + } else { + t.Errorf("read error: %s", err) + } } + xml.Unmarshal(recvBuf, &iqR) - if !iqReq.IsValid() { + if !iqR.IsValid() { mockIQError(c) return } // Crafting response - iqResp := stanza.NewIQ(stanza.Attrs{Type: stanza.IQTypeResult, From: iqReq.To, To: iqReq.From, Id: iqReq.Id, Lang: "en"}) + iqResp := stanza.NewIQ(stanza.Attrs{Type: stanza.IQTypeResult, From: iqR.To, To: iqR.From, Id: iqR.Id, Lang: "en"}) disco := iqResp.DiscoInfo() disco.AddFeatures("vcard-temp", `http://jabber.org/protocol/address`)