mirror of
https://github.com/sorenisanerd/gotty.git
synced 2024-11-22 12:24:25 +00:00
Fix possible race condition on timeout
This commit is contained in:
parent
9b8d2d5ed5
commit
2a2a034788
70
server/handler_atomic.go
Normal file
70
server/handler_atomic.go
Normal file
@ -0,0 +1,70 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type counter struct {
|
||||||
|
duration time.Duration
|
||||||
|
zeroTimer *time.Timer
|
||||||
|
wg sync.WaitGroup
|
||||||
|
connections int
|
||||||
|
mutex sync.Mutex
|
||||||
|
}
|
||||||
|
|
||||||
|
func newCounter(duration time.Duration) *counter {
|
||||||
|
zeroTimer := time.NewTimer(duration)
|
||||||
|
|
||||||
|
// when duration is 0, drain the expire event here
|
||||||
|
// so that user will never get the event.
|
||||||
|
if duration == 0 {
|
||||||
|
<-zeroTimer.C
|
||||||
|
}
|
||||||
|
|
||||||
|
return &counter{
|
||||||
|
duration: duration,
|
||||||
|
zeroTimer: zeroTimer,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (counter *counter) add(n int) int {
|
||||||
|
counter.mutex.Lock()
|
||||||
|
defer counter.mutex.Unlock()
|
||||||
|
|
||||||
|
if counter.duration > 0 {
|
||||||
|
counter.zeroTimer.Stop()
|
||||||
|
}
|
||||||
|
counter.wg.Add(n)
|
||||||
|
counter.connections += n
|
||||||
|
|
||||||
|
return counter.connections
|
||||||
|
}
|
||||||
|
|
||||||
|
func (counter *counter) done() int {
|
||||||
|
counter.mutex.Lock()
|
||||||
|
defer counter.mutex.Unlock()
|
||||||
|
|
||||||
|
counter.connections--
|
||||||
|
counter.wg.Done()
|
||||||
|
if counter.connections == 0 && counter.duration > 0 {
|
||||||
|
counter.zeroTimer.Reset(counter.duration)
|
||||||
|
}
|
||||||
|
|
||||||
|
return counter.connections
|
||||||
|
}
|
||||||
|
|
||||||
|
func (counter *counter) count() int {
|
||||||
|
counter.mutex.Lock()
|
||||||
|
defer counter.mutex.Unlock()
|
||||||
|
|
||||||
|
return counter.connections
|
||||||
|
}
|
||||||
|
|
||||||
|
func (counter *counter) wait() {
|
||||||
|
counter.wg.Wait()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (counter *counter) timer() *time.Timer {
|
||||||
|
return counter.zeroTimer
|
||||||
|
}
|
@ -8,9 +8,7 @@ import (
|
|||||||
"log"
|
"log"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"sync"
|
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/gorilla/websocket"
|
"github.com/gorilla/websocket"
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
@ -18,42 +16,32 @@ import (
|
|||||||
"github.com/yudai/gotty/webtty"
|
"github.com/yudai/gotty/webtty"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (server *Server) generateHandleWS(ctx context.Context, cancel context.CancelFunc, connections *int64, wg *sync.WaitGroup) http.HandlerFunc {
|
func (server *Server) generateHandleWS(ctx context.Context, cancel context.CancelFunc, counter *counter) http.HandlerFunc {
|
||||||
once := new(int64)
|
once := new(int64)
|
||||||
|
|
||||||
timer := time.NewTimer(time.Duration(server.options.Timeout) * time.Second)
|
go func() {
|
||||||
if server.options.Timeout > 0 {
|
select {
|
||||||
go func() {
|
case <-counter.timer().C:
|
||||||
select {
|
cancel()
|
||||||
case <-timer.C:
|
case <-ctx.Done():
|
||||||
cancel()
|
}
|
||||||
case <-ctx.Done():
|
}()
|
||||||
}
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
|
|
||||||
return func(w http.ResponseWriter, r *http.Request) {
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
if server.options.Once {
|
if server.options.Once {
|
||||||
success := atomic.CompareAndSwapInt64(once, 0, 1)
|
success := atomic.CompareAndSwapInt64(once, 0, 1)
|
||||||
if !success {
|
if !success {
|
||||||
|
|
||||||
http.Error(w, "Server is shutting down", http.StatusServiceUnavailable)
|
http.Error(w, "Server is shutting down", http.StatusServiceUnavailable)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if server.options.Timeout > 0 {
|
num := counter.add(1)
|
||||||
timer.Stop()
|
|
||||||
}
|
|
||||||
wg.Add(1)
|
|
||||||
num := atomic.AddInt64(connections, 1)
|
|
||||||
closeReason := "unknown reason"
|
closeReason := "unknown reason"
|
||||||
|
|
||||||
defer func() {
|
defer func() {
|
||||||
num := atomic.AddInt64(connections, -1)
|
num := counter.done()
|
||||||
if num == 0 && server.options.Timeout > 0 {
|
|
||||||
timer.Reset(time.Duration(server.options.Timeout) * time.Second)
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Printf(
|
log.Printf(
|
||||||
"Connection closed by %s: %s, connections: %d/%d",
|
"Connection closed by %s: %s, connections: %d/%d",
|
||||||
closeReason, r.RemoteAddr, num, server.options.MaxConnection,
|
closeReason, r.RemoteAddr, num, server.options.MaxConnection,
|
||||||
@ -62,12 +50,10 @@ func (server *Server) generateHandleWS(ctx context.Context, cancel context.Cance
|
|||||||
if server.options.Once {
|
if server.options.Once {
|
||||||
cancel()
|
cancel()
|
||||||
}
|
}
|
||||||
|
|
||||||
wg.Done()
|
|
||||||
}()
|
}()
|
||||||
|
|
||||||
if int64(server.options.MaxConnection) != 0 {
|
if int64(server.options.MaxConnection) != 0 {
|
||||||
if num > int64(server.options.MaxConnection) {
|
if num > server.options.MaxConnection {
|
||||||
closeReason = "exceeding max number of connections"
|
closeReason = "exceeding max number of connections"
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -10,9 +10,8 @@ import (
|
|||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"sync"
|
|
||||||
"sync/atomic"
|
|
||||||
noesctmpl "text/template"
|
noesctmpl "text/template"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/elazarl/go-bindata-assetfs"
|
"github.com/elazarl/go-bindata-assetfs"
|
||||||
"github.com/gorilla/websocket"
|
"github.com/gorilla/websocket"
|
||||||
@ -81,12 +80,9 @@ func (server *Server) Run(ctx context.Context, options ...RunOption) error {
|
|||||||
opt(opts)
|
opt(opts)
|
||||||
}
|
}
|
||||||
|
|
||||||
// wg and connections can be incosistent because they are handled nonatomically
|
counter := newCounter(time.Duration(server.options.Timeout) * time.Second)
|
||||||
wg := new(sync.WaitGroup) // to wait all connections to be closed
|
|
||||||
connections := new(int64) // number of active connections
|
|
||||||
|
|
||||||
url := server.setupURL()
|
url := server.setupURL()
|
||||||
handlers := server.setupHandlers(cctx, cancel, url, connections, wg)
|
handlers := server.setupHandlers(cctx, cancel, url, counter)
|
||||||
srv, err := server.setupHTTPServer(handlers, url)
|
srv, err := server.setupHTTPServer(handlers, url)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errors.Wrapf(err, "failed to setup an HTTP server")
|
return errors.Wrapf(err, "failed to setup an HTTP server")
|
||||||
@ -138,11 +134,11 @@ func (server *Server) Run(ctx context.Context, options ...RunOption) error {
|
|||||||
err = cctx.Err()
|
err = cctx.Err()
|
||||||
}
|
}
|
||||||
|
|
||||||
conn := atomic.LoadInt64(connections)
|
conn := counter.count()
|
||||||
if conn > 0 {
|
if conn > 0 {
|
||||||
log.Printf("Waiting for %d connections to be closed", conn)
|
log.Printf("Waiting for %d connections to be closed", conn)
|
||||||
}
|
}
|
||||||
wg.Wait()
|
counter.wait()
|
||||||
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -162,7 +158,7 @@ func (server *Server) setupURL() *url.URL {
|
|||||||
return &url.URL{Scheme: scheme, Host: host, Path: path}
|
return &url.URL{Scheme: scheme, Host: host, Path: path}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (server *Server) setupHandlers(ctx context.Context, cancel context.CancelFunc, url *url.URL, connections *int64, wg *sync.WaitGroup) http.Handler {
|
func (server *Server) setupHandlers(ctx context.Context, cancel context.CancelFunc, url *url.URL, counter *counter) http.Handler {
|
||||||
staticFileHandler := http.FileServer(
|
staticFileHandler := http.FileServer(
|
||||||
&assetfs.AssetFS{Asset: Asset, AssetDir: AssetDir, Prefix: "static"},
|
&assetfs.AssetFS{Asset: Asset, AssetDir: AssetDir, Prefix: "static"},
|
||||||
)
|
)
|
||||||
@ -184,7 +180,7 @@ func (server *Server) setupHandlers(ctx context.Context, cancel context.CancelFu
|
|||||||
|
|
||||||
wsMux := http.NewServeMux()
|
wsMux := http.NewServeMux()
|
||||||
wsMux.Handle("/", siteHandler)
|
wsMux.Handle("/", siteHandler)
|
||||||
wsMux.HandleFunc(url.Path+"ws", server.generateHandleWS(ctx, cancel, connections, wg))
|
wsMux.HandleFunc(url.Path+"ws", server.generateHandleWS(ctx, cancel, counter))
|
||||||
siteHandler = http.Handler(wsMux)
|
siteHandler = http.Handler(wsMux)
|
||||||
|
|
||||||
return server.wrapLogger(siteHandler)
|
return server.wrapLogger(siteHandler)
|
||||||
|
Loading…
Reference in New Issue
Block a user