From 6765efbd6148008fb37306dc0dcb30fd7b405811 Mon Sep 17 00:00:00 2001 From: Iwasaki Yudai Date: Sun, 13 Aug 2017 15:09:22 +0900 Subject: [PATCH] Add new option to allow cross origin requests to WS endpoint --- server/handlers.go | 3 +-- server/options.go | 1 + server/server.go | 13 +++++++++++++ 3 files changed, 15 insertions(+), 2 deletions(-) diff --git a/server/handlers.go b/server/handlers.go index 90e6429..1a48980 100644 --- a/server/handlers.go +++ b/server/handlers.go @@ -31,7 +31,6 @@ func (server *Server) generateHandleWS(ctx context.Context, cancel context.Cance if server.options.Once { success := atomic.CompareAndSwapInt64(once, 0, 1) if !success { - http.Error(w, "Server is shutting down", http.StatusServiceUnavailable) return } @@ -68,7 +67,7 @@ func (server *Server) generateHandleWS(ctx context.Context, cancel context.Cance conn, err := server.upgrader.Upgrade(w, r, nil) if err != nil { - http.Error(w, "Failed to upgrade connection: "+err.Error(), 500) + closeReason = fmt.Sprintf("origin check error: %s", r.Header.Get("Origin")) return } defer conn.Close() diff --git a/server/options.go b/server/options.go index 977c00d..897bcb2 100644 --- a/server/options.go +++ b/server/options.go @@ -28,6 +28,7 @@ type Options struct { Preferences *HtermPrefernces `hcl:"preferences"` Width int `hcl:"width" flagName:"width" flagDescribe:"Static width of the screen, 0(default) means dynamically resize" default:"0"` Height int `hcl:"height" flagName:"height" flagDescribe:"Static height of the screen, 0(default) means dynamically resize" default:"0"` + WSOrigin string `hcl:"ws_origin" flagName:"ws-origin" flagDescribe:"A regular expression that matches origin URLs to be accepted by WebSocket. No cross origin requests are acceptable by default" default:""` TitleVariables map[string]interface{} } diff --git a/server/server.go b/server/server.go index cc1cb4b..1442603 100644 --- a/server/server.go +++ b/server/server.go @@ -10,6 +10,7 @@ import ( "net" "net/http" "net/url" + "regexp" noesctmpl "text/template" "time" @@ -56,6 +57,17 @@ func New(factory Factory, options *Options) (*Server, error) { return nil, errors.Wrapf(err, "failed to parse window title format `%s`", options.TitleFormat) } + var originChekcer func(r *http.Request) bool + if options.WSOrigin != "" { + matcher, err := regexp.Compile(options.WSOrigin) + if err != nil { + return nil, errors.Wrapf(err, "failed to compile regular expression of Websocket Origin: %s", options.WSOrigin) + } + originChekcer = func(r *http.Request) bool { + return matcher.MatchString(r.Header.Get("Origin")) + } + } + return &Server{ factory: factory, options: options, @@ -64,6 +76,7 @@ func New(factory Factory, options *Options) (*Server, error) { ReadBufferSize: 1024, WriteBufferSize: 1024, Subprotocols: webtty.Protocols, + CheckOrigin: originChekcer, }, indexTemplate: indexTemplate, titleTemplate: titleTemplate,