diff --git a/server/handler_atomic.go b/server/handler_atomic.go new file mode 100644 index 0000000..326d1bd --- /dev/null +++ b/server/handler_atomic.go @@ -0,0 +1,70 @@ +package server + +import ( + "sync" + "time" +) + +type counter struct { + duration time.Duration + zeroTimer *time.Timer + wg sync.WaitGroup + connections int + mutex sync.Mutex +} + +func newCounter(duration time.Duration) *counter { + zeroTimer := time.NewTimer(duration) + + // when duration is 0, drain the expire event here + // so that user will never get the event. + if duration == 0 { + <-zeroTimer.C + } + + return &counter{ + duration: duration, + zeroTimer: zeroTimer, + } +} + +func (counter *counter) add(n int) int { + counter.mutex.Lock() + defer counter.mutex.Unlock() + + if counter.duration > 0 { + counter.zeroTimer.Stop() + } + counter.wg.Add(n) + counter.connections += n + + return counter.connections +} + +func (counter *counter) done() int { + counter.mutex.Lock() + defer counter.mutex.Unlock() + + counter.connections-- + counter.wg.Done() + if counter.connections == 0 && counter.duration > 0 { + counter.zeroTimer.Reset(counter.duration) + } + + return counter.connections +} + +func (counter *counter) count() int { + counter.mutex.Lock() + defer counter.mutex.Unlock() + + return counter.connections +} + +func (counter *counter) wait() { + counter.wg.Wait() +} + +func (counter *counter) timer() *time.Timer { + return counter.zeroTimer +} diff --git a/server/handlers.go b/server/handlers.go index 2ca51a1..90e6429 100644 --- a/server/handlers.go +++ b/server/handlers.go @@ -8,9 +8,7 @@ import ( "log" "net/http" "net/url" - "sync" "sync/atomic" - "time" "github.com/gorilla/websocket" "github.com/pkg/errors" @@ -18,42 +16,32 @@ import ( "github.com/yudai/gotty/webtty" ) -func (server *Server) generateHandleWS(ctx context.Context, cancel context.CancelFunc, connections *int64, wg *sync.WaitGroup) http.HandlerFunc { +func (server *Server) generateHandleWS(ctx context.Context, cancel context.CancelFunc, counter *counter) http.HandlerFunc { once := new(int64) - timer := time.NewTimer(time.Duration(server.options.Timeout) * time.Second) - if server.options.Timeout > 0 { - go func() { - select { - case <-timer.C: - cancel() - case <-ctx.Done(): - } - }() - } + go func() { + select { + case <-counter.timer().C: + cancel() + case <-ctx.Done(): + } + }() return func(w http.ResponseWriter, r *http.Request) { if server.options.Once { success := atomic.CompareAndSwapInt64(once, 0, 1) if !success { + http.Error(w, "Server is shutting down", http.StatusServiceUnavailable) return } } - if server.options.Timeout > 0 { - timer.Stop() - } - wg.Add(1) - num := atomic.AddInt64(connections, 1) + num := counter.add(1) closeReason := "unknown reason" defer func() { - num := atomic.AddInt64(connections, -1) - if num == 0 && server.options.Timeout > 0 { - timer.Reset(time.Duration(server.options.Timeout) * time.Second) - } - + num := counter.done() log.Printf( "Connection closed by %s: %s, connections: %d/%d", closeReason, r.RemoteAddr, num, server.options.MaxConnection, @@ -62,12 +50,10 @@ func (server *Server) generateHandleWS(ctx context.Context, cancel context.Cance if server.options.Once { cancel() } - - wg.Done() }() if int64(server.options.MaxConnection) != 0 { - if num > int64(server.options.MaxConnection) { + if num > server.options.MaxConnection { closeReason = "exceeding max number of connections" return } diff --git a/server/server.go b/server/server.go index 3b2c454..48d752f 100644 --- a/server/server.go +++ b/server/server.go @@ -10,9 +10,8 @@ import ( "net" "net/http" "net/url" - "sync" - "sync/atomic" noesctmpl "text/template" + "time" "github.com/elazarl/go-bindata-assetfs" "github.com/gorilla/websocket" @@ -81,12 +80,9 @@ func (server *Server) Run(ctx context.Context, options ...RunOption) error { opt(opts) } - // wg and connections can be incosistent because they are handled nonatomically - wg := new(sync.WaitGroup) // to wait all connections to be closed - connections := new(int64) // number of active connections - + counter := newCounter(time.Duration(server.options.Timeout) * time.Second) url := server.setupURL() - handlers := server.setupHandlers(cctx, cancel, url, connections, wg) + handlers := server.setupHandlers(cctx, cancel, url, counter) srv, err := server.setupHTTPServer(handlers, url) if err != nil { return errors.Wrapf(err, "failed to setup an HTTP server") @@ -138,11 +134,11 @@ func (server *Server) Run(ctx context.Context, options ...RunOption) error { err = cctx.Err() } - conn := atomic.LoadInt64(connections) + conn := counter.count() if conn > 0 { log.Printf("Waiting for %d connections to be closed", conn) } - wg.Wait() + counter.wait() return err } @@ -162,7 +158,7 @@ func (server *Server) setupURL() *url.URL { return &url.URL{Scheme: scheme, Host: host, Path: path} } -func (server *Server) setupHandlers(ctx context.Context, cancel context.CancelFunc, url *url.URL, connections *int64, wg *sync.WaitGroup) http.Handler { +func (server *Server) setupHandlers(ctx context.Context, cancel context.CancelFunc, url *url.URL, counter *counter) http.Handler { staticFileHandler := http.FileServer( &assetfs.AssetFS{Asset: Asset, AssetDir: AssetDir, Prefix: "static"}, ) @@ -184,7 +180,7 @@ func (server *Server) setupHandlers(ctx context.Context, cancel context.CancelFu wsMux := http.NewServeMux() wsMux.Handle("/", siteHandler) - wsMux.HandleFunc(url.Path+"ws", server.generateHandleWS(ctx, cancel, connections, wg)) + wsMux.HandleFunc(url.Path+"ws", server.generateHandleWS(ctx, cancel, counter)) siteHandler = http.Handler(wsMux) return server.wrapLogger(siteHandler)