bot: fix root group handling

This commit is contained in:
Demian 2020-09-15 18:12:43 +03:00
parent b291e4fc08
commit a70d2204db
4 changed files with 34 additions and 36 deletions

35
bot.go
View File

@ -38,9 +38,8 @@ func NewBot(pref Settings) (*Bot, error) {
Poller: pref.Poller, Poller: pref.Poller,
OnError: pref.OnError, OnError: pref.OnError,
Updates: make(chan Update, pref.Updates), Updates: make(chan Update, pref.Updates),
handlers: make(map[string]HandlerFunc), stop: make(chan struct{}),
stop: make(chan struct{}),
synchronous: pref.Synchronous, synchronous: pref.Synchronous,
verbose: pref.Verbose, verbose: pref.Verbose,
@ -72,7 +71,6 @@ type Bot struct {
OnError func(error, Context) OnError func(error, Context)
group *Group group *Group
handlers map[string]HandlerFunc
synchronous bool synchronous bool
verbose bool verbose bool
parseMode ParseMode parseMode ParseMode
@ -144,16 +142,16 @@ type Command struct {
Description string `json:"description"` Description string `json:"description"`
} }
// Group returns a new group.
func (b *Bot) Group() *Group {
return &Group{handlers: make(map[string]HandlerFunc)}
}
// Use adds middleware to the global bot chain. // Use adds middleware to the global bot chain.
func (b *Bot) Use(middleware ...MiddlewareFunc) { func (b *Bot) Use(middleware ...MiddlewareFunc) {
b.group.Use(middleware...) b.group.Use(middleware...)
} }
// Group returns a new group.
func (b *Bot) Group() *Group {
return &Group{b: b}
}
// Handle lets you set the handler for some command name or // Handle lets you set the handler for some command name or
// one of the supported endpoints. It also applies middleware // one of the supported endpoints. It also applies middleware
// if such passed to the function. // if such passed to the function.
@ -174,20 +172,7 @@ func (b *Bot) Group() *Group {
// b.Handle("/ban", onBan, protected) // b.Handle("/ban", onBan, protected)
// //
func (b *Bot) Handle(endpoint interface{}, h HandlerFunc, m ...MiddlewareFunc) { func (b *Bot) Handle(endpoint interface{}, h HandlerFunc, m ...MiddlewareFunc) {
if m != nil { b.group.Handle(endpoint, h, m...)
h = func(c Context) error {
return applyMiddleware(h, m...)(c)
}
}
switch end := endpoint.(type) {
case string:
b.handlers[end] = h
case CallbackEndpoint:
b.handlers[end.CallbackUnique()] = h
default:
panic("telebot: unsupported endpoint")
}
} }
var ( var (
@ -380,7 +365,7 @@ func (b *Bot) ProcessUpdate(upd Update) {
if match != nil { if match != nil {
unique, payload := match[0][1], match[0][3] unique, payload := match[0][1], match[0][3]
if handler, ok := b.handlers["\f"+unique]; ok { if handler, ok := b.group.handlers["\f"+unique]; ok {
upd.Callback.Data = payload upd.Callback.Data = payload
b.runHandler(handler, c) b.runHandler(handler, c)
return return
@ -425,7 +410,7 @@ func (b *Bot) ProcessUpdate(upd Update) {
} }
func (b *Bot) handle(end string, c Context) bool { func (b *Bot) handle(end string, c Context) bool {
if handler, ok := b.handlers[end]; ok { if handler, ok := b.group.handlers[end]; ok {
b.runHandler(handler, c) b.runHandler(handler, c)
return true return true
} }

View File

@ -82,7 +82,7 @@ func TestBotHandle(t *testing.T) {
} }
b.Handle("/start", func(c Context) error { return nil }) b.Handle("/start", func(c Context) error { return nil })
assert.Contains(t, b.handlers, "/start") assert.Contains(t, b.group.handlers, "/start")
reply := ReplyButton{Text: "reply"} reply := ReplyButton{Text: "reply"}
b.Handle(&reply, func(c Context) error { return nil }) b.Handle(&reply, func(c Context) error { return nil })
@ -96,10 +96,10 @@ func TestBotHandle(t *testing.T) {
btnInline := (&ReplyMarkup{}).Data("", "btnInline") btnInline := (&ReplyMarkup{}).Data("", "btnInline")
b.Handle(&btnInline, func(c Context) error { return nil }) b.Handle(&btnInline, func(c Context) error { return nil })
assert.Contains(t, b.handlers, btnReply.CallbackUnique()) assert.Contains(t, b.group.handlers, btnReply.CallbackUnique())
assert.Contains(t, b.handlers, btnInline.CallbackUnique()) assert.Contains(t, b.group.handlers, btnInline.CallbackUnique())
assert.Contains(t, b.handlers, reply.CallbackUnique()) assert.Contains(t, b.group.handlers, reply.CallbackUnique())
assert.Contains(t, b.handlers, inline.CallbackUnique()) assert.Contains(t, b.group.handlers, inline.CallbackUnique())
} }
func TestBotStart(t *testing.T) { func TestBotStart(t *testing.T) {
@ -107,9 +107,6 @@ func TestBotStart(t *testing.T) {
t.Skip("TELEBOT_SECRET is required") t.Skip("TELEBOT_SECRET is required")
} }
// cached bot has no poller
assert.Panics(t, func() { b.Start() })
pref := defaultSettings() pref := defaultSettings()
pref.Poller = &LongPoller{} pref.Poller = &LongPoller{}

View File

@ -6,8 +6,8 @@ type MiddlewareFunc func(HandlerFunc) HandlerFunc
// Group is a separated group of handlers, united by the general middleware. // Group is a separated group of handlers, united by the general middleware.
type Group struct { type Group struct {
b *Bot
middleware []MiddlewareFunc middleware []MiddlewareFunc
handlers map[string]HandlerFunc
} }
// Use adds middleware to the chain. // Use adds middleware to the chain.
@ -18,5 +18,21 @@ func (g *Group) Use(middleware ...MiddlewareFunc) {
// Handle adds endpoint handler to the bot, combining group's middleware // Handle adds endpoint handler to the bot, combining group's middleware
// with the optional given middleware. // with the optional given middleware.
func (g *Group) Handle(endpoint interface{}, h HandlerFunc, m ...MiddlewareFunc) { func (g *Group) Handle(endpoint interface{}, h HandlerFunc, m ...MiddlewareFunc) {
g.b.Handle(endpoint, h, append(g.middleware, m...)...) if len(g.middleware) > 0 {
m = append(g.middleware, m...)
}
if len(m) > 0 {
h = func(c Context) error {
return applyMiddleware(h, m...)(c)
}
}
switch end := endpoint.(type) {
case string:
g.handlers[end] = h
case CallbackEndpoint:
g.handlers[end.CallbackUnique()] = h
default:
panic("telebot: unsupported endpoint")
}
} }

View File

@ -31,7 +31,7 @@ func (b *Bot) runHandler(h HandlerFunc, c Context) {
f := func() { f := func() {
defer b.deferDebug() defer b.deferDebug()
if err := h(c); err != nil { if err := h(c); err != nil {
if err != ErrSkip { if err != ErrSkip && b.OnError != nil {
b.OnError(err, c) b.OnError(err, c)
} }
} }