diff --git a/server/handlers.go b/server/handlers.go index 61b06dc..2ca51a1 100644 --- a/server/handlers.go +++ b/server/handlers.go @@ -8,7 +8,9 @@ import ( "log" "net/http" "net/url" + "sync" "sync/atomic" + "time" "github.com/gorilla/websocket" "github.com/pkg/errors" @@ -16,46 +18,63 @@ import ( "github.com/yudai/gotty/webtty" ) -func (server *Server) generateHandleWS(ctx context.Context, cancel context.CancelFunc) http.HandlerFunc { +func (server *Server) generateHandleWS(ctx context.Context, cancel context.CancelFunc, connections *int64, wg *sync.WaitGroup) http.HandlerFunc { + once := new(int64) + + timer := time.NewTimer(time.Duration(server.options.Timeout) * time.Second) + if server.options.Timeout > 0 { + go func() { + select { + case <-timer.C: + cancel() + case <-ctx.Done(): + } + }() + } + return func(w http.ResponseWriter, r *http.Request) { if server.options.Once { - if atomic.LoadInt64(server.once) > 0 { + success := atomic.CompareAndSwapInt64(once, 0, 1) + if !success { http.Error(w, "Server is shutting down", http.StatusServiceUnavailable) return } - atomic.AddInt64(server.once, 1) } - connections := atomic.AddInt64(server.connections, 1) - server.wsWG.Add(1) - server.stopTimer() + + if server.options.Timeout > 0 { + timer.Stop() + } + wg.Add(1) + num := atomic.AddInt64(connections, 1) closeReason := "unknown reason" defer func() { - server.wsWG.Done() - - connections := atomic.AddInt64(server.connections, -1) - if connections == 0 { - server.resetTimer() + num := atomic.AddInt64(connections, -1) + if num == 0 && server.options.Timeout > 0 { + timer.Reset(time.Duration(server.options.Timeout) * time.Second) } log.Printf( "Connection closed by %s: %s, connections: %d/%d", - closeReason, r.RemoteAddr, connections, server.options.MaxConnection, + closeReason, r.RemoteAddr, num, server.options.MaxConnection, ) if server.options.Once { cancel() } + + wg.Done() }() - log.Printf("New client connected: %s", r.RemoteAddr) if int64(server.options.MaxConnection) != 0 { - if connections > int64(server.options.MaxConnection) { + if num > int64(server.options.MaxConnection) { closeReason = "exceeding max number of connections" return } } + log.Printf("New client connected: %s, connections: %d/%d", r.RemoteAddr, num, server.options.MaxConnection) + if r.Method != "GET" { http.Error(w, "Method not allowed", 405) return diff --git a/server/server.go b/server/server.go index a29fe6b..3b2c454 100644 --- a/server/server.go +++ b/server/server.go @@ -13,7 +13,6 @@ import ( "sync" "sync/atomic" noesctmpl "text/template" - "time" "github.com/elazarl/go-bindata-assetfs" "github.com/gorilla/websocket" @@ -29,18 +28,9 @@ type Server struct { factory Factory options *Options - srv *http.Server - - upgrader *websocket.Upgrader - + upgrader *websocket.Upgrader indexTemplate *template.Template titleTemplate *noesctmpl.Template - titleVars map[string]interface{} - timer *time.Timer - wsWG sync.WaitGroup - url *url.URL // use URL() - connections *int64 // Use atomic operations - once *int64 // use atomic operations } // New creates a new instance of Server. @@ -51,7 +41,6 @@ func New(factory Factory, options *Options) (*Server, error) { panic("index not found") // must be in bindata } if options.IndexFile != "" { - log.Printf("Using index file at " + options.IndexFile) path := homedir.Expand(options.IndexFile) indexData, err = ioutil.ReadFile(path) if err != nil { @@ -68,9 +57,6 @@ func New(factory Factory, options *Options) (*Server, error) { return nil, errors.Wrapf(err, "failed to parse window title format `%s`", options.TitleFormat) } - connections := int64(0) - once := int64(0) - return &Server{ factory: factory, options: options, @@ -82,8 +68,6 @@ func New(factory Factory, options *Options) (*Server, error) { }, indexTemplate: indexTemplate, titleTemplate: titleTemplate, - connections: &connections, - once: &once, }, nil } @@ -97,12 +81,18 @@ func (server *Server) Run(ctx context.Context, options ...RunOption) error { opt(opts) } - handlers := server.setupHandlers(cctx, cancel) - srv, err := server.setupHTTPServer(handlers) + // wg and connections can be incosistent because they are handled nonatomically + wg := new(sync.WaitGroup) // to wait all connections to be closed + connections := new(int64) // number of active connections + + url := server.setupURL() + handlers := server.setupHandlers(cctx, cancel, url, connections, wg) + srv, err := server.setupHTTPServer(handlers, url) if err != nil { return errors.Wrapf(err, "failed to setup an HTTP server") } + log.Printf("URL: %s", url.String()) if server.options.PermitWrite { log.Printf("Permitting clients to write input to the PTY.") } @@ -111,19 +101,6 @@ func (server *Server) Run(ctx context.Context, options ...RunOption) error { log.Printf("Once option is provided, accepting only one client") } - server.srv = srv - - if server.options.Timeout > 0 { - server.timer = time.NewTimer(time.Duration(server.options.Timeout) * time.Second) - go func() { - select { - case <-server.timer.C: - cancel() - case <-cctx.Done(): - } - }() - } - listenErr := make(chan error, 1) go func() { if server.options.EnableTLS { @@ -161,21 +138,35 @@ func (server *Server) Run(ctx context.Context, options ...RunOption) error { err = cctx.Err() } - conn := atomic.LoadInt64(server.connections) + conn := atomic.LoadInt64(connections) if conn > 0 { log.Printf("Waiting for %d connections to be closed", conn) } - server.wsWG.Wait() + wg.Wait() return err } -func (server *Server) setupHandlers(ctx context.Context, cancel context.CancelFunc) http.Handler { +func (server *Server) setupURL() *url.URL { + host := net.JoinHostPort(server.options.Address, server.options.Port) + scheme := "http" + path := "/" + + if server.options.EnableRandomUrl { + path = "/" + randomstring.Generate(server.options.RandomUrlLength) + "/" + } + + if server.options.EnableTLS { + scheme = "https" + } + return &url.URL{Scheme: scheme, Host: host, Path: path} +} + +func (server *Server) setupHandlers(ctx context.Context, cancel context.CancelFunc, url *url.URL, connections *int64, wg *sync.WaitGroup) http.Handler { staticFileHandler := http.FileServer( &assetfs.AssetFS{Asset: Asset, AssetDir: AssetDir, Prefix: "static"}, ) - url := server.URL() var siteMux = http.NewServeMux() siteMux.HandleFunc(url.Path, server.handleIndex) siteMux.Handle(url.Path+"js/", http.StripPrefix(url.Path, staticFileHandler)) @@ -193,16 +184,13 @@ func (server *Server) setupHandlers(ctx context.Context, cancel context.CancelFu wsMux := http.NewServeMux() wsMux.Handle("/", siteHandler) - wsMux.HandleFunc(url.Path+"ws", server.generateHandleWS(ctx, cancel)) + wsMux.HandleFunc(url.Path+"ws", server.generateHandleWS(ctx, cancel, connections, wg)) siteHandler = http.Handler(wsMux) return server.wrapLogger(siteHandler) } -func (server *Server) setupHTTPServer(handler http.Handler) (*http.Server, error) { - url := server.URL() - log.Printf("URL: %s", url.String()) - +func (server *Server) setupHTTPServer(handler http.Handler, url *url.URL) (*http.Server, error) { srv := &http.Server{ Addr: url.Host, Handler: handler, @@ -219,22 +207,6 @@ func (server *Server) setupHTTPServer(handler http.Handler) (*http.Server, error return srv, nil } -func (server *Server) URL() *url.URL { - if server.url == nil { - host := net.JoinHostPort(server.options.Address, server.options.Port) - path := "" - if server.options.EnableRandomUrl { - path += "/" + randomstring.Generate(server.options.RandomUrlLength) - } - scheme := "http" - if server.options.EnableTLS { - scheme = "https" - } - server.url = &url.URL{Scheme: scheme, Host: host, Path: path + "/"} - } - return server.url -} - func (server *Server) tlsConfig() (*tls.Config, error) { caFile := homedir.Expand(server.options.TLSCACrtFile) caCert, err := ioutil.ReadFile(caFile) diff --git a/server/timer.go b/server/timer.go deleted file mode 100644 index 7a522da..0000000 --- a/server/timer.go +++ /dev/null @@ -1,17 +0,0 @@ -package server - -import ( - "time" -) - -func (server *Server) stopTimer() { - if server.options.Timeout > 0 { - server.timer.Stop() - } -} - -func (server *Server) resetTimer() { - if server.options.Timeout > 0 { - server.timer.Reset(time.Duration(server.options.Timeout) * time.Second) - } -}