mirror of https://git.sr.ht/~ashkeel/strimertul
refactor: rename http to webserver, add testing
This commit is contained in:
parent
4fbac65c1b
commit
dcf8392346
6
app.go
6
app.go
|
@ -24,9 +24,9 @@ import (
|
|||
|
||||
"github.com/strimertul/strimertul/database"
|
||||
"github.com/strimertul/strimertul/docs"
|
||||
"github.com/strimertul/strimertul/http"
|
||||
"github.com/strimertul/strimertul/loyalty"
|
||||
"github.com/strimertul/strimertul/twitch"
|
||||
"github.com/strimertul/strimertul/webserver"
|
||||
)
|
||||
|
||||
// App struct
|
||||
|
@ -41,7 +41,7 @@ type App struct {
|
|||
|
||||
db *database.LocalDBClient
|
||||
twitchManager *twitch.Manager
|
||||
httpServer *http.Server
|
||||
httpServer *webserver.WebServer
|
||||
loyaltyManager *loyalty.Manager
|
||||
}
|
||||
|
||||
|
@ -153,7 +153,7 @@ func (a *App) initializeComponents() error {
|
|||
var err error
|
||||
|
||||
// Create logger and endpoints
|
||||
a.httpServer, err = http.NewServer(a.db, logger)
|
||||
a.httpServer, err = webserver.NewServer(a.db, logger, webserver.DefaultServerFactory)
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not initialize http server: %w", err)
|
||||
}
|
||||
|
|
|
@ -1,16 +1,16 @@
|
|||
package database
|
||||
|
||||
import (
|
||||
jsoniter "github.com/json-iterator/go"
|
||||
kv "github.com/strimertul/kilovolt/v10"
|
||||
"go.uber.org/zap/zaptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
jsoniter "github.com/json-iterator/go"
|
||||
kv "github.com/strimertul/kilovolt/v10"
|
||||
)
|
||||
|
||||
func TestLocalDBClientPutKey(t *testing.T) {
|
||||
client, store := createLocalClient(t)
|
||||
defer cleanupClient(client)
|
||||
client, store := CreateInMemoryLocalClient(t)
|
||||
defer CleanupLocalClient(client)
|
||||
|
||||
// Store a key using the local client
|
||||
key := "test"
|
||||
|
@ -31,8 +31,8 @@ func TestLocalDBClientPutKey(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestLocalDBClientPutJSON(t *testing.T) {
|
||||
client, store := createLocalClient(t)
|
||||
defer cleanupClient(client)
|
||||
client, store := CreateInMemoryLocalClient(t)
|
||||
defer CleanupLocalClient(client)
|
||||
|
||||
type test struct {
|
||||
A string
|
||||
|
@ -72,8 +72,8 @@ func TestLocalDBClientPutJSON(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestLocalDBClientPutJSONBulk(t *testing.T) {
|
||||
client, store := createLocalClient(t)
|
||||
defer cleanupClient(client)
|
||||
client, store := CreateInMemoryLocalClient(t)
|
||||
defer CleanupLocalClient(client)
|
||||
|
||||
type test struct {
|
||||
A string
|
||||
|
@ -131,8 +131,8 @@ func TestLocalDBClientPutJSONBulk(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestLocalDBClientGetKey(t *testing.T) {
|
||||
client, store := createLocalClient(t)
|
||||
defer cleanupClient(client)
|
||||
client, store := CreateInMemoryLocalClient(t)
|
||||
defer CleanupLocalClient(client)
|
||||
|
||||
// Store a key directly in the store
|
||||
key := "test"
|
||||
|
@ -154,8 +154,8 @@ func TestLocalDBClientGetKey(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestLocalDBClientGetJSON(t *testing.T) {
|
||||
client, store := createLocalClient(t)
|
||||
defer cleanupClient(client)
|
||||
client, store := CreateInMemoryLocalClient(t)
|
||||
defer CleanupLocalClient(client)
|
||||
|
||||
type test struct {
|
||||
A string
|
||||
|
@ -195,8 +195,8 @@ func TestLocalDBClientGetJSON(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestLocalDBClientGetAll(t *testing.T) {
|
||||
client, store := createLocalClient(t)
|
||||
defer cleanupClient(client)
|
||||
client, store := CreateInMemoryLocalClient(t)
|
||||
defer CleanupLocalClient(client)
|
||||
|
||||
// Store a bunch of keys directly in the store
|
||||
keys := map[string]string{
|
||||
|
@ -231,8 +231,8 @@ func TestLocalDBClientGetAll(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestLocalDBClientRemoveKey(t *testing.T) {
|
||||
client, store := createLocalClient(t)
|
||||
defer cleanupClient(client)
|
||||
client, store := CreateInMemoryLocalClient(t)
|
||||
defer CleanupLocalClient(client)
|
||||
|
||||
// Store a key directly in the store
|
||||
key := "test"
|
||||
|
@ -257,8 +257,8 @@ func TestLocalDBClientRemoveKey(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestLocalDBClientSubscribeKey(t *testing.T) {
|
||||
client, _ := createLocalClient(t)
|
||||
defer cleanupClient(client)
|
||||
client, _ := CreateInMemoryLocalClient(t)
|
||||
defer CleanupLocalClient(client)
|
||||
|
||||
// Subscribe to a key using the local client
|
||||
key := "test"
|
||||
|
@ -288,8 +288,8 @@ func TestLocalDBClientSubscribeKey(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestLocalDBClientSubscribePrefix(t *testing.T) {
|
||||
client, _ := createLocalClient(t)
|
||||
defer cleanupClient(client)
|
||||
client, _ := CreateInMemoryLocalClient(t)
|
||||
defer CleanupLocalClient(client)
|
||||
|
||||
// Subscribe to a prefix using the local client
|
||||
prefix := "test"
|
||||
|
@ -317,30 +317,3 @@ func TestLocalDBClientSubscribePrefix(t *testing.T) {
|
|||
t.Fatal("expected value to be received")
|
||||
}
|
||||
}
|
||||
|
||||
func createLocalClient(t *testing.T) (*LocalDBClient, kv.Driver) {
|
||||
logger := zaptest.NewLogger(t)
|
||||
|
||||
// Create in-memory store and hub
|
||||
inmemStore := kv.MakeBackend()
|
||||
hub, err := kv.NewHub(inmemStore, kv.HubOptions{}, logger)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
go hub.Run()
|
||||
|
||||
// Create local client
|
||||
client, err := NewLocalClient(hub, logger)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
return client, inmemStore
|
||||
}
|
||||
|
||||
func cleanupClient(client *LocalDBClient) {
|
||||
if client.hub != nil {
|
||||
_ = client.Close()
|
||||
client.hub.Close()
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,35 @@
|
|||
package database
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
kv "github.com/strimertul/kilovolt/v10"
|
||||
"go.uber.org/zap/zaptest"
|
||||
)
|
||||
|
||||
func CreateInMemoryLocalClient(t *testing.T) (*LocalDBClient, kv.Driver) {
|
||||
logger := zaptest.NewLogger(t)
|
||||
|
||||
// Create in-memory store and hub
|
||||
inMemoryStore := kv.MakeBackend()
|
||||
hub, err := kv.NewHub(inMemoryStore, kv.HubOptions{}, logger)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
go hub.Run()
|
||||
|
||||
// Create local client
|
||||
client, err := NewLocalClient(hub, logger)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
return client, inMemoryStore
|
||||
}
|
||||
|
||||
func CleanupLocalClient(client *LocalDBClient) {
|
||||
if client.hub != nil {
|
||||
_ = client.Close()
|
||||
client.hub.Close()
|
||||
}
|
||||
}
|
|
@ -2,10 +2,10 @@ package docs
|
|||
|
||||
import (
|
||||
"github.com/strimertul/strimertul/docs/interfaces"
|
||||
"github.com/strimertul/strimertul/http"
|
||||
"github.com/strimertul/strimertul/loyalty"
|
||||
"github.com/strimertul/strimertul/twitch"
|
||||
"github.com/strimertul/strimertul/utils"
|
||||
"github.com/strimertul/strimertul/webserver"
|
||||
)
|
||||
|
||||
var (
|
||||
|
@ -30,5 +30,5 @@ func init() {
|
|||
// Put all keys here
|
||||
addKeys(twitch.Keys)
|
||||
addKeys(loyalty.Keys)
|
||||
addKeys(http.Keys)
|
||||
addKeys(webserver.Keys)
|
||||
}
|
||||
|
|
|
@ -13,7 +13,7 @@ import (
|
|||
"go.uber.org/zap"
|
||||
|
||||
"github.com/strimertul/strimertul/database"
|
||||
"github.com/strimertul/strimertul/http"
|
||||
"github.com/strimertul/strimertul/webserver"
|
||||
)
|
||||
|
||||
var json = jsoniter.ConfigFastest
|
||||
|
@ -23,7 +23,7 @@ type Manager struct {
|
|||
cancelSubs func()
|
||||
}
|
||||
|
||||
func NewManager(db *database.LocalDBClient, server *http.Server, logger *zap.Logger) (*Manager, error) {
|
||||
func NewManager(db *database.LocalDBClient, server *webserver.WebServer, logger *zap.Logger) (*Manager, error) {
|
||||
// Get Twitch config
|
||||
var config Config
|
||||
if err := db.GetJSON(ConfigKey, &config); err != nil {
|
||||
|
@ -150,7 +150,7 @@ type Client struct {
|
|||
User helix.User
|
||||
logger *zap.Logger
|
||||
eventCache *lru.Cache[string, time.Time]
|
||||
server *http.Server
|
||||
server *webserver.WebServer
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
|
||||
|
@ -173,7 +173,7 @@ func (c *Client) ensureRoute() {
|
|||
}
|
||||
}
|
||||
|
||||
func newClient(config Config, db *database.LocalDBClient, server *http.Server, logger *zap.Logger) (*Client, error) {
|
||||
func newClient(config Config, db *database.LocalDBClient, server *webserver.WebServer, logger *zap.Logger) (*Client, error) {
|
||||
eventCache, err := lru.New[string, time.Time](128)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not create LRU cache for events: %w", err)
|
||||
|
|
|
@ -0,0 +1,28 @@
|
|||
package twitch
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/strimertul/strimertul/webserver"
|
||||
|
||||
"go.uber.org/zap/zaptest"
|
||||
|
||||
"github.com/strimertul/strimertul/database"
|
||||
)
|
||||
|
||||
func TestNewClient(t *testing.T) {
|
||||
logger := zaptest.NewLogger(t)
|
||||
client, _ := database.CreateInMemoryLocalClient(t)
|
||||
defer database.CleanupLocalClient(client)
|
||||
|
||||
server, err := webserver.NewServer(client, logger, webserver.DefaultServerFactory)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
config := Config{}
|
||||
_, err = newClient(config, client, server, logger)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
|
@ -1,4 +1,4 @@
|
|||
package http
|
||||
package webserver
|
||||
|
||||
const ServerConfigKey = "http/config"
|
||||
|
|
@ -1,8 +1,9 @@
|
|||
package http
|
||||
package webserver
|
||||
|
||||
import (
|
||||
"github.com/strimertul/strimertul/docs/interfaces"
|
||||
"reflect"
|
||||
|
||||
"github.com/strimertul/strimertul/docs/interfaces"
|
||||
)
|
||||
|
||||
// Documentation stuff, keep updated at all times
|
|
@ -0,0 +1,29 @@
|
|||
package webserver
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
type ServerFactory = func(h http.Handler, addr string) (Server, error)
|
||||
|
||||
func DefaultServerFactory(h http.Handler, addr string) (Server, error) {
|
||||
return &HTTPServer{http.Server{
|
||||
Addr: addr,
|
||||
Handler: h,
|
||||
}}, nil
|
||||
}
|
||||
|
||||
type Server interface {
|
||||
Start() error
|
||||
Close() error
|
||||
Shutdown(ctx context.Context) error
|
||||
}
|
||||
|
||||
type HTTPServer struct {
|
||||
http.Server
|
||||
}
|
||||
|
||||
func (s *HTTPServer) Start() error {
|
||||
return s.ListenAndServe()
|
||||
}
|
|
@ -1,4 +1,4 @@
|
|||
package http
|
||||
package webserver
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
@ -21,27 +21,29 @@ import (
|
|||
|
||||
var json = jsoniter.ConfigFastest
|
||||
|
||||
type Server struct {
|
||||
type WebServer struct {
|
||||
Config *sync.RWSync[ServerConfig]
|
||||
db *database.LocalDBClient
|
||||
logger *zap.Logger
|
||||
server *http.Server
|
||||
server Server
|
||||
frontend fs.FS
|
||||
hub *kv.Hub
|
||||
mux *http.ServeMux
|
||||
requestedRoutes *sync.Map[string, http.Handler]
|
||||
restart *sync.RWSync[bool]
|
||||
cancelConfigSub database.CancelFunc
|
||||
factory ServerFactory
|
||||
}
|
||||
|
||||
func NewServer(db *database.LocalDBClient, logger *zap.Logger) (*Server, error) {
|
||||
server := &Server{
|
||||
func NewServer(db *database.LocalDBClient, logger *zap.Logger, serverFactory ServerFactory) (*WebServer, error) {
|
||||
server := &WebServer{
|
||||
logger: logger,
|
||||
db: db,
|
||||
server: &http.Server{},
|
||||
server: nil,
|
||||
requestedRoutes: sync.NewMap[string, http.Handler](),
|
||||
restart: sync.NewRWSync(false),
|
||||
Config: sync.NewRWSync(ServerConfig{}),
|
||||
factory: serverFactory,
|
||||
}
|
||||
|
||||
var config ServerConfig
|
||||
|
@ -86,19 +88,26 @@ type StatusData struct {
|
|||
Bind string
|
||||
}
|
||||
|
||||
func (s *Server) Close() error {
|
||||
func (s *WebServer) Close() error {
|
||||
if s.cancelConfigSub != nil {
|
||||
s.cancelConfigSub()
|
||||
}
|
||||
|
||||
return s.server.Close()
|
||||
if s.server != nil {
|
||||
err := s.server.Close()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) SetFrontend(files fs.FS) {
|
||||
func (s *WebServer) SetFrontend(files fs.FS) {
|
||||
s.frontend = files
|
||||
}
|
||||
|
||||
func (s *Server) makeMux() *http.ServeMux {
|
||||
func (s *WebServer) makeMux() *http.ServeMux {
|
||||
mux := http.NewServeMux()
|
||||
|
||||
// Register pprof
|
||||
|
@ -107,6 +116,7 @@ func (s *Server) makeMux() *http.ServeMux {
|
|||
mux.HandleFunc("/debug/pprof/profile", pprof.Profile)
|
||||
mux.HandleFunc("/debug/pprof/symbol", pprof.Symbol)
|
||||
mux.HandleFunc("/debug/pprof/trace", pprof.Trace)
|
||||
mux.HandleFunc("/health", healthFunc)
|
||||
|
||||
if s.frontend != nil {
|
||||
mux.Handle("/ui/", http.StripPrefix("/ui/", FileServerWithDefault(http.FS(s.frontend))))
|
||||
|
@ -127,35 +137,50 @@ func (s *Server) makeMux() *http.ServeMux {
|
|||
return mux
|
||||
}
|
||||
|
||||
func (s *Server) RegisterRoute(route string, handler http.Handler) {
|
||||
func healthFunc(w http.ResponseWriter, _ *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
fmt.Fprint(w, "OK")
|
||||
}
|
||||
|
||||
func (s *WebServer) RegisterRoute(route string, handler http.Handler) {
|
||||
s.requestedRoutes.SetKey(route, handler)
|
||||
s.mux = s.makeMux()
|
||||
}
|
||||
|
||||
func (s *Server) UnregisterRoute(route string) {
|
||||
func (s *WebServer) UnregisterRoute(route string) {
|
||||
s.requestedRoutes.DeleteKey(route)
|
||||
s.mux = s.makeMux()
|
||||
}
|
||||
|
||||
func (s *Server) Listen() error {
|
||||
func (s *WebServer) Listen() error {
|
||||
// Start HTTP server
|
||||
exit := make(chan error)
|
||||
go func() {
|
||||
for {
|
||||
// Read config and make http request mux
|
||||
config := s.Config.Get()
|
||||
s.logger.Info("Starting HTTP server", zap.String("bind", config.Bind))
|
||||
s.mux = s.makeMux()
|
||||
s.server = &http.Server{
|
||||
Handler: s,
|
||||
Addr: config.Bind,
|
||||
|
||||
// Make HTTP server instance
|
||||
var err error
|
||||
s.server, err = s.factory(s, config.Bind)
|
||||
if err != nil {
|
||||
exit <- err
|
||||
return
|
||||
}
|
||||
|
||||
// Start HTTP server
|
||||
s.logger.Info("HTTP server started", zap.String("bind", config.Bind))
|
||||
err := s.server.ListenAndServe()
|
||||
err = s.server.Start()
|
||||
|
||||
// If the server died, we need to see what to do
|
||||
s.logger.Debug("HTTP server died", zap.Error(err))
|
||||
if err != nil && !errors.Is(err, http.ErrServerClosed) {
|
||||
exit <- err
|
||||
return
|
||||
}
|
||||
|
||||
// Are we trying to close or restart?
|
||||
s.logger.Debug("HTTP server stopped", zap.Bool("restart", s.restart.Get()))
|
||||
if s.restart.Get() {
|
||||
|
@ -171,7 +196,7 @@ func (s *Server) Listen() error {
|
|||
return <-exit
|
||||
}
|
||||
|
||||
func (s *Server) onConfigUpdate(value string) {
|
||||
func (s *WebServer) onConfigUpdate(value string) {
|
||||
oldConfig := s.Config.Get()
|
||||
|
||||
var config ServerConfig
|
||||
|
@ -200,7 +225,7 @@ func (s *Server) onConfigUpdate(value string) {
|
|||
}
|
||||
}
|
||||
|
||||
func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
func (s *WebServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
// Redirect to /ui/ if root
|
||||
if r.URL.Path == "/" {
|
||||
http.Redirect(w, r, "/ui/", http.StatusFound)
|
|
@ -0,0 +1,98 @@
|
|||
package webserver
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"go.uber.org/zap/zaptest"
|
||||
|
||||
"github.com/strimertul/strimertul/database"
|
||||
)
|
||||
|
||||
func TestNewServer(t *testing.T) {
|
||||
logger := zaptest.NewLogger(t)
|
||||
client, _ := database.CreateInMemoryLocalClient(t)
|
||||
defer database.CleanupLocalClient(client)
|
||||
|
||||
_, err := NewServer(client, logger, DefaultServerFactory)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewServerWithTestFactory(t *testing.T) {
|
||||
logger := zaptest.NewLogger(t)
|
||||
client, _ := database.CreateInMemoryLocalClient(t)
|
||||
defer database.CleanupLocalClient(client)
|
||||
|
||||
testServer := NewTestServer()
|
||||
_, err := NewServer(client, logger, testServer.Factory())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestListen(t *testing.T) {
|
||||
logger := zaptest.NewLogger(t)
|
||||
client, _ := database.CreateInMemoryLocalClient(t)
|
||||
defer database.CleanupLocalClient(client)
|
||||
|
||||
// Create a test server
|
||||
testServer := NewTestServer()
|
||||
server, err := NewServer(client, logger, testServer.Factory())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Start the server
|
||||
finished := make(chan struct{})
|
||||
go func() {
|
||||
err := server.Listen()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
finished <- struct{}{}
|
||||
}()
|
||||
|
||||
// Wait for the server to start up so we can get the client
|
||||
testServer.Wait()
|
||||
|
||||
// Make a request to the server
|
||||
httpClient := testServer.Client()
|
||||
resp, err := httpClient.Get(fmt.Sprintf("%s/health", testServer.URL()))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if resp.StatusCode != 200 {
|
||||
t.Fatalf("Expected 200, got %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
// Close the server
|
||||
err = server.Close()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Make sure the related goroutines have terminated
|
||||
select {
|
||||
case <-finished:
|
||||
case <-time.After(time.Second * 2):
|
||||
t.Fatal("Server did not shut down in time")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGeneratePassword(t *testing.T) {
|
||||
// Generate a bunch of passwords and make sure they are different and sufficiently long
|
||||
passwords := make(map[string]bool)
|
||||
for i := 0; i < 100; i++ {
|
||||
password := generatePassword()
|
||||
if len(password) < 8 {
|
||||
t.Fatalf("Password '%s' is empty or too short", password)
|
||||
}
|
||||
if passwords[password] {
|
||||
t.Fatal("Duplicate password")
|
||||
}
|
||||
passwords[password] = true
|
||||
}
|
||||
}
|
|
@ -1,6 +1,6 @@
|
|||
// FROM https://gist.github.com/lummie/91cd1c18b2e32fa9f316862221a6fd5c
|
||||
|
||||
package http
|
||||
package webserver
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
@ -13,7 +13,7 @@ func FileServerWithDefault(root http.FileSystem) http.Handler {
|
|||
fs := http.FileServer(root)
|
||||
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
//make sure the url path starts with /
|
||||
// make sure the url path starts with /
|
||||
upath := r.URL.Path
|
||||
if !strings.HasPrefix(upath, "/") {
|
||||
upath = "/" + upath
|
|
@ -0,0 +1,62 @@
|
|||
package webserver
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
|
||||
"git.sr.ht/~hamcha/containers/sync"
|
||||
)
|
||||
|
||||
type TestServer struct {
|
||||
server *sync.Sync[*httptest.Server]
|
||||
start chan struct{}
|
||||
close chan error
|
||||
}
|
||||
|
||||
func NewTestServer() *TestServer {
|
||||
return &TestServer{
|
||||
server: sync.NewSync[*httptest.Server](nil),
|
||||
close: make(chan error),
|
||||
start: make(chan struct{}, 10),
|
||||
}
|
||||
}
|
||||
|
||||
func (t *TestServer) Start() error {
|
||||
server := t.server.Get()
|
||||
server.Start()
|
||||
t.start <- struct{}{}
|
||||
return <-t.close
|
||||
}
|
||||
|
||||
func (t *TestServer) Close() error {
|
||||
t.server.Get().Close()
|
||||
t.close <- nil
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *TestServer) Shutdown(_ context.Context) error {
|
||||
return t.Close()
|
||||
}
|
||||
|
||||
func (t *TestServer) Factory() ServerFactory {
|
||||
return func(h http.Handler, addr string) (Server, error) {
|
||||
s := httptest.NewUnstartedServer(h)
|
||||
t.server.Set(s)
|
||||
return t, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (t *TestServer) Wait() {
|
||||
if t.server.Get() == nil {
|
||||
<-t.start
|
||||
}
|
||||
}
|
||||
|
||||
func (t *TestServer) Client() *http.Client {
|
||||
return t.server.Get().Client()
|
||||
}
|
||||
|
||||
func (t *TestServer) URL() string {
|
||||
return t.server.Get().URL
|
||||
}
|
Loading…
Reference in New Issue