diff --git a/bot.go b/bot.go index 52d6bcc..2a8a5af 100644 --- a/bot.go +++ b/bot.go @@ -173,12 +173,15 @@ var ( // b.Handle("/ban", onBan, middleware.Whitelist(ids...)) // func (b *Bot) Handle(endpoint interface{}, h HandlerFunc, m ...MiddlewareFunc) { + mw := m if len(b.group.middleware) > 0 { - m = append(b.group.middleware, m...) + mw = make([]MiddlewareFunc, 0, len(b.group.middleware)+len(m)) + mw = append(mw, b.group.middleware...) + mw = append(mw, m...) } handler := func(c Context) error { - return applyMiddleware(h, m...)(c) + return applyMiddleware(h, mw...)(c) } switch end := endpoint.(type) { diff --git a/bot_test.go b/bot_test.go index 030cef4..3f2f347 100644 --- a/bot_test.go +++ b/bot_test.go @@ -373,6 +373,109 @@ func TestBotOnError(t *testing.T) { assert.True(t, ok) } +func TestBot_Middleware(t *testing.T) { + t.Run("call order", func(t *testing.T) { + var trace []string + + mwTrace := func(name string) MiddlewareFunc { + return func(next HandlerFunc) HandlerFunc { + return func(c Context) error { + trace = append(trace, name+":in") + err := next(c) + trace = append(trace, name+":out") + return err + } + } + } + + b, err := NewBot(Settings{Synchronous: true, Offline: true}) + if err != nil { + t.Fatal(err) + } + b.Use(mwTrace("global-1"), mwTrace("global-2")) + + b.Handle("/a", func(c Context) error { + trace = append(trace, "/a") + return nil + }, mwTrace("handler-1-a"), mwTrace("handler-2-a")) + + group := b.Group() + group.Use(mwTrace("group-1"), mwTrace("group-2")) + + group.Handle("/b", func(c Context) error { + trace = append(trace, "/b") + return nil + }, mwTrace("handler-1-b")) + + b.ProcessUpdate(Update{Message: &Message{Text: "/a"}}) + + expectedOrder := []string{ + "global-1:in", "global-2:in", + "handler-1-a:in", "handler-2-a:in", + "/a", + "handler-2-a:out", "handler-1-a:out", + "global-2:out", "global-1:out", + } + assert.Equal(t, expectedOrder, trace) + + trace = trace[:0] + b.ProcessUpdate(Update{Message: &Message{Text: "/b"}}) + + expectedOrder = []string{ + "global-1:in", "global-2:in", + "group-1:in", "group-2:in", + "handler-1-b:in", + "/b", + "handler-1-b:out", + "group-2:out", "group-1:out", + "global-2:out", "global-1:out", + } + assert.Equal(t, expectedOrder, trace) + }) + + fatalMiddleware := func(next HandlerFunc) HandlerFunc { + return func(c Context) error { + t.Fatal("fatalMiddleware should not be called") + return nil + } + } + nopMiddleware := func(next HandlerFunc) HandlerFunc { + return func(c Context) error { return next(c) } + } + + t.Run("handler middleware is not clobbered when combined with global middleware", func(t *testing.T) { + b, err := NewBot(Settings{Synchronous: true, Offline: true}) + if err != nil { + t.Fatal(err) + } + // Pre-allocate middleware slice to make sure it has extra capacity after group-level middleware is added. + b.group.middleware = make([]MiddlewareFunc, 0, 2) + b.Use(nopMiddleware) + + b.Handle("/a", func(c Context) error { return nil }, nopMiddleware) + b.Handle("/b", func(c Context) error { return nil }, fatalMiddleware) + + b.ProcessUpdate(Update{Message: &Message{Text: "/a"}}) + }) + + t.Run("handler middleware is not clobbered when combined with group middleware", func(t *testing.T) { + b, err := NewBot(Settings{Synchronous: true, Offline: true}) + if err != nil { + t.Fatal(err) + } + + g := b.Group() + // Pre-allocate middleware slice to make sure it has extra capacity after group-level middleware is added. + g.middleware = make([]MiddlewareFunc, 0, 2) + g.Use(nopMiddleware) + + g.Handle("/a", func(c Context) error { return nil }, nopMiddleware) + g.Handle("/b", func(c Context) error { return nil }, fatalMiddleware) + + b.ProcessUpdate(Update{Message: &Message{Text: "/a"}}) + }) +} + func TestBot(t *testing.T) { if b == nil { t.Skip("Cached bot instance is bad (probably wrong or empty TELEBOT_SECRET)") diff --git a/middleware.go b/middleware.go index aa21ca2..e57465e 100644 --- a/middleware.go +++ b/middleware.go @@ -25,5 +25,11 @@ func (g *Group) Use(middleware ...MiddlewareFunc) { // Handle adds endpoint handler to the bot, combining group's middleware // with the optional given middleware. func (g *Group) Handle(endpoint interface{}, h HandlerFunc, m ...MiddlewareFunc) { - g.b.Handle(endpoint, h, append(g.middleware, m...)...) + mw := m + if len(g.middleware) > 0 { + mw = make([]MiddlewareFunc, 0, len(g.middleware)+len(m)) + mw = append(mw, g.middleware...) + mw = append(mw, m...) + } + g.b.Handle(endpoint, h, mw...) }