diff --git a/app/app.go b/app/app.go index fe62ec7..a7335ab 100644 --- a/app/app.go +++ b/app/app.go @@ -3,21 +3,16 @@ package app import ( "crypto/rand" "encoding/base64" - "encoding/json" "log" "math/big" "net/http" - "os" "os/exec" "strconv" "strings" - "syscall" - "unsafe" "github.com/elazarl/go-bindata-assetfs" "github.com/gorilla/websocket" "github.com/kr/pty" - "github.com/yudai/utf8reader" ) type App struct { @@ -126,108 +121,14 @@ func (app *App) handler(w http.ResponseWriter, r *http.Request) { log.Printf("Command is running for client %s with PID %d", r.RemoteAddr, cmd.Process.Pid) context := &clientContext{ + app: app, request: r, connection: conn, command: cmd, pty: ptyIo, } - app.goHandleConnection(context) -} - -func (app *App) goHandleConnection(context *clientContext) { - exit := make(chan bool, 2) - - go func() { - defer func() { exit <- true }() - - app.processSend(context) - }() - - go func() { - defer func() { exit <- true }() - - app.processReceive(context) - }() - - go func() { - <-exit - context.command.Wait() - context.connection.Close() - log.Printf("Connection closed: %s", context.request.RemoteAddr) - }() -} - -func (app *App) processSend(context *clientContext) { - buf := make([]byte, 1024) - utf8f := utf8reader.New(context.pty) - - for { - size, err := utf8f.Read(buf) - if err != nil { - log.Printf("Command exited for: %s", context.request.RemoteAddr) - return - } - - writer, err := context.connection.NextWriter(websocket.TextMessage) - if err != nil { - return - } - - writer.Write(buf[:size]) - writer.Close() - } -} - -func (app *App) processReceive(context *clientContext) { - for { - _, data, err := context.connection.ReadMessage() - if err != nil { - return - } - - switch data[0] { - case Input: - if !app.options.PermitWrite { - break - } - - _, err := context.pty.Write(data[1:]) - if err != nil { - return - } - - case ResizeTerminal: - var args argResizeTerminal - err = json.Unmarshal(data[1:], &args) - if err != nil { - log.Print("Malformed remote command") - return - } - - window := struct { - row uint16 - col uint16 - x uint16 - y uint16 - }{ - uint16(args.Rows), - uint16(args.Columns), - 0, - 0, - } - syscall.Syscall( - syscall.SYS_IOCTL, - context.pty.Fd(), - syscall.TIOCSWINSZ, - uintptr(unsafe.Pointer(&window)), - ) - - default: - log.Print("Unknown message type") - return - } - } + context.goHandleClient() } func generateRandomString(length int) string { @@ -240,20 +141,3 @@ func generateRandomString(length int) string { } return string(n) } - -const ( - Input = '0' - ResizeTerminal = '1' -) - -type argResizeTerminal struct { - Columns float64 - Rows float64 -} - -type clientContext struct { - request *http.Request - connection *websocket.Conn - command *exec.Cmd - pty *os.File -} diff --git a/app/client_context.go b/app/client_context.go new file mode 100644 index 0000000..c66e093 --- /dev/null +++ b/app/client_context.go @@ -0,0 +1,127 @@ +package app + +import ( + "encoding/json" + "log" + "net/http" + "os" + "os/exec" + "syscall" + "unsafe" + + "github.com/gorilla/websocket" + "github.com/yudai/utf8reader" +) + +type clientContext struct { + app *App + request *http.Request + connection *websocket.Conn + command *exec.Cmd + pty *os.File +} + +const ( + Input = '0' + ResizeTerminal = '1' +) + +type argResizeTerminal struct { + Columns float64 + Rows float64 +} + +func (context *clientContext) goHandleClient() { + exit := make(chan bool, 2) + + go func() { + defer func() { exit <- true }() + + context.processSend() + }() + + go func() { + defer func() { exit <- true }() + + context.processReceive() + }() + + go func() { + <-exit + context.command.Wait() + context.connection.Close() + log.Printf("Connection closed: %s", context.request.RemoteAddr) + }() +} + +func (context *clientContext) processSend() { + buf := make([]byte, 1024) + utf8f := utf8reader.New(context.pty) + + for { + size, err := utf8f.Read(buf) + if err != nil { + log.Printf("Command exited for: %s", context.request.RemoteAddr) + return + } + + writer, err := context.connection.NextWriter(websocket.TextMessage) + if err != nil { + return + } + + writer.Write(buf[:size]) + writer.Close() + } +} + +func (context *clientContext) processReceive() { + for { + _, data, err := context.connection.ReadMessage() + if err != nil { + return + } + + switch data[0] { + case Input: + if !context.app.options.PermitWrite { + break + } + + _, err := context.pty.Write(data[1:]) + if err != nil { + return + } + + case ResizeTerminal: + var args argResizeTerminal + err = json.Unmarshal(data[1:], &args) + if err != nil { + log.Print("Malformed remote command") + return + } + + window := struct { + row uint16 + col uint16 + x uint16 + y uint16 + }{ + uint16(args.Rows), + uint16(args.Columns), + 0, + 0, + } + syscall.Syscall( + syscall.SYS_IOCTL, + context.pty.Fd(), + syscall.TIOCSWINSZ, + uintptr(unsafe.Pointer(&window)), + ) + + default: + log.Print("Unknown message type") + return + } + } +}