bit4sat/ws/server.go
2019-04-02 15:53:00 +02:00

275 lines
6.3 KiB
Go

package ws
import (
"encoding/json"
"fmt"
"log"
"net/http"
"strings"
"time"
"git.sp4ke.com/sp4ke/bit4sat/bus"
"git.sp4ke.com/sp4ke/bit4sat/db"
"github.com/gin-contrib/sessions"
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
"github.com/mediocregopher/radix/v3"
"github.com/segmentio/ksuid"
)
const (
// Time allowed to write message to peer
writeWait = 10 * time.Second
// Time allowed to read the next pong message from the client.
pongWait = 30 * time.Second
// Send pings to client with this period. Must be less than pongWait.
pingPeriod = (pongWait * 9) / 10
//pingPeriod = 5 * time.Second
// Maximum message size
maxMessageSize = 512
WebsocketIdName = "websocket-id"
)
var upgrader = websocket.Upgrader{
ReadBufferSize: 1024,
WriteBufferSize: 1024,
CheckOrigin: func(r *http.Request) bool {
return true
},
}
// Keep reference to clients so they can be used in goroutines
var C = make(map[*Client]bool)
// Interface to talk to websocket
type Client struct {
id ksuid.KSUID
// websocket connection
conn *websocket.Conn
// PubSub message channel for this websocket
subChannel chan radix.PubSubMessage
// Name of main subscribed channel for this websocket
channelName string
// upload ids registered to this client
uploadId string
// upload notifications channel
uploadNotifChannel chan radix.PubSubMessage
}
func (c *Client) readPump() {
defer func() {
// Delete client from referenc list
delete(C, c)
// Close ws connection
c.conn.Close()
}()
c.conn.SetReadLimit(maxMessageSize)
c.conn.SetReadDeadline(time.Now().Add(pongWait))
c.conn.SetPongHandler(func(string) error {
//log.Println("pong")
c.conn.SetReadDeadline(time.Now().Add(pongWait))
return nil
})
for {
_, message, err := c.conn.ReadMessage()
if err != nil {
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway,
websocket.CloseAbnormalClosure) {
log.Printf("error: %v", err)
}
break
}
recvData := make(map[string]interface{})
json.Unmarshal(message, &recvData)
log.Println(recvData)
}
}
func (c *Client) writePump() {
pingTicker := time.NewTicker(pingPeriod)
defer func() {
log.Println("websocket closing")
pingTicker.Stop()
c.conn.Close()
}()
// Subscribe to general notifications for this client
log.Printf("socket subscribing to %s", c.channelName)
if err := db.DB.RedisPubSub.Subscribe(c.subChannel,
c.channelName); err != nil {
log.Println(err)
return
}
// subscribe to notifications related to this channel's registered upload
// ids
if err := db.DB.RedisPubSub.PSubscribe(c.uploadNotifChannel,
fmt.Sprintf("%s_*", bus.UploadUpdateChannelPrefix)); err != nil {
log.Println(err)
return
}
for {
select {
case msg := <-c.subChannel:
//log.Printf("received msg %s on socket main channel", msg)
jsonMsg := bus.Message{}
if err := json.Unmarshal(msg.Message, &jsonMsg); err != nil {
log.Printf("ws error reading from pubsub: %s", err)
break
}
if jsonMsg.Type == bus.SetUploadId {
log.Printf("registering uploadId: %s to socket", jsonMsg.UploadId)
c.uploadId = jsonMsg.UploadId
}
case msg := <-c.uploadNotifChannel:
log.Printf("websocket received upload paid notification on channel %s", msg.Channel)
log.Printf("our registered ids are %s", c.uploadId)
// If we have no upload id registered for this client break
if c.uploadId == "" {
continue
}
// Check if the message matches our upload id
slice := strings.SplitN(msg.Channel, "_", 3)
if len(slice) != 3 {
log.Printf("error decoding channel name %s", msg.Channel)
}
msgTarget := slice[2]
// If the target is any of our registered upload ids
// broadcast the message
if c.uploadId == msgTarget {
log.Println("this message is for us sending to client")
jsonMsg := bus.Message{}
err := json.Unmarshal(msg.Message, &jsonMsg)
if err != nil {
log.Printf("unmarshal error %s", err)
break
}
switch jsonMsg.Type {
case bus.PaymentReceived:
c.conn.SetWriteDeadline(time.Now().Add(writeWait))
w, err := c.conn.NextWriter(websocket.TextMessage)
if err != nil {
log.Println(err)
return
}
socketMsg := ProtoMessage{}
socketMsg.Type = messageTypes[invoicePaid]
socketMsg.Data = jsonMsg.Data
enc := json.NewEncoder(w)
err = enc.Encode(socketMsg)
if err != nil {
log.Println(err)
return
}
// No need to encode the message it's already in json
if err = w.Close(); err != nil {
fmt.Println(err)
return
}
}
// Handle different message types
}
case <-pingTicker.C:
//log.Println("ping")
c.conn.SetWriteDeadline(time.Now().Add(writeWait))
if err := c.conn.WriteMessage(websocket.PingMessage, nil); err != nil {
log.Printf("websocket ping error: %s", err)
return
}
}
}
}
func Serve(c *gin.Context) {
var err error
log.Println("websocket request")
session := sessions.Default(c)
// First check if cookie is already set with upload id in case this is a
// reconnection
//
var finalSessId ksuid.KSUID
socketSessionId := session.Get(WebsocketIdName)
if socketSessionId == nil {
// Create unique id for this session in order to store it
// in the websocket bus
finalSessId, err = ksuid.NewRandomWithTime(time.Now())
if err != nil {
log.Println("error generating ksuid for websocket")
return
}
log.Printf("using socket id: %s", finalSessId)
session.Set(WebsocketIdName, finalSessId.String())
session.Save()
log.Printf("Writing socket session id to header: %s", c.Writer.Header())
} else {
// Reuse socket session id
log.Printf("reusing websocket id %s", socketSessionId)
finalSessId, err = ksuid.Parse(socketSessionId.(string))
if err != nil {
log.Println("could not parse websocket session id")
}
}
conn, err := upgrader.Upgrade(c.Writer, c.Request, c.Writer.Header())
if err != nil {
if _, ok := err.(websocket.HandshakeError); !ok {
log.Printf("handshake error: %s", err)
}
log.Println(err)
return
}
client := &Client{
id: finalSessId,
conn: conn,
subChannel: make(chan radix.PubSubMessage),
uploadNotifChannel: make(chan radix.PubSubMessage),
channelName: fmt.Sprintf("%s_%s", bus.WebsocketPubSubPrefix, finalSessId),
}
C[client] = true
go client.writePump()
go client.readPump()
}