diff --git a/app.go b/app.go index ce7b5ea..ef26652 100644 --- a/app.go +++ b/app.go @@ -26,7 +26,7 @@ type App struct { ready *containers.RWSync[bool] db *database.LocalDBClient - twitchClient *twitch.Client + twitchManager *twitch.Manager httpServer *http.Server loyaltyManager *loyalty.Manager } @@ -72,11 +72,11 @@ func (a *App) startup(ctx context.Context) { failOnError(err, "could not initialize http server") // Create twitch client - a.twitchClient, err = twitch.NewClient(a.db, a.httpServer, logger) + a.twitchManager, err = twitch.NewManager(a.db, a.httpServer, logger) failOnError(err, "could not initialize twitch client") // Initialize loyalty system - a.loyaltyManager, err = loyalty.NewManager(a.db, a.twitchClient, logger) + a.loyaltyManager, err = loyalty.NewManager(a.db, a.twitchManager, logger) failOnError(err, "could not initialize loyalty manager") a.ready.Set(true) @@ -98,8 +98,8 @@ func (a *App) stop(context.Context) { if a.loyaltyManager != nil { warnOnError(a.loyaltyManager.Close(), "could not cleanly close loyalty manager") } - if a.twitchClient != nil { - warnOnError(a.twitchClient.Close(), "could not cleanly close twitch client") + if a.twitchManager != nil { + warnOnError(a.twitchManager.Close(), "could not cleanly close twitch client") } if a.httpServer != nil { warnOnError(a.httpServer.Close(), "could not cleanly close HTTP server") @@ -129,11 +129,11 @@ func (a *App) GetKilovoltBind() string { } func (a *App) GetTwitchAuthURL() string { - return a.twitchClient.GetAuthorizationURL() + return a.twitchManager.Client().GetAuthorizationURL() } func (a *App) GetTwitchLoggedUser() (helix.User, error) { - return a.twitchClient.GetLoggedUser() + return a.twitchManager.Client().GetLoggedUser() } func (a *App) GetLastLogs() []LogEntry { diff --git a/database/database.go b/database/database.go index 73e4504..f3a0842 100644 --- a/database/database.go +++ b/database/database.go @@ -9,6 +9,8 @@ import ( "go.uber.org/zap" ) +type CancelFunc func() + var json = jsoniter.ConfigFastest var ( @@ -74,29 +76,36 @@ func (mod *LocalDBClient) PutKey(key string, data string) error { return err } -func (mod *LocalDBClient) SubscribePrefix(fn kv.SubscriptionCallback, prefixes ...string) error { +func (mod *LocalDBClient) SubscribePrefix(fn kv.SubscriptionCallback, prefixes ...string) (err error, cancelFn func()) { + var ids []int64 for _, prefix := range prefixes { - _, err := mod.makeRequest(kv.CmdSubscribePrefix, map[string]interface{}{"prefix": prefix}) + _, err = mod.makeRequest(kv.CmdSubscribePrefix, map[string]interface{}{"prefix": prefix}) if err != nil { - return err + return err, nil + } + ids = append(ids, mod.client.SetPrefixSubCallback(prefix, fn)) + } + return nil, func() { + for _, id := range ids { + mod.client.UnsetCallback(id) } - go mod.client.SetPrefixSubCallback(prefix, fn) } - return nil } -func (mod *LocalDBClient) SubscribeKey(key string, fn func(string)) error { - _, err := mod.makeRequest(kv.CmdSubscribePrefix, map[string]interface{}{"prefix": key}) +func (mod *LocalDBClient) SubscribeKey(key string, fn func(string)) (err error, cancelFn CancelFunc) { + _, err = mod.makeRequest(kv.CmdSubscribePrefix, map[string]interface{}{"prefix": key}) if err != nil { - return err + return err, nil } - go mod.client.SetPrefixSubCallback(key, func(changedKey string, value string) { + id := mod.client.SetPrefixSubCallback(key, func(changedKey string, value string) { if key != changedKey { return } fn(value) }) - return nil + return nil, func() { + mod.client.UnsetCallback(id) + } } func (mod *LocalDBClient) GetJSON(key string, dst interface{}) error { diff --git a/database/driver.interface.go b/database/driver.interface.go index 9e6e2f2..9bb8d80 100644 --- a/database/driver.interface.go +++ b/database/driver.interface.go @@ -6,6 +6,8 @@ import ( "os" "path/filepath" + "github.com/strimertul/strimertul/utils" + "go.uber.org/zap" kv "github.com/strimertul/kilovolt/v9" @@ -48,7 +50,7 @@ func getDatabaseDriverName(ctx *cli.Context) string { func GetDatabaseDriver(ctx *cli.Context) (DatabaseDriver, error) { name := getDatabaseDriverName(ctx) dbDirectory := ctx.String("database-dir") - logger := ctx.Context.Value("logger").(*zap.Logger) + logger := ctx.Context.Value(utils.ContextLogger).(*zap.Logger) switch name { case "badger": diff --git a/go.mod b/go.mod index d1bd13b..423b624 100644 --- a/go.mod +++ b/go.mod @@ -3,7 +3,7 @@ module github.com/strimertul/strimertul go 1.19 require ( - git.sr.ht/~hamcha/containers v0.2.0 + git.sr.ht/~hamcha/containers v0.2.1 github.com/Masterminds/sprig/v3 v3.2.2 github.com/apenwarr/fixconsole v0.0.0-20191012055117-5a9f6489cc29 github.com/cockroachdb/pebble v0.0.0-20221116223310-87eccabb90a3 diff --git a/go.sum b/go.sum index d47ec98..fea2d81 100644 --- a/go.sum +++ b/go.sum @@ -31,8 +31,8 @@ cloud.google.com/go/storage v1.6.0/go.mod h1:N7U0C8pVQ/+NIKOBQyamJIeKQKkZ+mxpohl cloud.google.com/go/storage v1.8.0/go.mod h1:Wv1Oy7z6Yz3DshWRJFhqM/UCfaWIRTdp0RXyy7KQOVs= cloud.google.com/go/storage v1.10.0/go.mod h1:FLPqc6j+Ki4BU591ie1oL6qBQGu2Bl/tZ9ullr3+Kg0= dmitri.shuralyov.com/gpu/mtl v0.0.0-20190408044501-666a987793e9/go.mod h1:H6x//7gZCb22OMCxBHrMx7a5I7Hp++hsVxbQ4BYO7hU= -git.sr.ht/~hamcha/containers v0.2.0 h1:fv8HQ6fsJUa1w46sH9KluW6dfJEh3uZN3QNLJvuCIm4= -git.sr.ht/~hamcha/containers v0.2.0/go.mod h1:RiZphUpy9t6EnL4Gf6uzByM9QrBoqRCEPo7kz2wzbhE= +git.sr.ht/~hamcha/containers v0.2.1 h1:mJ8b4fQhDKU73VRK1SjeIzJ5YnZYHeFHLJvHl6yKtNg= +git.sr.ht/~hamcha/containers v0.2.1/go.mod h1:RiZphUpy9t6EnL4Gf6uzByM9QrBoqRCEPo7kz2wzbhE= github.com/AndreasBriese/bbloom v0.0.0-20190306092124-e2d15f34fcf9/go.mod h1:bOvUY6CB00SOBii9/FifXqc0awNKxLFCL/+pkDPuyl8= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/BurntSushi/toml v1.2.1 h1:9F2/+DoOYIOksmaJFPw1tGFy1eDnIJXg+UHjuD8lTak= diff --git a/http/server.go b/http/server.go index 0619750..e73fd2b 100644 --- a/http/server.go +++ b/http/server.go @@ -27,7 +27,8 @@ type Server struct { frontend fs.FS hub *kv.Hub mux *http.ServeMux - requestedRoutes map[string]http.HandlerFunc + requestedRoutes map[string]http.Handler + cancelConfigSub database.CancelFunc } func NewServer(db *database.LocalDBClient, logger *zap.Logger) (*Server, error) { @@ -35,7 +36,7 @@ func NewServer(db *database.LocalDBClient, logger *zap.Logger) (*Server, error) logger: logger, db: db, server: &http.Server{}, - requestedRoutes: make(map[string]http.HandlerFunc), + requestedRoutes: make(map[string]http.Handler), } err := db.GetJSON(ServerConfigKey, &server.Config) @@ -70,6 +71,10 @@ type StatusData struct { } func (s *Server) Close() error { + if s.cancelConfigSub != nil { + s.cancelConfigSub() + } + return s.server.Close() } @@ -99,53 +104,55 @@ func (s *Server) makeMux() *http.ServeMux { mux.Handle("/static/", http.StripPrefix("/static/", http.FileServer(http.Dir(s.Config.Path)))) } for route, handler := range s.requestedRoutes { - mux.HandleFunc(route, handler) + mux.Handle(route, handler) } return mux } -func (s *Server) SetRoute(route string, handler http.HandlerFunc) { +func (s *Server) RegisterRoute(route string, handler http.Handler) { s.requestedRoutes[route] = handler - if s.mux != nil { - s.mux.HandleFunc(route, handler) - } + s.mux = s.makeMux() +} + +func (s *Server) UnregisterRoute(route string) { + delete(s.requestedRoutes, route) + s.mux = s.makeMux() } func (s *Server) Listen() error { // Start HTTP server restart := containers.NewRWSync(false) exit := make(chan error) - go func() { - err := s.db.SubscribeKey(ServerConfigKey, func(value string) { - oldBind := s.Config.Bind - oldPassword := s.Config.KVPassword - err := json.Unmarshal([]byte(value), &s.Config) + var err error + err, s.cancelConfigSub = s.db.SubscribeKey(ServerConfigKey, func(value string) { + oldBind := s.Config.Bind + oldPassword := s.Config.KVPassword + err := json.Unmarshal([]byte(value), &s.Config) + if err != nil { + s.logger.Error("Failed to unmarshal config", zap.Error(err)) + return + } + s.mux = s.makeMux() + // Restart hub if password changed + if oldPassword != s.Config.KVPassword { + s.hub.SetOptions(kv.HubOptions{ + Password: s.Config.KVPassword, + }) + } + // Restart server if bind changed + if oldBind != s.Config.Bind { + restart.Set(true) + err = s.server.Shutdown(context.Background()) if err != nil { - s.logger.Error("Failed to unmarshal config", zap.Error(err)) + s.logger.Error("Failed to shutdown server", zap.Error(err)) return } - s.mux = s.makeMux() - // Restart hub if password changed - if oldPassword != s.Config.KVPassword { - s.hub.SetOptions(kv.HubOptions{ - Password: s.Config.KVPassword, - }) - } - // Restart server if bind changed - if oldBind != s.Config.Bind { - restart.Set(true) - err = s.server.Shutdown(context.Background()) - if err != nil { - s.logger.Error("Failed to shutdown server", zap.Error(err)) - return - } - } - }) - if err != nil { - exit <- fmt.Errorf("error while handling subscription to HTTP config changes: %w", err) } - }() + }) + if err != nil { + exit <- fmt.Errorf("error while handling subscription to HTTP config changes: %w", err) + } go func() { for { s.logger.Info("Starting HTTP server", zap.String("bind", s.Config.Bind)) diff --git a/loyalty/manager.go b/loyalty/manager.go index 21718ce..2fc3649 100644 --- a/loyalty/manager.go +++ b/loyalty/manager.go @@ -28,22 +28,23 @@ var ( ) type Manager struct { - points *containers.SyncMap[string, PointsEntry] - Config *containers.RWSync[Config] - Rewards *containers.Sync[RewardStorage] - Goals *containers.Sync[GoalStorage] - Queue *containers.Sync[RedeemQueueStorage] - db *database.LocalDBClient - logger *zap.Logger - cooldowns map[string]time.Time - banlist map[string]bool - activeUsers *containers.SyncMap[string, bool] - twitchClient *twitch.Client - ctx context.Context - cancelFn context.CancelFunc + points *containers.SyncMap[string, PointsEntry] + Config *containers.RWSync[Config] + Rewards *containers.Sync[RewardStorage] + Goals *containers.Sync[GoalStorage] + Queue *containers.Sync[RedeemQueueStorage] + db *database.LocalDBClient + logger *zap.Logger + cooldowns map[string]time.Time + banlist map[string]bool + activeUsers *containers.SyncMap[string, bool] + twitchManager *twitch.Manager + ctx context.Context + cancelFn context.CancelFunc + cancelSub database.CancelFunc } -func NewManager(db *database.LocalDBClient, twitchClient *twitch.Client, logger *zap.Logger) (*Manager, error) { +func NewManager(db *database.LocalDBClient, twitchManager *twitch.Manager, logger *zap.Logger) (*Manager, error) { ctx, cancelFn := context.WithCancel(context.Background()) loyalty := &Manager{ Config: containers.NewRWSync(Config{Enabled: false}), @@ -51,19 +52,19 @@ func NewManager(db *database.LocalDBClient, twitchClient *twitch.Client, logger Goals: containers.NewSync(GoalStorage{}), Queue: containers.NewSync(RedeemQueueStorage{}), - logger: logger, - db: db, - points: containers.NewSyncMap[string, PointsEntry](), - cooldowns: make(map[string]time.Time), - banlist: make(map[string]bool), - activeUsers: containers.NewSyncMap[string, bool](), - twitchClient: twitchClient, - ctx: ctx, - cancelFn: cancelFn, + logger: logger, + db: db, + points: containers.NewSyncMap[string, PointsEntry](), + cooldowns: make(map[string]time.Time), + banlist: make(map[string]bool), + activeUsers: containers.NewSyncMap[string, bool](), + twitchManager: twitchManager, + ctx: ctx, + cancelFn: cancelFn, } // Get data from DB var config Config - if err := db.GetJSON(ConfigKey, config); err == nil { + if err := db.GetJSON(ConfigKey, &config); err == nil { loyalty.Config.Set(config) } else { if !errors.Is(err, database.ErrEmptyKey) { @@ -118,7 +119,7 @@ func NewManager(db *database.LocalDBClient, twitchClient *twitch.Client, logger } // SubscribePrefix for changes - err = db.SubscribePrefix(loyalty.update, "loyalty/") + err, loyalty.cancelSub = db.SubscribePrefix(loyalty.update, "loyalty/") if err != nil { logger.Error("could not setup loyalty reload subscription", zap.Error(err)) } @@ -132,6 +133,11 @@ func NewManager(db *database.LocalDBClient, twitchClient *twitch.Client, logger } func (m *Manager) Close() error { + // Stop subscription + if m.cancelSub != nil { + m.cancelSub() + } + // Send cancellation m.cancelFn() diff --git a/loyalty/twitch-bot.go b/loyalty/twitch-bot.go index 48dfaf3..4a3ed3a 100644 --- a/loyalty/twitch-bot.go +++ b/loyalty/twitch-bot.go @@ -16,7 +16,7 @@ import ( ) func (m *Manager) SetupTwitch() { - bot := m.twitchClient.Bot + bot := m.twitchManager.Client().Bot if bot == nil { return } @@ -57,7 +57,7 @@ func (m *Manager) SetupTwitch() { // Setup handler for adding points over time go func() { config := m.Config.Get() - if config.Enabled && bot != nil { + if config.Enabled { for { if config.Points.Interval > 0 { // Wait for next poll @@ -68,11 +68,18 @@ func (m *Manager) SetupTwitch() { } // If stream is confirmed offline, don't give points away! - isOnline := m.twitchClient.IsLive() + isOnline := m.twitchManager.Client().IsLive() if !isOnline { continue } + // Check that bot is online and working + bot := m.twitchManager.Client().Bot + if bot == nil { + m.logger.Warn("bot is offline or not configured, could not assign points") + continue + } + m.logger.Debug("awarding points") // Get user list @@ -116,7 +123,7 @@ func (m *Manager) SetupTwitch() { } func (m *Manager) StopTwitch() { - bot := m.twitchClient.Bot + bot := m.twitchManager.Client().Bot if bot != nil { bot.RemoveCommand("!redeem") bot.RemoveCommand("!balance") diff --git a/main.go b/main.go index 826f331..7a30cc4 100644 --- a/main.go +++ b/main.go @@ -1,6 +1,7 @@ package main import ( + "context" "embed" "fmt" "log" @@ -8,6 +9,8 @@ import ( "os" "time" + "github.com/strimertul/strimertul/utils" + "github.com/apenwarr/fixconsole" "go.uber.org/zap/zapcore" @@ -86,6 +89,7 @@ func main() { level = zapcore.InfoLevel } initLogger(level) + ctx.Context = context.WithValue(ctx.Context, utils.ContextLogger, logger) return nil }, After: func(ctx *cli.Context) error { diff --git a/twitch/bot.alerts.go b/twitch/bot.alerts.go index 4d928be..30266e4 100644 --- a/twitch/bot.alerts.go +++ b/twitch/bot.alerts.go @@ -7,6 +7,8 @@ import ( "text/template" "time" + "github.com/strimertul/strimertul/database" + "go.uber.org/zap" "github.com/Masterminds/sprig/v3" @@ -91,6 +93,9 @@ type BotAlertsModule struct { bot *Bot mu sync.Mutex templates templateCache + + cancelAlertSub database.CancelFunc + cancelTwitchEventSub database.CancelFunc } func SetupAlerts(bot *Bot) *BotAlertsModule { @@ -114,7 +119,7 @@ func SetupAlerts(bot *Bot) *BotAlertsModule { mod.compileTemplates() - err = bot.api.db.SubscribeKey(BotAlertsKey, func(value string) { + err, mod.cancelAlertSub = bot.api.db.SubscribeKey(BotAlertsKey, func(value string) { err := json.UnmarshalFromString(value, &mod.Config) if err != nil { bot.logger.Debug("error reloading timer config", zap.Error(err)) @@ -238,7 +243,7 @@ func SetupAlerts(bot *Bot) *BotAlertsModule { } } - err = bot.api.db.SubscribeKey(EventSubEventKey, func(value string) { + err, mod.cancelTwitchEventSub = bot.api.db.SubscribeKey(EventSubEventKey, func(value string) { var ev eventSubNotification err := json.UnmarshalFromString(value, &ev) if err != nil { @@ -489,6 +494,15 @@ func (m *BotAlertsModule) addTemplate(templateList map[int]*template.Template, i templateList[id] = tpl } +func (m *BotAlertsModule) Close() { + if m.cancelAlertSub != nil { + m.cancelAlertSub() + } + if m.cancelTwitchEventSub != nil { + m.cancelTwitchEventSub() + } +} + // writeTemplate renders the template and sends the message to the channel func writeTemplate(bot *Bot, tpl *template.Template, data interface{}) { var buf bytes.Buffer diff --git a/twitch/bot.go b/twitch/bot.go index e106a18..3612b53 100644 --- a/twitch/bot.go +++ b/twitch/bot.go @@ -2,10 +2,12 @@ package twitch import ( "strings" - "sync" "text/template" "time" + "git.sr.ht/~hamcha/containers" + "github.com/strimertul/strimertul/database" + "github.com/strimertul/strimertul/utils" "go.uber.org/zap" @@ -21,18 +23,19 @@ type Bot struct { api *Client username string logger *zap.Logger - lastMessage time.Time - chatHistory []irc.PrivateMessage + lastMessage *containers.RWSync[time.Time] + chatHistory *containers.Sync[[]irc.PrivateMessage] - commands map[string]BotCommand - customCommands map[string]BotCustomCommand - customTemplates map[string]*template.Template + commands *containers.SyncMap[string, BotCommand] + customCommands *containers.SyncMap[string, BotCustomCommand] + customTemplates *containers.SyncMap[string, *template.Template] customFunctions template.FuncMap OnConnect *utils.PubSub[BotConnectHandler] OnMessage *utils.PubSub[BotMessageHandler] - mu sync.Mutex + cancelUpdateSub database.CancelFunc + cancelWriteRPCSub database.CancelFunc // Module specific vars Timers *BotTimerModule @@ -49,7 +52,14 @@ type BotMessageHandler interface { HandleBotMessage(message irc.PrivateMessage) } -func NewBot(api *Client, config BotConfig) *Bot { +func (b *Bot) Migrate(old *Bot) { + utils.MergeSyncMap(b.commands, old.commands) + // Get registered commands and handlers from old bot + b.OnConnect.Copy(old.OnConnect) + b.OnMessage.Copy(old.OnMessage) +} + +func newBot(api *Client, config BotConfig) *Bot { // Create client client := irc.NewClient(config.Username, config.Token) @@ -60,11 +70,10 @@ func NewBot(api *Client, config BotConfig) *Bot { username: strings.ToLower(config.Username), // Normalize username logger: api.logger, api: api, - lastMessage: time.Now(), - mu: sync.Mutex{}, - commands: make(map[string]BotCommand), - customCommands: make(map[string]BotCustomCommand), - customTemplates: make(map[string]*template.Template), + lastMessage: containers.NewRWSync(time.Now()), + commands: containers.NewSyncMap[string, BotCommand](), + customCommands: containers.NewSyncMap[string, BotCustomCommand](), + customTemplates: containers.NewSyncMap[string, *template.Template](), OnConnect: utils.NewPubSub[BotConnectHandler](), OnMessage: utils.NewPubSub[BotMessageHandler](), @@ -86,18 +95,17 @@ func NewBot(api *Client, config BotConfig) *Bot { } // Ignore messages for a while or twitch will get mad! - if message.Time.Before(bot.lastMessage.Add(time.Second * 2)) { + if message.Time.Before(bot.lastMessage.Get().Add(time.Second * 2)) { bot.logger.Debug("message received too soon, ignoring") return } - bot.mu.Lock() lowercaseMessage := strings.ToLower(message.Message) // Check if it's a command if strings.HasPrefix(message.Message, "!") { // Run through supported commands - for cmd, data := range bot.commands { + for cmd, data := range bot.commands.Copy() { if !data.Enabled { continue } @@ -109,12 +117,12 @@ func NewBot(api *Client, config BotConfig) *Bot { continue } go data.Handler(bot, message) - bot.lastMessage = time.Now() + bot.lastMessage.Set(time.Now()) } } // Run through custom commands - for cmd, data := range bot.customCommands { + for cmd, data := range bot.customCommands.Get() { if !data.Enabled { continue } @@ -127,19 +135,19 @@ func NewBot(api *Client, config BotConfig) *Bot { continue } go cmdCustom(bot, cmd, data, message) - bot.lastMessage = time.Now() + bot.lastMessage.Set(time.Now()) } - bot.mu.Unlock() err := bot.api.db.PutJSON(ChatEventKey, message) if err != nil { bot.logger.Warn("could not save chat message to key", zap.String("key", ChatEventKey), zap.Error(err)) } if bot.Config.ChatHistory > 0 { - if len(bot.chatHistory) >= bot.Config.ChatHistory { - bot.chatHistory = bot.chatHistory[len(bot.chatHistory)-bot.Config.ChatHistory+1:] + history := bot.chatHistory.Get() + if len(history) >= bot.Config.ChatHistory { + history = history[len(history)-bot.Config.ChatHistory+1:] } - bot.chatHistory = append(bot.chatHistory, message) + bot.chatHistory.Set(append(history, message)) err = bot.api.db.PutJSON(ChatHistoryKey, bot.chatHistory) if err != nil { bot.logger.Warn("could not save message to chat history", zap.Error(err)) @@ -175,20 +183,22 @@ func NewBot(api *Client, config BotConfig) *Bot { bot.Alerts = SetupAlerts(bot) // Load custom commands - err := api.db.GetJSON(CustomCommandsKey, &bot.customCommands) + var customCommands map[string]BotCustomCommand + err := api.db.GetJSON(CustomCommandsKey, &customCommands) if err != nil { bot.logger.Error("failed to load custom commands", zap.Error(err)) } + bot.customCommands.Set(customCommands) err = bot.updateTemplates() if err != nil { bot.logger.Error("failed to parse custom commands", zap.Error(err)) } - err = api.db.SubscribeKey(CustomCommandsKey, bot.updateCommands) + err, bot.cancelUpdateSub = api.db.SubscribeKey(CustomCommandsKey, bot.updateCommands) if err != nil { bot.logger.Error("could not set-up bot command reload subscription", zap.Error(err)) } - err = api.db.SubscribeKey(WriteMessageRPC, bot.handleWriteMessageRPC) + err, bot.cancelWriteRPCSub = api.db.SubscribeKey(WriteMessageRPC, bot.handleWriteMessageRPC) if err != nil { bot.logger.Error("could not set-up bot command reload subscription", zap.Error(err)) } @@ -196,12 +206,24 @@ func NewBot(api *Client, config BotConfig) *Bot { return bot } +func (b *Bot) Close() error { + if b.cancelUpdateSub != nil { + b.cancelUpdateSub() + } + if b.cancelWriteRPCSub != nil { + b.cancelWriteRPCSub() + } + if b.Timers != nil { + b.Timers.Close() + } + if b.Alerts != nil { + b.Alerts.Close() + } + return b.Client.Disconnect() +} + func (b *Bot) updateCommands(value string) { - err := func() error { - b.mu.Lock() - defer b.mu.Unlock() - return json.UnmarshalFromString(value, &b.customCommands) - }() + err := utils.LoadJSONToWrapped[map[string]BotCustomCommand](value, b.customCommands) if err != nil { b.logger.Error("failed to decode new custom commands", zap.Error(err)) return @@ -218,18 +240,21 @@ func (b *Bot) handleWriteMessageRPC(value string) { } func (b *Bot) updateTemplates() error { - for cmd, tmpl := range b.customCommands { - var err error - b.customTemplates[cmd], err = template.New("").Funcs(sprig.TxtFuncMap()).Funcs(b.customFunctions).Parse(tmpl.Response) + for cmd, tmpl := range b.customCommands.Copy() { + tpl, err := template.New("").Funcs(sprig.TxtFuncMap()).Funcs(b.customFunctions).Parse(tmpl.Response) if err != nil { return err } + b.customTemplates.SetKey(cmd, tpl) } return nil } -func (b *Bot) Connect() error { - return b.Client.Connect() +func (b *Bot) Connect() { + err := b.Client.Connect() + if err != nil { + b.logger.Error("bot connection ended", zap.Error(err)) + } } func (b *Bot) WriteMessage(message string) { @@ -237,12 +262,11 @@ func (b *Bot) WriteMessage(message string) { } func (b *Bot) RegisterCommand(trigger string, command BotCommand) { - // TODO make it goroutine safe? - b.commands[trigger] = command + b.commands.SetKey(trigger, command) } func (b *Bot) RemoveCommand(trigger string) { - delete(b.commands, trigger) + b.commands.DeleteKey(trigger) } func getUserAccessLevel(user irc.User) AccessLevelType { diff --git a/twitch/bot.timer.go b/twitch/bot.timer.go index b82c607..ff36cf5 100644 --- a/twitch/bot.timer.go +++ b/twitch/bot.timer.go @@ -5,6 +5,8 @@ import ( "sync" "time" + "github.com/strimertul/strimertul/database" + "go.uber.org/zap" irc "github.com/gempir/go-twitch-irc/v3" @@ -34,6 +36,8 @@ type BotTimerModule struct { messages [AverageMessageWindow]int mu sync.Mutex startTime time.Time + + cancelTimerSub database.CancelFunc } func SetupTimers(bot *Bot) *BotTimerModule { @@ -58,7 +62,7 @@ func SetupTimers(bot *Bot) *BotTimerModule { } } - err = bot.api.db.SubscribeKey(BotTimersKey, func(value string) { + err, mod.cancelTimerSub = bot.api.db.SubscribeKey(BotTimersKey, func(value string) { err := json.UnmarshalFromString(value, &mod.Config) if err != nil { bot.logger.Debug("error reloading timer config", zap.Error(err)) @@ -144,6 +148,12 @@ func (m *BotTimerModule) runTimers() { } } +func (m *BotTimerModule) Close() { + if m.cancelTimerSub != nil { + m.cancelTimerSub() + } +} + func (m *BotTimerModule) currentChatActivity() int { total := 0 for _, v := range m.messages { diff --git a/twitch/client.auth.go b/twitch/client.auth.go index 2bfcee6..d5b8dfd 100644 --- a/twitch/client.auth.go +++ b/twitch/client.auth.go @@ -72,7 +72,7 @@ func (c *Client) GetLoggedUser() (helix.User, error) { return users.Data.Users[0], nil } -func (c *Client) AuthorizeCallback(w http.ResponseWriter, req *http.Request) { +func (c *Client) ServeHTTP(w http.ResponseWriter, req *http.Request) { // Get code from params code := req.URL.Query().Get("code") if code == "" { @@ -110,32 +110,6 @@ type RefreshResponse struct { Scope []string `json:"scope"` } -func (c *Client) refreshAccessToken(refreshToken string) (r RefreshResponse, err error) { - // Exchange code for access/refresh tokens - query := url.Values{ - "client_id": {c.Config.APIClientID}, - "client_secret": {c.Config.APIClientSecret}, - "grant_type": {"refresh_token"}, - "refresh_token": {refreshToken}, - } - authRequest, err := http.NewRequest("POST", "https://id.twitch.tv/oauth2/token?"+query.Encode(), nil) - if err != nil { - return RefreshResponse{}, err - } - resp, err := http.DefaultClient.Do(authRequest) - if err != nil { - return RefreshResponse{}, err - } - defer resp.Body.Close() - var refreshResp RefreshResponse - err = jsoniter.ConfigFastest.NewDecoder(resp.Body).Decode(&refreshResp) - return refreshResp, err -} - -func (c *Client) getRedirectURI() (string, error) { - var severConfig struct { - Bind string `json:"bind"` - } - err := c.db.GetJSON("http/config", &severConfig) - return fmt.Sprintf("http://%s/twitch/callback", severConfig.Bind), err +func getRedirectURI(baseurl string) string { + return fmt.Sprintf("http://%s/twitch/callback", baseurl) } diff --git a/twitch/client.eventsub.go b/twitch/client.eventsub.go index ad4b44e..b1bf660 100644 --- a/twitch/client.eventsub.go +++ b/twitch/client.eventsub.go @@ -15,21 +15,41 @@ import ( const websocketEndpoint = "wss://eventsub-beta.wss.twitch.tv/ws" -func (c *Client) connectWebsocket() error { +func (c *Client) connectWebsocket() { connection, _, err := websocket.DefaultDialer.Dial(websocketEndpoint, nil) if err != nil { c.logger.Error("could not connect to eventsub ws", zap.Error(err)) - return fmt.Errorf("error connecting to websocket server: %w", err) + return } + defer connection.Close() + + received := make(chan []byte) + wsErr := make(chan error) + go func(recv chan<- []byte) { + for { + messageType, messageData, err := connection.ReadMessage() + if err != nil { + c.logger.Warn("eventsub ws read error", zap.Error(err)) + wsErr <- err + return + } + if messageType != websocket.TextMessage { + continue + } + + recv <- messageData + } + }(received) for { - messageType, messageData, err := connection.ReadMessage() - if err != nil { - c.logger.Warn("eventsub ws read error", zap.Error(err)) - break - } - if messageType != websocket.TextMessage { - continue + // Wait for next message or closing/error + var messageData []byte + select { + case <-c.ctx.Done(): + return + case <-wsErr: + return + case messageData = <-received: } var wsMessage EventSubWebsocketMessage @@ -76,8 +96,6 @@ func (c *Client) connectWebsocket() error { // TODO idk what to do here } } - - return connection.Close() } func (c *Client) processEvent(message EventSubWebsocketMessage) { diff --git a/twitch/client.go b/twitch/client.go index b15f7ae..b39aa84 100644 --- a/twitch/client.go +++ b/twitch/client.go @@ -1,6 +1,7 @@ package twitch import ( + "context" "errors" "fmt" "time" @@ -16,138 +17,203 @@ import ( var json = jsoniter.ConfigFastest +type Manager struct { + client *Client + cancelSubs func() +} + +func NewManager(db *database.LocalDBClient, server *http.Server, logger *zap.Logger) (*Manager, error) { + // Get Twitch Config + var config Config + if err := db.GetJSON(ConfigKey, &config); err != nil { + if !errors.Is(err, database.ErrEmptyKey) { + return nil, fmt.Errorf("failed to get twitch config: %w", err) + } + config.Enabled = false + } + + // Get Twitch bot Config + var botConfig BotConfig + if err := db.GetJSON(BotConfigKey, &botConfig); err != nil { + if !errors.Is(err, database.ErrEmptyKey) { + return nil, fmt.Errorf("failed to get bot config: %w", err) + } + config.EnableBot = false + } + + // Create new client + client, err := newClient(config, db, server, logger) + if err != nil { + return nil, fmt.Errorf("failed to create twitch client: %w", err) + } + + if config.EnableBot { + client.Bot = newBot(client, botConfig) + go client.Bot.Connect() + } + + manager := &Manager{ + client: client, + } + + // Listen for client config changes + err, cancelConfigSub := db.SubscribeKey(ConfigKey, func(value string) { + var newConfig Config + if err := json.UnmarshalFromString(value, &newConfig); err != nil { + logger.Error("failed to unmarshal config", zap.Error(err)) + return + } + + var updatedClient *Client + updatedClient, err = newClient(newConfig, db, server, logger) + if err != nil { + logger.Error("could not create twitch client with new config, keeping old", zap.Error(err)) + return + } + + err = manager.client.Close() + if err != nil { + logger.Warn("twitch client could not close cleanly", zap.Error(err)) + } + + // New client works, replace old + updatedClient.Merge(manager.client) + manager.client = updatedClient + + logger.Info("reloaded/updated Twitch client") + }) + if err != nil { + logger.Error("could not setup twitch config reload subscription", zap.Error(err)) + } + + // Listen for bot config changes + err, cancelBotSub := db.SubscribeKey(BotConfigKey, func(value string) { + var newBotConfig BotConfig + if err := json.UnmarshalFromString(value, &newBotConfig); err != nil { + logger.Error("failed to unmarshal Config", zap.Error(err)) + return + } + + oldBot := manager.client.Bot + err = oldBot.Close() + if err != nil { + client.logger.Warn("failed to disconnect old bot from Twitch IRC", zap.Error(err)) + } + + bot := newBot(manager.client, newBotConfig) + if client.Config.Get().EnableBot { + go bot.Connect() + } + + client.Bot = bot + client.logger.Info("reloaded/restarted Twitch bot") + }) + if err != nil { + client.logger.Error("could not setup twitch bot config reload subscription", zap.Error(err)) + } + + manager.cancelSubs = func() { + if cancelConfigSub != nil { + cancelConfigSub() + } + if cancelBotSub != nil { + cancelBotSub() + } + } + + return manager, nil +} + +func (m *Manager) Client() *Client { + return m.client +} + +func (m *Manager) Close() error { + m.cancelSubs() + + if err := m.client.Close(); err != nil { + return err + } + + return nil +} + type Client struct { - Config Config + Config *containers.RWSync[Config] Bot *Bot db *database.LocalDBClient API *helix.Client logger *zap.Logger eventCache *lru.Cache + server *http.Server + ctx context.Context + cancel context.CancelFunc restart chan bool streamOnline *containers.RWSync[bool] savedSubscriptions map[string]bool } -func NewClient(db *database.LocalDBClient, server *http.Server, logger *zap.Logger) (*Client, error) { +func (c *Client) Merge(old *Client) { + // Copy bot instance and some params + c.streamOnline.Set(old.streamOnline.Get()) + c.Bot = old.Bot +} + +func newClient(config Config, db *database.LocalDBClient, server *http.Server, logger *zap.Logger) (*Client, error) { eventCache, err := lru.New(128) if err != nil { return nil, fmt.Errorf("could not create LRU cache for events: %w", err) } - // Get Twitch Config - var config Config - err = db.GetJSON(ConfigKey, &config) - if err != nil { - if !errors.Is(err, database.ErrEmptyKey) { - return nil, fmt.Errorf("failed to get twitch Config: %w", err) - } - config.Enabled = false - } - // Create Twitch client + ctx, cancel := context.WithCancel(context.Background()) client := &Client{ - Config: config, + Config: containers.NewRWSync(config), db: db, logger: logger.With(zap.String("service", "twitch")), restart: make(chan bool, 128), streamOnline: containers.NewRWSync(false), eventCache: eventCache, savedSubscriptions: make(map[string]bool), + ctx: ctx, + cancel: cancel, + server: server, } - // Listen for Config changes - err = db.SubscribeKey(ConfigKey, func(value string) { - err := json.UnmarshalFromString(value, &config) - if err != nil { - client.logger.Error("failed to unmarshal Config", zap.Error(err)) - return - } - api, err := client.getHelixAPI(config) - if err != nil { - client.logger.Warn("failed to create new twitch client, keeping old credentials", zap.Error(err)) - return - } - client.API = api - client.Config = config - - client.logger.Info("reloaded/updated Twitch API") - }) + baseurl, err := client.baseURL() if err != nil { - client.logger.Error("could not setup twitch Config reload subscription", zap.Error(err)) - } - - err = db.SubscribeKey(BotConfigKey, func(value string) { - var twitchBotConfig BotConfig - err := json.UnmarshalFromString(value, &twitchBotConfig) - if err != nil { - client.logger.Error("failed to unmarshal Config", zap.Error(err)) - return - } - err = client.Bot.Client.Disconnect() - if err != nil { - client.logger.Warn("failed to disconnect from Twitch IRC", zap.Error(err)) - } - if client.Config.EnableBot { - if err := client.startBot(); err != nil { - if !errors.Is(err, database.ErrEmptyKey) { - client.logger.Error("failed to re-create bot", zap.Error(err)) - } - } - } - client.restart <- true - client.logger.Info("reloaded/restarted Twitch bot") - }) - if err != nil { - client.logger.Error("could not setup twitch bot Config reload subscription", zap.Error(err)) + return nil, err } if config.Enabled { - client.API, err = client.getHelixAPI(config) + api, err := getHelixAPI(config, baseurl) if err != nil { - client.logger.Error("failed to create twitch client", zap.Error(err)) - } else { - server.SetRoute("/twitch/callback", client.AuthorizeCallback) - - go client.runStatusPoll() - go client.connectWebsocket() + return nil, fmt.Errorf("failed to create twitch client: %w", err) } + + client.API = api + server.RegisterRoute(CallbackRoute, client) + + go client.runStatusPoll() + go client.connectWebsocket() } - if client.Config.EnableBot { - if err := client.startBot(); err != nil { - if !errors.Is(err, database.ErrEmptyKey) { - return nil, err - } - } - } - - go func() { - for { - if client.Config.EnableBot && client.Bot != nil { - err := client.RunBot() - if err != nil { - client.logger.Error("failed to connect to Twitch IRC", zap.Error(err)) - // Wait for Config change before retrying - <-client.restart - } - } else { - <-client.restart - } - } - }() - return client, nil } func (c *Client) runStatusPoll() { c.logger.Info("status poll started") for { - // Wait for next poll - time.Sleep(60 * time.Second) + // Wait for next poll (or cancellation) + select { + case <-c.ctx.Done(): + return + case <-time.After(60 * time.Second): + } // Make sure we're configured and connected properly first - if !c.Config.Enabled || c.Bot == nil || c.Bot.Config.Channel == "" { + if !c.Config.Get().Enabled || c.Bot == nil || c.Bot.Config.Channel == "" { continue } @@ -170,28 +236,8 @@ func (c *Client) runStatusPoll() { } } -func (c *Client) startBot() error { - // Get Twitch bot Config - var twitchBotConfig BotConfig - err := c.db.GetJSON(BotConfigKey, &twitchBotConfig) - if err != nil { - if !errors.Is(err, database.ErrEmptyKey) { - return fmt.Errorf("failed to get bot Config: %w", err) - } - c.Config.EnableBot = false - } - - // Create and run IRC bot - c.Bot = NewBot(c, twitchBotConfig) - - return nil -} - -func (c *Client) getHelixAPI(config Config) (*helix.Client, error) { - redirectURI, err := c.getRedirectURI() - if err != nil { - return nil, err - } +func getHelixAPI(config Config, baseurl string) (*helix.Client, error) { + redirectURI := getRedirectURI(baseurl) // Create Twitch client api, err := helix.NewClient(&helix.Options{ @@ -214,17 +260,12 @@ func (c *Client) getHelixAPI(config Config) (*helix.Client, error) { return api, nil } -func (c *Client) RunBot() error { - cherr := make(chan error) - go func() { - cherr <- c.Bot.Connect() - }() - select { - case <-c.restart: - return nil - case err := <-cherr: - return err +func (c *Client) baseURL() (string, error) { + var severConfig struct { + Bind string `json:"bind"` } + err := c.db.GetJSON("http/config", &severConfig) + return severConfig.Bind, err } func (c *Client) IsLive() bool { @@ -232,5 +273,12 @@ func (c *Client) IsLive() bool { } func (c *Client) Close() error { - return c.Bot.Client.Disconnect() + c.server.UnregisterRoute(CallbackRoute) + defer c.cancel() + + if err := c.Bot.Close(); err != nil { + return err + } + + return nil } diff --git a/twitch/commands.go b/twitch/commands.go index 2e1fa0c..2fcd5d6 100644 --- a/twitch/commands.go +++ b/twitch/commands.go @@ -52,8 +52,11 @@ func cmdCustom(bot *Bot, cmd string, data BotCustomCommand, message irc.PrivateM // Add future logic (like counters etc.) here, for now it's just fixed messages var buf bytes.Buffer - err := bot.customTemplates[cmd].Execute(&buf, message) - if err != nil { + tpl, ok := bot.customTemplates.GetKey(cmd) + if !ok { + return + } + if err := tpl.Execute(&buf, message); err != nil { bot.logger.Error("Failed to execute custom command template", zap.Error(err)) return } @@ -87,7 +90,7 @@ func (b *Bot) setupFunctions() { counterKey := BotCounterPrefix + name counter := 0 if byt, err := b.api.db.GetKey(counterKey); err == nil { - counter, _ = strconv.Atoi(string(byt)) + counter, _ = strconv.Atoi(byt) } counter += 1 err := b.api.db.PutKey(counterKey, strconv.Itoa(counter)) diff --git a/twitch/data.go b/twitch/data.go index 55a371f..b27f054 100644 --- a/twitch/data.go +++ b/twitch/data.go @@ -1,5 +1,7 @@ package twitch +const CallbackRoute = "/twitch/callback" + const ConfigKey = "twitch/config" type Config struct { diff --git a/utils/context.go b/utils/context.go new file mode 100644 index 0000000..c5eeee8 --- /dev/null +++ b/utils/context.go @@ -0,0 +1,7 @@ +package utils + +type ContextKey string + +const ( + ContextLogger ContextKey = "logger" +) diff --git a/utils/map.go b/utils/map.go new file mode 100644 index 0000000..0a7437c --- /dev/null +++ b/utils/map.go @@ -0,0 +1,13 @@ +package utils + +import "git.sr.ht/~hamcha/containers" + +func MergeMap[T comparable, V any](a, b map[T]V) { + for key, value := range b { + a[key] = value + } +} + +func MergeSyncMap[T comparable, V any](a, b *containers.SyncMap[T, V]) { + b.Set(a.Copy()) +} diff --git a/utils/pubsub.go b/utils/pubsub.go index c3bde17..60d4cb2 100644 --- a/utils/pubsub.go +++ b/utils/pubsub.go @@ -31,3 +31,9 @@ func (p *PubSub[T]) Unsubscribe(handler T) { func (p *PubSub[T]) Subscribers() []T { return p.subscribers.Get() } + +func (p *PubSub[T]) Copy(other *PubSub[T]) { + for _, subscriber := range other.Subscribers() { + p.Subscribe(subscriber) + } +}