diff --git a/bot.go b/bot.go index 9f4a2f4..5286b49 100644 --- a/bot.go +++ b/bot.go @@ -235,31 +235,20 @@ func (b *Bot) NewMarkup() *ReplyMarkup { // NewContext returns a new native context object, // field by the passed update. -func (b *Bot) NewContext(upd Update) Context { +func (b *Bot) NewContext(u Update) Context { return &nativeContext{ - b: b, - update: upd, - message: upd.Message, - callback: upd.Callback, - query: upd.Query, - inlineResult: upd.InlineResult, - shippingQuery: upd.ShippingQuery, - preCheckoutQuery: upd.PreCheckoutQuery, - poll: upd.Poll, - pollAnswer: upd.PollAnswer, - myChatMember: upd.MyChatMember, - chatMember: upd.ChatMember, - chatJoinRequest: upd.ChatJoinRequest, + b: b, + u: u, } } // ProcessUpdate processes a single incoming update. // A started bot calls this function automatically. -func (b *Bot) ProcessUpdate(upd Update) { - c := b.NewContext(upd).(*nativeContext) +func (b *Bot) ProcessUpdate(u Update) { + c := b.NewContext(u) - if upd.Message != nil { - m := upd.Message + if u.Message != nil { + m := u.Message if m.PinnedMessage != nil { b.handle(OnPinned, c) @@ -422,44 +411,38 @@ func (b *Bot) ProcessUpdate(upd Update) { } } - if upd.EditedMessage != nil { - c.message = upd.EditedMessage + if u.EditedMessage != nil { b.handle(OnEdited, c) return } - if upd.ChannelPost != nil { - m := upd.ChannelPost + if u.ChannelPost != nil { + m := u.ChannelPost if m.PinnedMessage != nil { - c.message = m.PinnedMessage b.handle(OnPinned, c) return } - c.message = upd.ChannelPost b.handle(OnChannelPost, c) return } - if upd.EditedChannelPost != nil { - c.message = upd.EditedChannelPost + if u.EditedChannelPost != nil { b.handle(OnEditedChannelPost, c) return } - if upd.Callback != nil { - if upd.Callback.Data != "" { - if data := upd.Callback.Data; data[0] == '\f' { - match := cbackRx.FindAllStringSubmatch(data, -1) - if match != nil { - unique, payload := match[0][1], match[0][3] - if handler, ok := b.handlers["\f"+unique]; ok { - c.callback.Unique = unique - c.callback.Data = payload - b.runHandler(handler, c) - return - } + if u.Callback != nil { + if data := u.Callback.Data; data != "" && data[0] == '\f' { + match := cbackRx.FindAllStringSubmatch(data, -1) + if match != nil { + unique, payload := match[0][1], match[0][3] + if handler, ok := b.handlers["\f"+unique]; ok { + u.Callback.Unique = unique + u.Callback.Data = payload + b.runHandler(handler, c) + return } } } @@ -468,47 +451,47 @@ func (b *Bot) ProcessUpdate(upd Update) { return } - if upd.Query != nil { + if u.Query != nil { b.handle(OnQuery, c) return } - if upd.InlineResult != nil { + if u.InlineResult != nil { b.handle(OnInlineResult, c) return } - if upd.ShippingQuery != nil { + if u.ShippingQuery != nil { b.handle(OnShipping, c) return } - if upd.PreCheckoutQuery != nil { + if u.PreCheckoutQuery != nil { b.handle(OnCheckout, c) return } - if upd.Poll != nil { + if u.Poll != nil { b.handle(OnPoll, c) return } - if upd.PollAnswer != nil { + if u.PollAnswer != nil { b.handle(OnPollAnswer, c) return } - if upd.MyChatMember != nil { + if u.MyChatMember != nil { b.handle(OnMyChatMember, c) return } - if upd.ChatMember != nil { + if u.ChatMember != nil { b.handle(OnChatMember, c) return } - if upd.ChatJoinRequest != nil { + if u.ChatJoinRequest != nil { b.handle(OnChatJoinRequest, c) return } diff --git a/context.go b/context.go index 09fee21..110e3d4 100644 --- a/context.go +++ b/context.go @@ -153,21 +153,8 @@ type Context interface { // nativeContext is a native implementation of the Context interface. // "context" is taken by context package, maybe there is a better name. type nativeContext struct { - b *Bot - - update Update - message *Message - callback *Callback - query *Query - inlineResult *InlineResult - shippingQuery *ShippingQuery - preCheckoutQuery *PreCheckoutQuery - poll *Poll - pollAnswer *PollAnswer - myChatMember *ChatMemberUpdate - chatMember *ChatMemberUpdate - chatJoinRequest *ChatJoinRequest - + b *Bot + u Update lock sync.RWMutex store map[string]interface{} } @@ -177,15 +164,24 @@ func (c *nativeContext) Bot() *Bot { } func (c *nativeContext) Update() Update { - return c.update + return c.u } func (c *nativeContext) Message() *Message { switch { - case c.message != nil: - return c.message - case c.callback != nil: - return c.callback.Message + case c.u.Message != nil: + return c.u.Message + case c.u.Callback != nil: + return c.u.Callback.Message + case c.u.EditedMessage != nil: + return c.u.EditedMessage + case c.u.ChannelPost != nil: + if c.u.ChannelPost.PinnedMessage != nil { + return c.u.ChannelPost.PinnedMessage + } + return c.u.ChannelPost + case c.u.EditedChannelPost != nil: + return c.u.EditedChannelPost default: return nil } @@ -215,74 +211,74 @@ func (c *nativeContext) Media() Media { } func (c *nativeContext) Callback() *Callback { - return c.callback + return c.u.Callback } func (c *nativeContext) Query() *Query { - return c.query + return c.u.Query } func (c *nativeContext) InlineResult() *InlineResult { - return c.inlineResult + return c.u.InlineResult } func (c *nativeContext) ShippingQuery() *ShippingQuery { - return c.shippingQuery + return c.u.ShippingQuery } func (c *nativeContext) PreCheckoutQuery() *PreCheckoutQuery { - return c.preCheckoutQuery + return c.u.PreCheckoutQuery } func (c *nativeContext) ChatMember() *ChatMemberUpdate { switch { - case c.chatMember != nil: - return c.chatMember - case c.myChatMember != nil: - return c.myChatMember + case c.u.ChatMember != nil: + return c.u.ChatMember + case c.u.MyChatMember != nil: + return c.u.MyChatMember default: return nil } } func (c *nativeContext) ChatJoinRequest() *ChatJoinRequest { - return c.chatJoinRequest + return c.u.ChatJoinRequest } func (c *nativeContext) Poll() *Poll { - return c.poll + return c.u.Poll } func (c *nativeContext) PollAnswer() *PollAnswer { - return c.pollAnswer + return c.u.PollAnswer } func (c *nativeContext) Migration() (int64, int64) { - return c.message.MigrateFrom, c.message.MigrateTo + return c.u.Message.MigrateFrom, c.u.Message.MigrateTo } func (c *nativeContext) Sender() *User { switch { - case c.message != nil: - return c.message.Sender - case c.callback != nil: - return c.callback.Sender - case c.query != nil: - return c.query.Sender - case c.inlineResult != nil: - return c.inlineResult.Sender - case c.shippingQuery != nil: - return c.shippingQuery.Sender - case c.preCheckoutQuery != nil: - return c.preCheckoutQuery.Sender - case c.pollAnswer != nil: - return c.pollAnswer.Sender - case c.myChatMember != nil: - return c.myChatMember.Sender - case c.chatMember != nil: - return c.chatMember.Sender - case c.chatJoinRequest != nil: - return c.chatJoinRequest.Sender + case c.u.Message != nil: + return c.u.Message.Sender + case c.u.Callback != nil: + return c.u.Callback.Sender + case c.u.Query != nil: + return c.u.Query.Sender + case c.u.InlineResult != nil: + return c.u.InlineResult.Sender + case c.u.ShippingQuery != nil: + return c.u.ShippingQuery.Sender + case c.u.PreCheckoutQuery != nil: + return c.u.PreCheckoutQuery.Sender + case c.u.PollAnswer != nil: + return c.u.PollAnswer.Sender + case c.u.MyChatMember != nil: + return c.u.MyChatMember.Sender + case c.u.ChatMember != nil: + return c.u.ChatMember.Sender + case c.u.ChatJoinRequest != nil: + return c.u.ChatJoinRequest.Sender default: return nil } @@ -290,16 +286,16 @@ func (c *nativeContext) Sender() *User { func (c *nativeContext) Chat() *Chat { switch { - case c.message != nil: - return c.message.Chat - case c.callback != nil && c.callback.Message != nil: - return c.callback.Message.Chat - case c.myChatMember != nil: - return c.myChatMember.Chat - case c.chatMember != nil: - return c.chatMember.Chat - case c.chatJoinRequest != nil: - return c.chatJoinRequest.Chat + case c.u.Message != nil: + return c.u.Message.Chat + case c.u.Callback != nil && c.u.Callback.Message != nil: + return c.u.Callback.Message.Chat + case c.u.MyChatMember != nil: + return c.u.MyChatMember.Chat + case c.u.ChatMember != nil: + return c.u.ChatMember.Chat + case c.u.ChatJoinRequest != nil: + return c.u.ChatJoinRequest.Chat default: return nil } @@ -317,10 +313,10 @@ func (c *nativeContext) Text() string { var m *Message switch { - case c.message != nil: - m = c.message - case c.callback != nil && c.callback.Message != nil: - m = c.callback.Message + case c.u.Message != nil: + m = c.u.Message + case c.u.Callback != nil && c.u.Callback.Message != nil: + m = c.u.Callback.Message default: return "" } @@ -334,18 +330,18 @@ func (c *nativeContext) Text() string { func (c *nativeContext) Data() string { switch { - case c.message != nil: - return c.message.Payload - case c.callback != nil: - return c.callback.Data - case c.query != nil: - return c.query.Text - case c.inlineResult != nil: - return c.inlineResult.Query - case c.shippingQuery != nil: - return c.shippingQuery.Payload - case c.preCheckoutQuery != nil: - return c.preCheckoutQuery.Payload + case c.u.Message != nil: + return c.u.Message.Payload + case c.u.Callback != nil: + return c.u.Callback.Data + case c.u.Query != nil: + return c.u.Query.Text + case c.u.InlineResult != nil: + return c.u.InlineResult.Query + case c.u.ShippingQuery != nil: + return c.u.ShippingQuery.Payload + case c.u.PreCheckoutQuery != nil: + return c.u.PreCheckoutQuery.Payload default: return "" } @@ -353,17 +349,17 @@ func (c *nativeContext) Data() string { func (c *nativeContext) Args() []string { switch { - case c.message != nil: - payload := strings.Trim(c.message.Payload, " ") + case c.u.Message != nil: + payload := strings.Trim(c.u.Message.Payload, " ") if payload != "" { return strings.Split(payload, " ") } - case c.callback != nil: - return strings.Split(c.callback.Data, "|") - case c.query != nil: - return strings.Split(c.query.Text, " ") - case c.inlineResult != nil: - return strings.Split(c.inlineResult.Query, " ") + case c.u.Callback != nil: + return strings.Split(c.u.Callback.Data, "|") + case c.u.Query != nil: + return strings.Split(c.u.Query.Text, " ") + case c.u.InlineResult != nil: + return strings.Split(c.u.InlineResult.Query, " ") } return nil } @@ -402,24 +398,24 @@ func (c *nativeContext) ForwardTo(to Recipient, opts ...interface{}) error { } func (c *nativeContext) Edit(what interface{}, opts ...interface{}) error { - if c.inlineResult != nil { - _, err := c.b.Edit(c.inlineResult, what, opts...) + if c.u.InlineResult != nil { + _, err := c.b.Edit(c.u.InlineResult, what, opts...) return err } - if c.callback != nil { - _, err := c.b.Edit(c.callback, what, opts...) + if c.u.Callback != nil { + _, err := c.b.Edit(c.u.Callback, what, opts...) return err } return ErrBadContext } func (c *nativeContext) EditCaption(caption string, opts ...interface{}) error { - if c.inlineResult != nil { - _, err := c.b.EditCaption(c.inlineResult, caption, opts...) + if c.u.InlineResult != nil { + _, err := c.b.EditCaption(c.u.InlineResult, caption, opts...) return err } - if c.callback != nil { - _, err := c.b.EditCaption(c.callback, caption, opts...) + if c.u.Callback != nil { + _, err := c.b.EditCaption(c.u.Callback, caption, opts...) return err } return ErrBadContext @@ -454,31 +450,31 @@ func (c *nativeContext) Notify(action ChatAction) error { } func (c *nativeContext) Ship(what ...interface{}) error { - if c.shippingQuery == nil { + if c.u.ShippingQuery == nil { return errors.New("telebot: context shipping query is nil") } - return c.b.Ship(c.shippingQuery, what...) + return c.b.Ship(c.u.ShippingQuery, what...) } func (c *nativeContext) Accept(errorMessage ...string) error { - if c.preCheckoutQuery == nil { + if c.u.PreCheckoutQuery == nil { return errors.New("telebot: context pre checkout query is nil") } - return c.b.Accept(c.preCheckoutQuery, errorMessage...) + return c.b.Accept(c.u.PreCheckoutQuery, errorMessage...) } func (c *nativeContext) Answer(resp *QueryResponse) error { - if c.query == nil { + if c.u.Query == nil { return errors.New("telebot: context inline query is nil") } - return c.b.Answer(c.query, resp) + return c.b.Answer(c.u.Query, resp) } func (c *nativeContext) Respond(resp ...*CallbackResponse) error { - if c.callback == nil { + if c.u.Callback == nil { return errors.New("telebot: context callback is nil") } - return c.b.Respond(c.callback, resp...) + return c.b.Respond(c.u.Callback, resp...) } func (c *nativeContext) Set(key string, value interface{}) {