package ws import ( "encoding/json" "fmt" "log" "net/http" "strings" "time" "git.sp4ke.com/sp4ke/bit4sat/bus" "git.sp4ke.com/sp4ke/bit4sat/db" "github.com/gin-contrib/sessions" "github.com/gin-gonic/gin" "github.com/gorilla/websocket" "github.com/mediocregopher/radix/v3" "github.com/segmentio/ksuid" ) const ( // Time allowed to write message to peer writeWait = 10 * time.Second // Time allowed to read the next pong message from the client. pongWait = 30 * time.Second // Send pings to client with this period. Must be less than pongWait. pingPeriod = (pongWait * 9) / 10 //pingPeriod = 5 * time.Second // Maximum message size maxMessageSize = 512 WebsocketIdName = "websocket-id" ) var upgrader = websocket.Upgrader{ ReadBufferSize: 1024, WriteBufferSize: 1024, CheckOrigin: func(r *http.Request) bool { return true }, } // Keep reference to clients so they can be used in goroutines var C = make(map[*Client]bool) // Interface to talk to websocket type Client struct { id ksuid.KSUID // websocket connection conn *websocket.Conn // PubSub message channel for this websocket subChannel chan radix.PubSubMessage // Name of main subscribed channel for this websocket channelName string // upload ids registered to this client uploadId string // upload notifications channel uploadNotifChannel chan radix.PubSubMessage } func (c *Client) readPump() { defer func() { // Delete client from referenc list delete(C, c) // Close ws connection c.conn.Close() }() c.conn.SetReadLimit(maxMessageSize) c.conn.SetReadDeadline(time.Now().Add(pongWait)) c.conn.SetPongHandler(func(string) error { //log.Println("pong") c.conn.SetReadDeadline(time.Now().Add(pongWait)) return nil }) for { _, message, err := c.conn.ReadMessage() if err != nil { if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) { log.Printf("error: %v", err) } break } recvData := make(map[string]interface{}) json.Unmarshal(message, &recvData) log.Println(recvData) } } func (c *Client) writePump() { pingTicker := time.NewTicker(pingPeriod) defer func() { log.Println("websocket closing") pingTicker.Stop() c.conn.Close() }() // Subscribe to general notifications for this client log.Printf("socket subscribing to %s", c.channelName) if err := db.DB.RedisPubSub.Subscribe(c.subChannel, c.channelName); err != nil { log.Println(err) return } // subscribe to notifications related to this channel's registered upload // ids if err := db.DB.RedisPubSub.PSubscribe(c.uploadNotifChannel, fmt.Sprintf("%s_*", bus.UploadUpdateChannelPrefix)); err != nil { log.Println(err) return } for { select { case msg := <-c.subChannel: //log.Printf("received msg %s on socket main channel", msg) jsonMsg := bus.Message{} if err := json.Unmarshal(msg.Message, &jsonMsg); err != nil { log.Printf("ws error reading from pubsub: %s", err) break } if jsonMsg.Type == bus.SetUploadId { log.Printf("registering uploadId: %s to socket", jsonMsg.UploadId) c.uploadId = jsonMsg.UploadId } case msg := <-c.uploadNotifChannel: log.Printf("websocket received upload paid notification on channel %s", msg.Channel) log.Printf("our registered ids are %s", c.uploadId) // If we have no upload id registered for this client break if c.uploadId == "" { continue } // Check if the message matches our upload id slice := strings.SplitN(msg.Channel, "_", 3) if len(slice) != 3 { log.Printf("error decoding channel name %s", msg.Channel) } msgTarget := slice[2] // If the target is any of our registered upload ids // broadcast the message if c.uploadId == msgTarget { log.Println("this message is for us sending to client") jsonMsg := bus.Message{} err := json.Unmarshal(msg.Message, &jsonMsg) if err != nil { log.Printf("unmarshal error %s", err) break } switch jsonMsg.Type { case bus.PaymentReceived: c.conn.SetWriteDeadline(time.Now().Add(writeWait)) w, err := c.conn.NextWriter(websocket.TextMessage) if err != nil { log.Println(err) return } socketMsg := ProtoMessage{} socketMsg.Type = messageTypes[invoicePaid] socketMsg.Data = jsonMsg.Data enc := json.NewEncoder(w) err = enc.Encode(socketMsg) if err != nil { log.Println(err) return } // No need to encode the message it's already in json if err = w.Close(); err != nil { fmt.Println(err) return } } // Handle different message types } case <-pingTicker.C: //log.Println("ping") c.conn.SetWriteDeadline(time.Now().Add(writeWait)) if err := c.conn.WriteMessage(websocket.PingMessage, nil); err != nil { log.Printf("websocket ping error: %s", err) return } } } } func Serve(c *gin.Context) { var err error log.Println("websocket request") session := sessions.Default(c) // First check if cookie is already set with upload id in case this is a // reconnection // var finalSessId ksuid.KSUID socketSessionId := session.Get(WebsocketIdName) if socketSessionId == nil { // Create unique id for this session in order to store it // in the websocket bus finalSessId, err = ksuid.NewRandomWithTime(time.Now()) if err != nil { log.Println("error generating ksuid for websocket") return } log.Printf("using socket id: %s", finalSessId) session.Set(WebsocketIdName, finalSessId.String()) session.Save() log.Printf("Writing socket session id to header: %s", c.Writer.Header()) } else { // Reuse socket session id log.Printf("reusing websocket id %s", socketSessionId) finalSessId, err = ksuid.Parse(socketSessionId.(string)) if err != nil { log.Println("could not parse websocket session id") } } conn, err := upgrader.Upgrade(c.Writer, c.Request, c.Writer.Header()) if err != nil { if _, ok := err.(websocket.HandshakeError); !ok { log.Printf("handshake error: %s", err) } log.Println(err) return } client := &Client{ id: finalSessId, conn: conn, subChannel: make(chan radix.PubSubMessage), uploadNotifChannel: make(chan radix.PubSubMessage), channelName: fmt.Sprintf("%s_%s", bus.WebsocketPubSubPrefix, finalSessId), } C[client] = true go client.writePump() go client.readPump() }