diff --git a/bot.go b/bot.go index 58effcb..4a0fc88 100644 --- a/bot.go +++ b/bot.go @@ -37,7 +37,7 @@ func NewBot(pref Settings) (*Bot, error) { Token: pref.Token, URL: pref.URL, Poller: pref.Poller, - OnError: pref.OnError, + onError: pref.OnError, Updates: make(chan Update, pref.Updates), handlers: make(map[string]HandlerFunc), @@ -70,7 +70,7 @@ type Bot struct { URL string Updates chan Update Poller Poller - OnError func(error, Context) + onError func(error, Context) group *Group handlers map[string]HandlerFunc @@ -149,6 +149,10 @@ type Command struct { Description string `json:"description"` } +func (b *Bot) OnError(err error, c Context) { + b.onError(err, c) +} + // Group returns a new group. func (b *Bot) Group() *Group { return &Group{b: b} diff --git a/bot_test.go b/bot_test.go index 031f393..8c91761 100644 --- a/bot_test.go +++ b/bot_test.go @@ -354,7 +354,7 @@ func TestBotOnError(t *testing.T) { } var ok bool - b.OnError = func(err error, c Context) { + b.onError = func(err error, c Context) { assert.Equal(t, b, c.(*nativeContext).b) assert.NotNil(t, err) ok = true diff --git a/middleware/logger.go b/middleware/logger.go index cfe01be..19f5a74 100644 --- a/middleware/logger.go +++ b/middleware/logger.go @@ -7,6 +7,8 @@ import ( tele "gopkg.in/telebot.v3" ) +// Logger returns a middleware that logs incoming updates. +// If no custom logger provided, log.Default() will be used. func Logger(logger ...*log.Logger) tele.MiddlewareFunc { var l *log.Logger if len(logger) > 0 { diff --git a/middleware/middleware.go b/middleware/middleware.go new file mode 100644 index 0000000..91babd5 --- /dev/null +++ b/middleware/middleware.go @@ -0,0 +1,62 @@ +package middleware + +import ( + "errors" + + tele "gopkg.in/telebot.v3" +) + +// AutoRespond returns a middleware that automatically responds +// to every callback. +func AutoRespond() tele.MiddlewareFunc { + return func(next tele.HandlerFunc) tele.HandlerFunc { + return func(c tele.Context) error { + if c.Callback() != nil { + defer c.Respond() + } + return next(c) + } + } +} + +// IgnoreVia returns a middleware that ignores all the +// "sent via" messages. +func IgnoreVia() tele.MiddlewareFunc { + return func(next tele.HandlerFunc) tele.HandlerFunc { + return func(c tele.Context) error { + if msg := c.Message(); msg != nil && msg.Via != nil { + return nil + } + return next(c) + } + } +} + +// Recover returns a middleware that recovers a panic happened in +// the handler. +func Recover(onError ...func(error)) tele.MiddlewareFunc { + return func(next tele.HandlerFunc) tele.HandlerFunc { + return func(c tele.Context) error { + var f func(error) + if len(onError) > 0 { + f = onError[0] + } else { + f = func(err error) { + c.Bot().OnError(err, nil) + } + } + + defer func() { + if r := recover(); r != nil { + if err, ok := r.(error); ok { + f(err) + } else if s, ok := r.(string); ok { + f(errors.New(s)) + } + } + }() + + return next(c) + } + } +} diff --git a/middleware/middleware_test.go b/middleware/middleware_test.go new file mode 100644 index 0000000..bd2f63e --- /dev/null +++ b/middleware/middleware_test.go @@ -0,0 +1,30 @@ +package middleware + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + tele "gopkg.in/telebot.v3" +) + +var b, _ = tele.NewBot(tele.Settings{Offline: true}) + +func TestRecover(t *testing.T) { + onError := func(err error) { + require.Error(t, err, "recover test") + } + + h := func(c tele.Context) error { + panic("recover test") + } + + assert.Panics(t, func() { + h(nil) + }) + + assert.NotPanics(t, func() { + Recover(onError)(h)(nil) + }) +} diff --git a/middleware/misc.go b/middleware/misc.go deleted file mode 100644 index 1a0cafa..0000000 --- a/middleware/misc.go +++ /dev/null @@ -1,25 +0,0 @@ -package middleware - -import tele "gopkg.in/telebot.v3" - -func AutoRespond() tele.MiddlewareFunc { - return func(next tele.HandlerFunc) tele.HandlerFunc { - return func(c tele.Context) error { - if c.Callback() != nil { - defer c.Respond() - } - return next(c) - } - } -} - -func IgnoreVia() tele.MiddlewareFunc { - return func(next tele.HandlerFunc) tele.HandlerFunc { - return func(c tele.Context) error { - if msg := c.Message(); msg != nil && msg.Via != nil { - return nil - } - return next(c) - } - } -} diff --git a/middleware/restrict.go b/middleware/restrict.go index 0bd83d9..59f79c7 100644 --- a/middleware/restrict.go +++ b/middleware/restrict.go @@ -2,11 +2,25 @@ package middleware import tele "gopkg.in/telebot.v3" +// RestrictConfig defines config for Restrict middleware. type RestrictConfig struct { - Chats []int64 - In, Out tele.HandlerFunc + // Chats is a list of chats that are going to be affected + // by either In or Out function. + Chats []int64 + + // In defines a function that will be called if the chat + // of an update will be found in the Chats list. + In tele.HandlerFunc + + // Out defines a function that will be called if the chat + // of an update will NOT be found in the Chats list. + Out tele.HandlerFunc } +// Restrict returns a middleware that handles a list of provided +// chats with the logic defined by In and Out functions. +// If the chat is found in the Chats field, In function will be called, +// otherwise Out function will be called. func Restrict(v RestrictConfig) tele.MiddlewareFunc { return func(next tele.HandlerFunc) tele.HandlerFunc { if v.In == nil { @@ -26,22 +40,26 @@ func Restrict(v RestrictConfig) tele.MiddlewareFunc { } } -func Whitelist(chats ...int64) tele.MiddlewareFunc { +// Blacklist returns a middleware that skips the update for users +// specified in the chats field. +func Blacklist(chats ...int64) tele.MiddlewareFunc { return func(next tele.HandlerFunc) tele.HandlerFunc { return Restrict(RestrictConfig{ Chats: chats, - In: next, - Out: func(c tele.Context) error { return nil }, + Out: next, + In: func(c tele.Context) error { return nil }, })(next) } } -func Blacklist(chats ...int64) tele.MiddlewareFunc { +// Whitelist returns a middleware that skips the update for users +// NOT specified in the chats field. +func Whitelist(chats ...int64) tele.MiddlewareFunc { return func(next tele.HandlerFunc) tele.HandlerFunc { return Restrict(RestrictConfig{ Chats: chats, - Out: next, - In: func(c tele.Context) error { return nil }, + In: next, + Out: func(c tele.Context) error { return nil }, })(next) } } diff --git a/util.go b/util.go index b37d02f..5e52a70 100644 --- a/util.go +++ b/util.go @@ -10,28 +10,21 @@ import ( ) var defaultOnError = func(err error, c Context) { - log.Println(c.Update().ID, err) -} - -func (b *Bot) debug(err error) { - if b.verbose { + if c != nil { + log.Println(c.Update().ID, err) + } else { log.Println(err) } } -func (b *Bot) deferDebug() { - if r := recover(); r != nil { - if err, ok := r.(error); ok { - b.debug(err) - } else if str, ok := r.(string); ok { - b.debug(fmt.Errorf("%s", str)) - } +func (b *Bot) debug(err error) { + if b.verbose { + b.OnError(err, nil) } } func (b *Bot) runHandler(h HandlerFunc, c Context) { f := func() { - defer b.deferDebug() if err := h(c); err != nil { b.OnError(err, c) } @@ -43,9 +36,9 @@ func (b *Bot) runHandler(h HandlerFunc, c Context) { } } -func applyMiddleware(h HandlerFunc, middleware ...MiddlewareFunc) HandlerFunc { - for i := len(middleware) - 1; i >= 0; i-- { - h = middleware[i](h) +func applyMiddleware(h HandlerFunc, m ...MiddlewareFunc) HandlerFunc { + for i := len(m) - 1; i >= 0; i-- { + h = m[i](h) } return h }