diff --git a/poller.go b/poller.go index 989dde2..f21e9e4 100644 --- a/poller.go +++ b/poller.go @@ -20,6 +20,50 @@ type Poller interface { Poll(b *Bot, updates chan Update, stop chan struct{}) } +// MiddlewarePoller is a special kind of poller that acts +// like a filter for updates. It could be used for spam +// handling, banning or whatever. +// +// For heavy middleware, use increased capacity. +// +type MiddlewarePoller struct { + Capacity int // Default: 1 + Poller Poller + Filter func(*Update) bool +} + +// NewMiddlewarePoller wait for it... constructs a new middleware poller. +func NewMiddlewarePoller(original Poller, filter func(*Update) bool) *MiddlewarePoller { + return &MiddlewarePoller{ + Poller: original, + Filter: filter, + } +} + +// Poll sieves updates through middleware filter. +func (p *MiddlewarePoller) Poll(b *Bot, dest chan Update, stop chan struct{}) { + if p.Capacity < 1 { + p.Capacity = 1 + } + + middle := make(chan Update, p.Capacity) + stopPoller := make(chan struct{}) + + go p.Poller.Poll(b, middle, stopPoller) + + for { + select { + case <-stop: + close(stopPoller) + return + case upd := <-middle: + if p.Filter(&upd) { + dest <- upd + } + } + } +} + // LongPoller is a classic LongPoller with timeout. type LongPoller struct { Limit int diff --git a/poller_test.go b/poller_test.go index 936a93d..1f42f23 100644 --- a/poller_test.go +++ b/poller_test.go @@ -1,5 +1,11 @@ package telebot +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + type testPoller struct { updates chan Update done chan struct{} @@ -23,3 +29,39 @@ func (p *testPoller) Poll(b *Bot, updates chan Update, stop chan struct{}) { } } } + +func TestMiddlewarePoller(t *testing.T) { + tp := newTestPoller() + var ids []int + + pref := defaultSettings() + pref.Offline = true + + b, err := NewBot(pref) + if err != nil { + t.Fatal(err) + } + + b.Poller = NewMiddlewarePoller(tp, func(u *Update) bool { + if u.ID > 0 { + ids = append(ids, u.ID) + return true + } + + tp.done <- struct{}{} + return false + }) + + go func() { + tp.updates <- Update{ID: 1} + tp.updates <- Update{ID: 2} + tp.updates <- Update{ID: 0} + }() + + go b.Start() + <-tp.done + b.Stop() + + assert.Contains(t, ids, 1) + assert.Contains(t, ids, 2) +}