diff --git a/app/app.go b/app/app.go index 660c9ad..84f3d11 100644 --- a/app/app.go +++ b/app/app.go @@ -17,6 +17,7 @@ import ( "strings" "text/template" + "github.com/braintree/manners" "github.com/elazarl/go-bindata-assetfs" "github.com/gorilla/websocket" "github.com/kr/pty" @@ -26,6 +27,7 @@ type App struct { options Options upgrader *websocket.Upgrader + server *manners.GracefulServer preferences map[string]interface{} titleTemplate *template.Template @@ -157,16 +159,21 @@ func (app *App) Run() error { } var err error + app.server = manners.NewWithServer( + &http.Server{Addr: endpoint, Handler: siteHandler}, + ) if app.options.EnableTLS { cert, key := app.loadTLSFiles() - err = http.ListenAndServeTLS(endpoint, cert, key, siteHandler) + err = app.server.ListenAndServeTLS(cert, key) } else { - err = http.ListenAndServe(endpoint, siteHandler) + err = app.server.ListenAndServe() } if err != nil { return err } + log.Printf("Exiting...") + return nil } @@ -217,6 +224,14 @@ func (app *App) handleWS(w http.ResponseWriter, r *http.Request) { context.goHandleClient() } +func (app *App) Exit() (firstCall bool) { + if app.server != nil { + log.Printf("Received Exit command, waiting for all clients to close sessions...") + return app.server.Close() + } + return true +} + func wrapLogger(handler http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { log.Printf("%s %s", r.Method, r.URL.Path) diff --git a/app/client_context.go b/app/client_context.go index 5ddae7e..c36b948 100644 --- a/app/client_context.go +++ b/app/client_context.go @@ -61,7 +61,10 @@ func (context *clientContext) goHandleClient() { context.processReceive() }() + context.app.server.StartRoutine() go func() { + defer context.app.server.FinishRoutine() + <-exit context.pty.Close() context.command.Wait() diff --git a/main.go b/main.go index afa9fd5..88d9ce1 100644 --- a/main.go +++ b/main.go @@ -7,6 +7,8 @@ import ( "github.com/codegangsta/cli" "github.com/yudai/gotty/app" + "os/signal" + "syscall" ) func main() { @@ -107,6 +109,8 @@ func main() { os.Exit(2) } + registerSignals(app) + err = app.Run() if err != nil { fmt.Println(err) @@ -118,3 +122,24 @@ func main() { cmd.Run(os.Args) } + +func registerSignals(app *app.App) { + sigChan := make(chan os.Signal, 1) + signal.Notify( + sigChan, + syscall.SIGINT, + syscall.SIGTERM, + ) + + go func() { + for { + s := <-sigChan + switch s { + case syscall.SIGINT, syscall.SIGTERM: + if !app.Exit() { + os.Exit(4) + } + } + } + }() +}