diff --git a/backend/localcommand/factory.go b/backend/localcommand/factory.go index 9d0c916..be6da37 100644 --- a/backend/localcommand/factory.go +++ b/backend/localcommand/factory.go @@ -37,12 +37,12 @@ func (factory *Factory) Name() string { return "local command" } -func (factory *Factory) New(params map[string][]string) (server.Slave, error) { +func (factory *Factory) New(params map[string][]string, headers map[string][]string) (server.Slave, error) { argv := make([]string, len(factory.argv)) copy(argv, factory.argv) if params["arg"] != nil && len(params["arg"]) > 0 { argv = append(argv, params["arg"]...) } - return New(factory.command, argv, factory.opts...) + return New(factory.command, argv, headers, factory.opts...) } diff --git a/backend/localcommand/local_command.go b/backend/localcommand/local_command.go index 71b6c18..d0e0753 100644 --- a/backend/localcommand/local_command.go +++ b/backend/localcommand/local_command.go @@ -3,6 +3,7 @@ package localcommand import ( "os" "os/exec" + "strings" "syscall" "time" @@ -27,11 +28,22 @@ type LocalCommand struct { ptyClosed chan struct{} } -func New(command string, argv []string, options ...Option) (*LocalCommand, error) { +func New(command string, argv []string, headers map[string][]string, options ...Option) (*LocalCommand, error) { cmd := exec.Command(command, argv...) cmd.Env = append(os.Environ(), "TERM=xterm-256color") + // Combine headers into key=value pairs to set as env vars + // Prefix the headers with "http_" so we don't overwrite any other env vars + // which potentially has the same name and to bring these closer to what + // a (F)CGI server would proxy to a backend service + // Replace hyphen with underscore and make them all upper case + for key, values := range headers { + h := "HTTP_" + strings.Replace(strings.ToUpper(key), "-", "_", -1) + "=" + strings.Join(values, ",") + // log.Printf("Adding header: %s", h) + cmd.Env = append(cmd.Env, h) + } + pty, err := pty.Start(cmd) if err != nil { // todo close cmd? diff --git a/backend/localcommand/local_command_test.go b/backend/localcommand/local_command_test.go index ddda029..220cb73 100644 --- a/backend/localcommand/local_command_test.go +++ b/backend/localcommand/local_command_test.go @@ -23,7 +23,7 @@ func TestNewFactory(t *testing.T) { t.Errorf("factory.options = %v, expected %v", factory.options, &Options{}) } - slave, _ := factory.New(nil) + slave, _ := factory.New(nil, nil) lcmd := slave.(*LocalCommand) if lcmd.closeSignal != 123 { t.Errorf("lcmd.closeSignal = %v, expected %v", lcmd.closeSignal, 123) @@ -40,7 +40,7 @@ func TestFactoryNew(t *testing.T) { return } - slave, err := factory.New(nil) + slave, err := factory.New(nil, nil) if err != nil { t.Errorf("factory.New() returned error") return diff --git a/server/handlers.go b/server/handlers.go index 5347ab6..a0acec1 100644 --- a/server/handlers.go +++ b/server/handlers.go @@ -72,7 +72,7 @@ func (server *Server) generateHandleWS(ctx context.Context, cancel context.Cance } defer conn.Close() - err = server.processWSConn(ctx, conn) + err = server.processWSConn(ctx, conn, r.Header) switch err { case ctx.Err(): @@ -87,7 +87,7 @@ func (server *Server) generateHandleWS(ctx context.Context, cancel context.Cance } } -func (server *Server) processWSConn(ctx context.Context, conn *websocket.Conn) error { +func (server *Server) processWSConn(ctx context.Context, conn *websocket.Conn, headers map[string][]string) error { typ, initLine, err := conn.ReadMessage() if err != nil { return errors.Wrapf(err, "failed to authenticate websocket connection") @@ -116,7 +116,7 @@ func (server *Server) processWSConn(ctx context.Context, conn *websocket.Conn) e } params := query.Query() var slave Slave - slave, err = server.factory.New(params) + slave, err = server.factory.New(params, headers) if err != nil { return errors.Wrapf(err, "failed to create backend") } diff --git a/server/slave.go b/server/slave.go index 52cd9fe..db9d731 100644 --- a/server/slave.go +++ b/server/slave.go @@ -13,5 +13,5 @@ type Slave interface { type Factory interface { Name() string - New(params map[string][]string) (Slave, error) + New(params map[string][]string, headers map[string][]string) (Slave, error) }