Fix handler middleware clobbering when group middleware slice has extra capacity

pull/588/head
Max Kuznetsov 1 year ago
parent fbd35f2103
commit bae47b52d4

@ -173,12 +173,15 @@ var (
// b.Handle("/ban", onBan, middleware.Whitelist(ids...))
//
func (b *Bot) Handle(endpoint interface{}, h HandlerFunc, m ...MiddlewareFunc) {
mw := m
if len(b.group.middleware) > 0 {
m = append(b.group.middleware, m...)
mw = make([]MiddlewareFunc, 0, len(b.group.middleware)+len(m))
mw = append(mw, b.group.middleware...)
mw = append(mw, m...)
}
handler := func(c Context) error {
return applyMiddleware(h, m...)(c)
return applyMiddleware(h, mw...)(c)
}
switch end := endpoint.(type) {

@ -373,6 +373,109 @@ func TestBotOnError(t *testing.T) {
assert.True(t, ok)
}
func TestBot_Middleware(t *testing.T) {
t.Run("call order", func(t *testing.T) {
var trace []string
mwTrace := func(name string) MiddlewareFunc {
return func(next HandlerFunc) HandlerFunc {
return func(c Context) error {
trace = append(trace, name+":in")
err := next(c)
trace = append(trace, name+":out")
return err
}
}
}
b, err := NewBot(Settings{Synchronous: true, Offline: true})
if err != nil {
t.Fatal(err)
}
b.Use(mwTrace("global-1"), mwTrace("global-2"))
b.Handle("/a", func(c Context) error {
trace = append(trace, "/a")
return nil
}, mwTrace("handler-1-a"), mwTrace("handler-2-a"))
group := b.Group()
group.Use(mwTrace("group-1"), mwTrace("group-2"))
group.Handle("/b", func(c Context) error {
trace = append(trace, "/b")
return nil
}, mwTrace("handler-1-b"))
b.ProcessUpdate(Update{Message: &Message{Text: "/a"}})
expectedOrder := []string{
"global-1:in", "global-2:in",
"handler-1-a:in", "handler-2-a:in",
"/a",
"handler-2-a:out", "handler-1-a:out",
"global-2:out", "global-1:out",
}
assert.Equal(t, expectedOrder, trace)
trace = trace[:0]
b.ProcessUpdate(Update{Message: &Message{Text: "/b"}})
expectedOrder = []string{
"global-1:in", "global-2:in",
"group-1:in", "group-2:in",
"handler-1-b:in",
"/b",
"handler-1-b:out",
"group-2:out", "group-1:out",
"global-2:out", "global-1:out",
}
assert.Equal(t, expectedOrder, trace)
})
fatalMiddleware := func(next HandlerFunc) HandlerFunc {
return func(c Context) error {
t.Fatal("fatalMiddleware should not be called")
return nil
}
}
nopMiddleware := func(next HandlerFunc) HandlerFunc {
return func(c Context) error { return next(c) }
}
t.Run("handler middleware is not clobbered when combined with global middleware", func(t *testing.T) {
b, err := NewBot(Settings{Synchronous: true, Offline: true})
if err != nil {
t.Fatal(err)
}
// Pre-allocate middleware slice to make sure it has extra capacity after group-level middleware is added.
b.group.middleware = make([]MiddlewareFunc, 0, 2)
b.Use(nopMiddleware)
b.Handle("/a", func(c Context) error { return nil }, nopMiddleware)
b.Handle("/b", func(c Context) error { return nil }, fatalMiddleware)
b.ProcessUpdate(Update{Message: &Message{Text: "/a"}})
})
t.Run("handler middleware is not clobbered when combined with group middleware", func(t *testing.T) {
b, err := NewBot(Settings{Synchronous: true, Offline: true})
if err != nil {
t.Fatal(err)
}
g := b.Group()
// Pre-allocate middleware slice to make sure it has extra capacity after group-level middleware is added.
g.middleware = make([]MiddlewareFunc, 0, 2)
g.Use(nopMiddleware)
g.Handle("/a", func(c Context) error { return nil }, nopMiddleware)
g.Handle("/b", func(c Context) error { return nil }, fatalMiddleware)
b.ProcessUpdate(Update{Message: &Message{Text: "/a"}})
})
}
func TestBot(t *testing.T) {
if b == nil {
t.Skip("Cached bot instance is bad (probably wrong or empty TELEBOT_SECRET)")

@ -25,5 +25,11 @@ func (g *Group) Use(middleware ...MiddlewareFunc) {
// Handle adds endpoint handler to the bot, combining group's middleware
// with the optional given middleware.
func (g *Group) Handle(endpoint interface{}, h HandlerFunc, m ...MiddlewareFunc) {
g.b.Handle(endpoint, h, append(g.middleware, m...)...)
mw := m
if len(g.middleware) > 0 {
mw = make([]MiddlewareFunc, 0, len(g.middleware)+len(m))
mw = append(mw, g.middleware...)
mw = append(mw, m...)
}
g.b.Handle(endpoint, h, mw...)
}

Loading…
Cancel
Save