diff --git a/bot.go b/bot.go index 878a0d4..dbfa230 100644 --- a/bot.go +++ b/bot.go @@ -35,10 +35,11 @@ func NewBot(pref Settings) (*Bot, error) { Updates: make(chan Update, pref.Updates), Poller: pref.Poller, - handlers: make(map[string]interface{}), - stop: make(chan struct{}), - reporter: pref.Reporter, - client: client, + handlers: make(map[string]interface{}), + synchronous: pref.Synchronous, + stop: make(chan struct{}), + reporter: pref.Reporter, + client: client, } user, err := bot.getMe() @@ -58,10 +59,11 @@ type Bot struct { Updates chan Update Poller Poller - handlers map[string]interface{} - reporter func(error) - stop chan struct{} - client *http.Client + handlers map[string]interface{} + synchronous bool + reporter func(error) + stop chan struct{} + client *http.Client } // Settings represents a utility struct for passing certain @@ -79,6 +81,10 @@ type Settings struct { // Poller is the provider of Updates. Poller Poller + // Synchronous prevents handlers from running in parallel. + // It makes ProcessUpdate return after the handler is finished. + Synchronous bool + // Reporter is a callback function that will get called // on any panics recovered from endpoint handlers. Reporter func(error) @@ -173,7 +179,7 @@ func (b *Bot) Start() { select { // handle incoming updates case upd := <-b.Updates: - b.incomingUpdate(&upd) + b.ProcessUpdate(&upd) // call to stop polling case <-b.stop: stop <- struct{}{} @@ -187,7 +193,9 @@ func (b *Bot) Stop() { b.stop <- struct{}{} } -func (b *Bot) incomingUpdate(upd *Update) { +// ProcessUpdate processes a single incoming update. +// A started bot calls this function automatically. +func (b *Bot) ProcessUpdate(upd *Update) { if upd.Message != nil { m := upd.Message @@ -278,11 +286,8 @@ func (b *Bot) incomingUpdate(upd *Update) { panic("telebot: migration handler is bad") } - go func(b *Bot, handler func(int64, int64), from, to int64) { - if b.reporter == nil { - defer b.deferDebug() - } - handler(from, to) + func(b *Bot, handler func(int64, int64), from, to int64) { + b.runHandler(func() { handler(from, to) }) }(b, handler, m.Chat.ID, m.MigrateTo) } @@ -335,11 +340,8 @@ func (b *Bot) incomingUpdate(upd *Update) { } upd.Callback.Data = payload - go func(b *Bot, handler func(*Callback), c *Callback) { - if b.reporter == nil { - defer b.deferDebug() - } - handler(c) + func(b *Bot, handler func(*Callback), c *Callback) { + b.runHandler(func() { handler(c) }) }(b, handler, upd.Callback) return @@ -354,11 +356,8 @@ func (b *Bot) incomingUpdate(upd *Update) { panic("telebot: callback handler is bad") } - go func(b *Bot, handler func(*Callback), c *Callback) { - if b.reporter == nil { - defer b.deferDebug() - } - handler(c) + func(b *Bot, handler func(*Callback), c *Callback) { + b.runHandler(func() { handler(c) }) }(b, handler, upd.Callback) } @@ -372,11 +371,8 @@ func (b *Bot) incomingUpdate(upd *Update) { panic("telebot: query handler is bad") } - go func(b *Bot, handler func(*Query), q *Query) { - if b.reporter == nil { - defer b.deferDebug() - } - handler(q) + func(b *Bot, handler func(*Query), q *Query) { + b.runHandler(func() { handler(q) }) }(b, handler, upd.Query) } @@ -390,11 +386,8 @@ func (b *Bot) incomingUpdate(upd *Update) { panic("telebot: chosen inline result handler is bad") } - go func(b *Bot, handler func(*ChosenInlineResult), r *ChosenInlineResult) { - if b.reporter == nil { - defer b.deferDebug() - } - handler(r) + func(b *Bot, handler func(*ChosenInlineResult), r *ChosenInlineResult) { + b.runHandler(func() { handler(r) }) }(b, handler, upd.ChosenInlineResult) } @@ -408,11 +401,8 @@ func (b *Bot) incomingUpdate(upd *Update) { panic("telebot: pre checkout query handler is bad") } - go func(b *Bot, handler func(*PreCheckoutQuery), pre *PreCheckoutQuery) { - if b.reporter == nil { - defer b.deferDebug() - } - handler(pre) + func(b *Bot, handler func(*PreCheckoutQuery), pre *PreCheckoutQuery) { + b.runHandler(func() { handler(pre) }) }(b, handler, upd.PreCheckoutQuery) } @@ -426,11 +416,8 @@ func (b *Bot) incomingUpdate(upd *Update) { panic("telebot: poll handler is bad") } - go func(b *Bot, handler func(*Poll), p *Poll) { - if b.reporter == nil { - defer b.deferDebug() - } - handler(p) + func(b *Bot, handler func(*Poll), p *Poll) { + b.runHandler(func() { handler(p) }) }(b, handler, upd.Poll) } @@ -444,11 +431,8 @@ func (b *Bot) incomingUpdate(upd *Update) { panic("telebot: poll answer handler is bad") } - go func(b *Bot, handler func(*PollAnswer), pa *PollAnswer) { - if b.reporter == nil { - defer b.deferDebug() - } - handler(pa) + func(b *Bot, handler func(*PollAnswer), pa *PollAnswer) { + b.runHandler(func() { handler(pa) }) }(b, handler, upd.PollAnswer) } @@ -463,12 +447,7 @@ func (b *Bot) handle(end string, m *Message) bool { panic(fmt.Errorf("telebot: %s handler is bad", end)) } - go func(b *Bot, handler func(*Message), m *Message) { - if b.reporter == nil { - defer b.deferDebug() - } - handler(m) - }(b, handler, m) + b.runHandler(func() { handler(m) }) return true } diff --git a/util.go b/util.go index 52c12e3..eae707c 100644 --- a/util.go +++ b/util.go @@ -28,6 +28,18 @@ func (b *Bot) deferDebug() { } } +func (b *Bot) runHandler(handler func()) { + f := func() { + defer b.deferDebug() + handler() + } + if b.synchronous { + f() + } else { + go f() + } +} + // wrapError returns new wrapped telebot-related error. func wrapError(err error) error { return errors.Wrap(err, "telebot")