275 lines
6.3 KiB
Go
275 lines
6.3 KiB
Go
package ws
|
|
|
|
import (
|
|
"encoding/json"
|
|
"fmt"
|
|
"log"
|
|
"net/http"
|
|
"strings"
|
|
"time"
|
|
|
|
"git.sp4ke.com/sp4ke/bit4sat/bus"
|
|
"git.sp4ke.com/sp4ke/bit4sat/db"
|
|
"github.com/gin-contrib/sessions"
|
|
"github.com/gin-gonic/gin"
|
|
"github.com/gorilla/websocket"
|
|
"github.com/mediocregopher/radix/v3"
|
|
"github.com/segmentio/ksuid"
|
|
)
|
|
|
|
const (
|
|
// Time allowed to write message to peer
|
|
writeWait = 10 * time.Second
|
|
|
|
// Time allowed to read the next pong message from the client.
|
|
pongWait = 30 * time.Second
|
|
|
|
// Send pings to client with this period. Must be less than pongWait.
|
|
pingPeriod = (pongWait * 9) / 10
|
|
//pingPeriod = 5 * time.Second
|
|
|
|
// Maximum message size
|
|
maxMessageSize = 512
|
|
|
|
WebsocketIdName = "websocket-id"
|
|
)
|
|
|
|
var upgrader = websocket.Upgrader{
|
|
ReadBufferSize: 1024,
|
|
WriteBufferSize: 1024,
|
|
CheckOrigin: func(r *http.Request) bool {
|
|
return true
|
|
},
|
|
}
|
|
|
|
// Keep reference to clients so they can be used in goroutines
|
|
var C = make(map[*Client]bool)
|
|
|
|
// Interface to talk to websocket
|
|
type Client struct {
|
|
id ksuid.KSUID
|
|
|
|
// websocket connection
|
|
conn *websocket.Conn
|
|
|
|
// PubSub message channel for this websocket
|
|
subChannel chan radix.PubSubMessage
|
|
|
|
// Name of main subscribed channel for this websocket
|
|
channelName string
|
|
|
|
// upload ids registered to this client
|
|
uploadId string
|
|
|
|
// upload notifications channel
|
|
uploadNotifChannel chan radix.PubSubMessage
|
|
}
|
|
|
|
func (c *Client) readPump() {
|
|
defer func() {
|
|
// Delete client from referenc list
|
|
delete(C, c)
|
|
|
|
// Close ws connection
|
|
c.conn.Close()
|
|
}()
|
|
|
|
c.conn.SetReadLimit(maxMessageSize)
|
|
c.conn.SetReadDeadline(time.Now().Add(pongWait))
|
|
c.conn.SetPongHandler(func(string) error {
|
|
//log.Println("pong")
|
|
c.conn.SetReadDeadline(time.Now().Add(pongWait))
|
|
return nil
|
|
})
|
|
|
|
for {
|
|
_, message, err := c.conn.ReadMessage()
|
|
if err != nil {
|
|
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway,
|
|
websocket.CloseAbnormalClosure) {
|
|
log.Printf("error: %v", err)
|
|
}
|
|
break
|
|
}
|
|
|
|
recvData := make(map[string]interface{})
|
|
json.Unmarshal(message, &recvData)
|
|
log.Println(recvData)
|
|
}
|
|
}
|
|
|
|
func (c *Client) writePump() {
|
|
pingTicker := time.NewTicker(pingPeriod)
|
|
defer func() {
|
|
log.Println("websocket closing")
|
|
pingTicker.Stop()
|
|
c.conn.Close()
|
|
}()
|
|
|
|
// Subscribe to general notifications for this client
|
|
log.Printf("socket subscribing to %s", c.channelName)
|
|
if err := db.DB.RedisPubSub.Subscribe(c.subChannel,
|
|
c.channelName); err != nil {
|
|
log.Println(err)
|
|
return
|
|
}
|
|
|
|
// subscribe to notifications related to this channel's registered upload
|
|
// ids
|
|
if err := db.DB.RedisPubSub.PSubscribe(c.uploadNotifChannel,
|
|
fmt.Sprintf("%s_*", bus.UploadUpdateChannelPrefix)); err != nil {
|
|
log.Println(err)
|
|
return
|
|
}
|
|
|
|
for {
|
|
select {
|
|
case msg := <-c.subChannel:
|
|
//log.Printf("received msg %s on socket main channel", msg)
|
|
jsonMsg := bus.Message{}
|
|
if err := json.Unmarshal(msg.Message, &jsonMsg); err != nil {
|
|
log.Printf("ws error reading from pubsub: %s", err)
|
|
break
|
|
}
|
|
|
|
if jsonMsg.Type == bus.SetUploadId {
|
|
log.Printf("registering uploadId: %s to socket", jsonMsg.UploadId)
|
|
c.uploadId = jsonMsg.UploadId
|
|
}
|
|
|
|
case msg := <-c.uploadNotifChannel:
|
|
log.Printf("websocket received upload paid notification on channel %s", msg.Channel)
|
|
log.Printf("our registered ids are %s", c.uploadId)
|
|
|
|
// If we have no upload id registered for this client break
|
|
if c.uploadId == "" {
|
|
continue
|
|
}
|
|
|
|
// Check if the message matches our upload id
|
|
slice := strings.SplitN(msg.Channel, "_", 3)
|
|
if len(slice) != 3 {
|
|
log.Printf("error decoding channel name %s", msg.Channel)
|
|
}
|
|
|
|
msgTarget := slice[2]
|
|
|
|
// If the target is any of our registered upload ids
|
|
// broadcast the message
|
|
|
|
if c.uploadId == msgTarget {
|
|
log.Println("this message is for us sending to client")
|
|
|
|
jsonMsg := bus.Message{}
|
|
err := json.Unmarshal(msg.Message, &jsonMsg)
|
|
if err != nil {
|
|
log.Printf("unmarshal error %s", err)
|
|
break
|
|
}
|
|
|
|
switch jsonMsg.Type {
|
|
|
|
case bus.PaymentReceived:
|
|
|
|
c.conn.SetWriteDeadline(time.Now().Add(writeWait))
|
|
|
|
w, err := c.conn.NextWriter(websocket.TextMessage)
|
|
if err != nil {
|
|
log.Println(err)
|
|
return
|
|
}
|
|
|
|
socketMsg := ProtoMessage{}
|
|
socketMsg.Type = messageTypes[invoicePaid]
|
|
socketMsg.Data = jsonMsg.Data
|
|
|
|
enc := json.NewEncoder(w)
|
|
err = enc.Encode(socketMsg)
|
|
if err != nil {
|
|
log.Println(err)
|
|
return
|
|
}
|
|
|
|
// No need to encode the message it's already in json
|
|
if err = w.Close(); err != nil {
|
|
fmt.Println(err)
|
|
return
|
|
}
|
|
|
|
}
|
|
|
|
// Handle different message types
|
|
|
|
}
|
|
|
|
case <-pingTicker.C:
|
|
//log.Println("ping")
|
|
|
|
c.conn.SetWriteDeadline(time.Now().Add(writeWait))
|
|
if err := c.conn.WriteMessage(websocket.PingMessage, nil); err != nil {
|
|
log.Printf("websocket ping error: %s", err)
|
|
return
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func Serve(c *gin.Context) {
|
|
var err error
|
|
log.Println("websocket request")
|
|
|
|
session := sessions.Default(c)
|
|
|
|
// First check if cookie is already set with upload id in case this is a
|
|
// reconnection
|
|
//
|
|
var finalSessId ksuid.KSUID
|
|
|
|
socketSessionId := session.Get(WebsocketIdName)
|
|
if socketSessionId == nil {
|
|
|
|
// Create unique id for this session in order to store it
|
|
// in the websocket bus
|
|
finalSessId, err = ksuid.NewRandomWithTime(time.Now())
|
|
if err != nil {
|
|
log.Println("error generating ksuid for websocket")
|
|
return
|
|
}
|
|
|
|
log.Printf("using socket id: %s", finalSessId)
|
|
session.Set(WebsocketIdName, finalSessId.String())
|
|
session.Save()
|
|
log.Printf("Writing socket session id to header: %s", c.Writer.Header())
|
|
} else {
|
|
// Reuse socket session id
|
|
log.Printf("reusing websocket id %s", socketSessionId)
|
|
finalSessId, err = ksuid.Parse(socketSessionId.(string))
|
|
if err != nil {
|
|
log.Println("could not parse websocket session id")
|
|
}
|
|
}
|
|
|
|
conn, err := upgrader.Upgrade(c.Writer, c.Request, c.Writer.Header())
|
|
if err != nil {
|
|
if _, ok := err.(websocket.HandshakeError); !ok {
|
|
log.Printf("handshake error: %s", err)
|
|
}
|
|
|
|
log.Println(err)
|
|
return
|
|
}
|
|
|
|
client := &Client{
|
|
id: finalSessId,
|
|
conn: conn,
|
|
subChannel: make(chan radix.PubSubMessage),
|
|
uploadNotifChannel: make(chan radix.PubSubMessage),
|
|
channelName: fmt.Sprintf("%s_%s", bus.WebsocketPubSubPrefix, finalSessId),
|
|
}
|
|
|
|
C[client] = true
|
|
|
|
go client.writePump()
|
|
go client.readPump()
|
|
}
|