From df46f7dfbfd64454c68196e0d19a32d5a7db1c2a Mon Sep 17 00:00:00 2001 From: Ash Keel Date: Sun, 5 Mar 2023 20:11:19 +0100 Subject: [PATCH] fix: force refreshing the auth token when starting --- loyalty/twitch-bot.go | 2 +- twitch/client.auth.go | 6 +++--- twitch/client.eventsub.go | 19 ++++++++++++------- twitch/client.go | 2 +- 4 files changed, 17 insertions(+), 12 deletions(-) diff --git a/loyalty/twitch-bot.go b/loyalty/twitch-bot.go index fac01fe..b17fbc7 100644 --- a/loyalty/twitch-bot.go +++ b/loyalty/twitch-bot.go @@ -84,7 +84,7 @@ func (m *Manager) SetupTwitch() { cursor := "" var users []string for { - userClient, err := client.GetUserClient() + userClient, err := client.GetUserClient(false) if err != nil { m.logger.Error("could not get user api client for list of chatters", zap.Error(err)) return diff --git a/twitch/client.auth.go b/twitch/client.auth.go index 2c03ae5..0537b7b 100644 --- a/twitch/client.auth.go +++ b/twitch/client.auth.go @@ -26,14 +26,14 @@ func (c *Client) GetAuthorizationURL() string { }) } -func (c *Client) GetUserClient() (*helix.Client, error) { +func (c *Client) GetUserClient(forceRefresh bool) (*helix.Client, error) { var authResp AuthResponse err := c.db.GetJSON(AuthKey, &authResp) if err != nil { return nil, err } // Handle token expiration - if time.Now().After(authResp.Time.Add(time.Duration(authResp.ExpiresIn) * time.Second)) { + if forceRefresh || time.Now().After(authResp.Time.Add(time.Duration(authResp.ExpiresIn)*time.Second)) { // Refresh tokens refreshed, err := c.API.RefreshUserAccessToken(authResp.RefreshToken) if err != nil { @@ -63,7 +63,7 @@ func (c *Client) GetLoggedUser() (helix.User, error) { return c.User, nil } - client, err := c.GetUserClient() + client, err := c.GetUserClient(false) if err != nil { return helix.User{}, fmt.Errorf("failed getting API client for user: %w", err) } diff --git a/twitch/client.eventsub.go b/twitch/client.eventsub.go index 09372f6..49ecbb5 100644 --- a/twitch/client.eventsub.go +++ b/twitch/client.eventsub.go @@ -17,22 +17,23 @@ const websocketEndpoint = "wss://eventsub-beta.wss.twitch.tv/ws" func (c *Client) eventSubLoop(userClient *helix.Client) { endpoint := websocketEndpoint var err error + var connection *websocket.Conn for endpoint != "" { - endpoint, err = c.connectWebsocket(endpoint, userClient) + endpoint, connection, err = c.connectWebsocket(endpoint, connection, userClient) if err != nil { c.logger.Error("eventsub ws read error", zap.Error(err)) break } } + utils.Close(connection, c.logger) } -func (c *Client) connectWebsocket(url string, userClient *helix.Client) (string, error) { +func (c *Client) connectWebsocket(url string, oldConnection *websocket.Conn, userClient *helix.Client) (string, *websocket.Conn, error) { connection, _, err := websocket.DefaultDialer.Dial(url, nil) if err != nil { c.logger.Error("could not connect to eventsub ws", zap.Error(err)) - return "", err + return "", nil, err } - defer utils.Close(connection, c.logger) received := make(chan []byte, 10) wsErr := make(chan error, 1) @@ -60,9 +61,9 @@ func (c *Client) connectWebsocket(url string, userClient *helix.Client) (string, var messageData []byte select { case <-c.ctx.Done(): - return "", nil + return "", nil, nil case err = <-wsErr: - return "", err + return "", nil, err case messageData = <-received: } @@ -84,6 +85,10 @@ func (c *Client) connectWebsocket(url string, userClient *helix.Client) (string, break } c.logger.Info("eventsub ws connection established", zap.String("session-id", welcomeData.Session.Id)) + + if oldConnection != nil { + utils.Close(connection, c.logger) + } // Add subscription to websocket session err = c.addSubscriptionsForSession(userClient, welcomeData.Session.Id) if err != nil { @@ -99,7 +104,7 @@ func (c *Client) connectWebsocket(url string, userClient *helix.Client) (string, } c.logger.Info("eventsub ws connection reset requested", zap.String("session-id", reconnectData.Session.Id), zap.String("reconnect-url", reconnectData.Session.ReconnectUrl)) - return reconnectData.Session.ReconnectUrl, nil + return reconnectData.Session.ReconnectUrl, connection, nil case "notification": go c.processEvent(wsMessage) case "revocation": diff --git a/twitch/client.go b/twitch/client.go index 1fffffc..ef4a384 100644 --- a/twitch/client.go +++ b/twitch/client.go @@ -208,7 +208,7 @@ func newClient(config Config, db *database.LocalDBClient, server *http.Server, l client.API = api server.RegisterRoute(CallbackRoute, client) - if userClient, err := client.GetUserClient(); err == nil { + if userClient, err := client.GetUserClient(true); err == nil { users, err := userClient.GetUsers(&helix.UsersParams{}) if err != nil { client.logger.Error("failed looking up user", zap.Error(err))