// Copyright (c) 2021 Tulir Asokan // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this // file, You can obtain one at http://mozilla.org/MPL/2.0/. package socket import ( "context" "crypto/cipher" "encoding/binary" "sync" "sync/atomic" "github.com/gorilla/websocket" ) type NoiseSocket struct { fs *FrameSocket onFrame FrameHandler writeKey cipher.AEAD readKey cipher.AEAD writeCounter uint32 readCounter uint32 writeLock sync.Mutex destroyed uint32 stopConsumer chan struct{} } type DisconnectHandler func(socket *NoiseSocket, remote bool) type FrameHandler func([]byte) func newNoiseSocket(fs *FrameSocket, writeKey, readKey cipher.AEAD, frameHandler FrameHandler, disconnectHandler DisconnectHandler) (*NoiseSocket, error) { ns := &NoiseSocket{ fs: fs, writeKey: writeKey, readKey: readKey, onFrame: frameHandler, stopConsumer: make(chan struct{}), } fs.OnDisconnect = func(remote bool) { disconnectHandler(ns, remote) } go ns.consumeFrames(fs.ctx, fs.Frames) return ns, nil } func (ns *NoiseSocket) consumeFrames(ctx context.Context, frames <-chan []byte) { ctxDone := ctx.Done() for { select { case frame := <-frames: ns.receiveEncryptedFrame(frame) case <-ctxDone: return case <-ns.stopConsumer: return } } } func generateIV(count uint32) []byte { iv := make([]byte, 12) binary.BigEndian.PutUint32(iv[8:], count) return iv } func (ns *NoiseSocket) Context() context.Context { return ns.fs.Context() } func (ns *NoiseSocket) Stop(disconnect bool) { if atomic.CompareAndSwapUint32(&ns.destroyed, 0, 1) { close(ns.stopConsumer) ns.fs.OnDisconnect = nil if disconnect { ns.fs.Close(websocket.CloseNormalClosure) } } } func (ns *NoiseSocket) SendFrame(plaintext []byte) error { ns.writeLock.Lock() ciphertext := ns.writeKey.Seal(nil, generateIV(ns.writeCounter), plaintext, nil) ns.writeCounter++ err := ns.fs.SendFrame(ciphertext) ns.writeLock.Unlock() return err } func (ns *NoiseSocket) receiveEncryptedFrame(ciphertext []byte) { count := atomic.AddUint32(&ns.readCounter, 1) - 1 plaintext, err := ns.readKey.Open(nil, generateIV(count), ciphertext, nil) if err != nil { ns.fs.log.Warnf("Failed to decrypt frame: %v", err) return } ns.onFrame(plaintext) } func (ns *NoiseSocket) IsConnected() bool { return ns.fs.IsConnected() }