From 650044991651008c83ef2317c12ae36648e1d148 Mon Sep 17 00:00:00 2001 From: Quentin Perez Date: Wed, 30 Sep 2015 16:48:34 +0200 Subject: [PATCH] Added mutex to avoid concurrent writes --- app/app.go | 2 ++ app/client_context.go | 45 ++++++++++++++++++++++++++++--------------- 2 files changed, 32 insertions(+), 15 deletions(-) diff --git a/app/app.go b/app/app.go index 790eee6..314566c 100644 --- a/app/app.go +++ b/app/app.go @@ -14,6 +14,7 @@ import ( "os/exec" "strconv" "strings" + "sync" "text/template" "github.com/braintree/manners" @@ -243,6 +244,7 @@ func (app *App) handleWS(w http.ResponseWriter, r *http.Request) { connection: conn, command: cmd, pty: ptyIo, + writeMutex: &sync.Mutex{}, } context.goHandleClient() diff --git a/app/client_context.go b/app/client_context.go index 63ad67e..ca38d23 100644 --- a/app/client_context.go +++ b/app/client_context.go @@ -8,6 +8,7 @@ import ( "os" "os/exec" "strings" + "sync" "syscall" "unsafe" @@ -20,6 +21,7 @@ type clientContext struct { connection *websocket.Conn command *exec.Cmd pty *os.File + writeMutex *sync.Mutex } const ( @@ -96,19 +98,24 @@ func (context *clientContext) processSend() { for { size, err := context.pty.Read(buf) - safeMessage := base64.StdEncoding.EncodeToString([]byte(buf[:size])) if err != nil { log.Printf("Command exited for: %s", context.request.RemoteAddr) return } - - err = context.connection.WriteMessage(websocket.TextMessage, append([]byte{Output}, []byte(safeMessage)...)) - if err != nil { + safeMessage := base64.StdEncoding.EncodeToString([]byte(buf[:size])) + if err = context.write(append([]byte{Output}, []byte(safeMessage)...)); err != nil { + log.Printf(err.Error()) return } } } +func (context *clientContext) write(data []byte) error { + context.writeMutex.Lock() + defer context.writeMutex.Unlock() + return context.connection.WriteMessage(websocket.TextMessage, data) +} + func (context *clientContext) sendInitialize() error { hostname, _ := os.Hostname() titleVars := ContextVars{ @@ -118,34 +125,34 @@ func (context *clientContext) sendInitialize() error { RemoteAddr: context.request.RemoteAddr, } + context.writeMutex.Lock() writer, err := context.connection.NextWriter(websocket.TextMessage) if err != nil { + context.writeMutex.Unlock() return err } writer.Write([]byte{SetWindowTitle}) if err = context.app.titleTemplate.Execute(writer, titleVars); err != nil { + context.writeMutex.Unlock() return err } writer.Close() + context.writeMutex.Unlock() htermPrefs := make(map[string]interface{}) for key, value := range context.app.options.Preferences { htermPrefs[strings.Replace(key, "_", "-", -1)] = value } prefs, _ := json.Marshal(htermPrefs) - context.connection.WriteMessage( - websocket.TextMessage, - append([]byte{SetPreferences}, prefs...), - ) - + if err := context.write(append([]byte{SetPreferences}, prefs...)); err != nil { + return err + } if context.app.options.EnableReconnect { reconnect, _ := json.Marshal(context.app.options.ReconnectTime) - context.connection.WriteMessage( - websocket.TextMessage, - append([]byte{SetReconnect}, reconnect...), - ) + if err := context.write(append([]byte{SetReconnect}, reconnect...)); err != nil { + return err + } } - return nil } @@ -153,6 +160,11 @@ func (context *clientContext) processReceive() { for { _, data, err := context.connection.ReadMessage() if err != nil { + log.Print(err.Error()) + return + } + if len(data) == 0 { + log.Print("An error has occured") return } @@ -168,7 +180,10 @@ func (context *clientContext) processReceive() { } case Ping: - context.connection.WriteMessage(websocket.TextMessage, []byte{Pong}) + if err := context.write([]byte{Pong}); err != nil { + log.Print(err.Error()) + return + } case ResizeTerminal: var args argResizeTerminal err = json.Unmarshal(data[1:], &args)