From a99f944f335b7e6a60c9c222dffb987aa8b797a3 Mon Sep 17 00:00:00 2001 From: Ash Keel Date: Sun, 4 Dec 2022 00:36:13 +0100 Subject: [PATCH] fix: server config doesn't get half-changed in case of error --- http/server.go | 50 +++++++++++++++++++++++++++++--------------------- 1 file changed, 29 insertions(+), 21 deletions(-) diff --git a/http/server.go b/http/server.go index e73fd2b..6dcaab9 100644 --- a/http/server.go +++ b/http/server.go @@ -8,6 +8,8 @@ import ( "net/http" "net/http/pprof" + "github.com/strimertul/strimertul/utils" + "git.sr.ht/~hamcha/containers" jsoniter "github.com/json-iterator/go" "github.com/strimertul/strimertul/database" @@ -20,14 +22,14 @@ import ( var json = jsoniter.ConfigFastest type Server struct { - Config ServerConfig + Config *containers.RWSync[ServerConfig] db *database.LocalDBClient logger *zap.Logger server *http.Server frontend fs.FS hub *kv.Hub mux *http.ServeMux - requestedRoutes map[string]http.Handler + requestedRoutes *containers.SyncMap[string, http.Handler] cancelConfigSub database.CancelFunc } @@ -36,17 +38,18 @@ func NewServer(db *database.LocalDBClient, logger *zap.Logger) (*Server, error) logger: logger, db: db, server: &http.Server{}, - requestedRoutes: make(map[string]http.Handler), + requestedRoutes: containers.NewSyncMap[string, http.Handler](), + Config: containers.NewRWSync(ServerConfig{}), } - err := db.GetJSON(ServerConfigKey, &server.Config) + err := utils.LoadJSONToWrapped[ServerConfig](ServerConfigKey, server.Config) if err != nil { // Initialize with default config - server.Config = ServerConfig{ + server.Config.Set(ServerConfig{ Bind: "localhost:4337", EnableStaticServer: false, KVPassword: "", - } + }) // Save err = db.PutJSON(ServerConfigKey, server.Config) if err != nil { @@ -59,7 +62,7 @@ func NewServer(db *database.LocalDBClient, logger *zap.Logger) (*Server, error) // Set password server.hub.SetOptions(kv.HubOptions{ - Password: server.Config.KVPassword, + Password: server.Config.Get().KVPassword, }) return server, nil @@ -100,10 +103,11 @@ func (s *Server) makeMux() *http.ServeMux { kv.ServeWs(s.hub, w, r) }) } - if s.Config.EnableStaticServer { - mux.Handle("/static/", http.StripPrefix("/static/", http.FileServer(http.Dir(s.Config.Path)))) + config := s.Config.Get() + if config.EnableStaticServer { + mux.Handle("/static/", http.StripPrefix("/static/", http.FileServer(http.Dir(config.Path)))) } - for route, handler := range s.requestedRoutes { + for route, handler := range s.requestedRoutes.Copy() { mux.Handle(route, handler) } @@ -111,12 +115,12 @@ func (s *Server) makeMux() *http.ServeMux { } func (s *Server) RegisterRoute(route string, handler http.Handler) { - s.requestedRoutes[route] = handler + s.requestedRoutes.SetKey(route, handler) s.mux = s.makeMux() } func (s *Server) UnregisterRoute(route string) { - delete(s.requestedRoutes, route) + s.requestedRoutes.DeleteKey(route) s.mux = s.makeMux() } @@ -126,22 +130,25 @@ func (s *Server) Listen() error { exit := make(chan error) var err error err, s.cancelConfigSub = s.db.SubscribeKey(ServerConfigKey, func(value string) { - oldBind := s.Config.Bind - oldPassword := s.Config.KVPassword - err := json.Unmarshal([]byte(value), &s.Config) + oldConfig := s.Config.Get() + + var config ServerConfig + err := json.Unmarshal([]byte(value), &config) if err != nil { s.logger.Error("Failed to unmarshal config", zap.Error(err)) return } + + s.Config.Set(config) s.mux = s.makeMux() // Restart hub if password changed - if oldPassword != s.Config.KVPassword { + if oldConfig.KVPassword != config.KVPassword { s.hub.SetOptions(kv.HubOptions{ - Password: s.Config.KVPassword, + Password: config.KVPassword, }) } // Restart server if bind changed - if oldBind != s.Config.Bind { + if oldConfig.Bind != config.Bind { restart.Set(true) err = s.server.Shutdown(context.Background()) if err != nil { @@ -155,13 +162,14 @@ func (s *Server) Listen() error { } go func() { for { - s.logger.Info("Starting HTTP server", zap.String("bind", s.Config.Bind)) + config := s.Config.Get() + s.logger.Info("Starting HTTP server", zap.String("bind", config.Bind)) s.mux = s.makeMux() s.server = &http.Server{ Handler: s, - Addr: s.Config.Bind, + Addr: config.Bind, } - s.logger.Info("HTTP server started", zap.String("bind", s.Config.Bind)) + s.logger.Info("HTTP server started", zap.String("bind", config.Bind)) err := s.server.ListenAndServe() s.logger.Debug("HTTP server died", zap.Error(err)) if err != nil && !errors.Is(err, http.ErrServerClosed) {