Added mutex to avoid concurrent writes

This commit is contained in:
Quentin Perez 2015-09-30 16:48:34 +02:00
parent 8f9d5ba582
commit 6500449916
2 changed files with 32 additions and 15 deletions

View File

@ -14,6 +14,7 @@ import (
"os/exec" "os/exec"
"strconv" "strconv"
"strings" "strings"
"sync"
"text/template" "text/template"
"github.com/braintree/manners" "github.com/braintree/manners"
@ -243,6 +244,7 @@ func (app *App) handleWS(w http.ResponseWriter, r *http.Request) {
connection: conn, connection: conn,
command: cmd, command: cmd,
pty: ptyIo, pty: ptyIo,
writeMutex: &sync.Mutex{},
} }
context.goHandleClient() context.goHandleClient()

View File

@ -8,6 +8,7 @@ import (
"os" "os"
"os/exec" "os/exec"
"strings" "strings"
"sync"
"syscall" "syscall"
"unsafe" "unsafe"
@ -20,6 +21,7 @@ type clientContext struct {
connection *websocket.Conn connection *websocket.Conn
command *exec.Cmd command *exec.Cmd
pty *os.File pty *os.File
writeMutex *sync.Mutex
} }
const ( const (
@ -96,19 +98,24 @@ func (context *clientContext) processSend() {
for { for {
size, err := context.pty.Read(buf) size, err := context.pty.Read(buf)
safeMessage := base64.StdEncoding.EncodeToString([]byte(buf[:size]))
if err != nil { if err != nil {
log.Printf("Command exited for: %s", context.request.RemoteAddr) log.Printf("Command exited for: %s", context.request.RemoteAddr)
return return
} }
safeMessage := base64.StdEncoding.EncodeToString([]byte(buf[:size]))
err = context.connection.WriteMessage(websocket.TextMessage, append([]byte{Output}, []byte(safeMessage)...)) if err = context.write(append([]byte{Output}, []byte(safeMessage)...)); err != nil {
if err != nil { log.Printf(err.Error())
return return
} }
} }
} }
func (context *clientContext) write(data []byte) error {
context.writeMutex.Lock()
defer context.writeMutex.Unlock()
return context.connection.WriteMessage(websocket.TextMessage, data)
}
func (context *clientContext) sendInitialize() error { func (context *clientContext) sendInitialize() error {
hostname, _ := os.Hostname() hostname, _ := os.Hostname()
titleVars := ContextVars{ titleVars := ContextVars{
@ -118,34 +125,34 @@ func (context *clientContext) sendInitialize() error {
RemoteAddr: context.request.RemoteAddr, RemoteAddr: context.request.RemoteAddr,
} }
context.writeMutex.Lock()
writer, err := context.connection.NextWriter(websocket.TextMessage) writer, err := context.connection.NextWriter(websocket.TextMessage)
if err != nil { if err != nil {
context.writeMutex.Unlock()
return err return err
} }
writer.Write([]byte{SetWindowTitle}) writer.Write([]byte{SetWindowTitle})
if err = context.app.titleTemplate.Execute(writer, titleVars); err != nil { if err = context.app.titleTemplate.Execute(writer, titleVars); err != nil {
context.writeMutex.Unlock()
return err return err
} }
writer.Close() writer.Close()
context.writeMutex.Unlock()
htermPrefs := make(map[string]interface{}) htermPrefs := make(map[string]interface{})
for key, value := range context.app.options.Preferences { for key, value := range context.app.options.Preferences {
htermPrefs[strings.Replace(key, "_", "-", -1)] = value htermPrefs[strings.Replace(key, "_", "-", -1)] = value
} }
prefs, _ := json.Marshal(htermPrefs) prefs, _ := json.Marshal(htermPrefs)
context.connection.WriteMessage( if err := context.write(append([]byte{SetPreferences}, prefs...)); err != nil {
websocket.TextMessage, return err
append([]byte{SetPreferences}, prefs...), }
)
if context.app.options.EnableReconnect { if context.app.options.EnableReconnect {
reconnect, _ := json.Marshal(context.app.options.ReconnectTime) reconnect, _ := json.Marshal(context.app.options.ReconnectTime)
context.connection.WriteMessage( if err := context.write(append([]byte{SetReconnect}, reconnect...)); err != nil {
websocket.TextMessage, return err
append([]byte{SetReconnect}, reconnect...), }
)
} }
return nil return nil
} }
@ -153,6 +160,11 @@ func (context *clientContext) processReceive() {
for { for {
_, data, err := context.connection.ReadMessage() _, data, err := context.connection.ReadMessage()
if err != nil { if err != nil {
log.Print(err.Error())
return
}
if len(data) == 0 {
log.Print("An error has occured")
return return
} }
@ -168,7 +180,10 @@ func (context *clientContext) processReceive() {
} }
case Ping: case Ping:
context.connection.WriteMessage(websocket.TextMessage, []byte{Pong}) if err := context.write([]byte{Pong}); err != nil {
log.Print(err.Error())
return
}
case ResizeTerminal: case ResizeTerminal:
var args argResizeTerminal var args argResizeTerminal
err = json.Unmarshal(data[1:], &args) err = json.Unmarshal(data[1:], &args)