Add new option to allow cross origin requests to WS endpoint

This commit is contained in:
Iwasaki Yudai 2017-08-13 15:09:22 +09:00
parent 84ec13ca19
commit 6765efbd61
3 changed files with 15 additions and 2 deletions

View File

@ -31,7 +31,6 @@ func (server *Server) generateHandleWS(ctx context.Context, cancel context.Cance
if server.options.Once { if server.options.Once {
success := atomic.CompareAndSwapInt64(once, 0, 1) success := atomic.CompareAndSwapInt64(once, 0, 1)
if !success { if !success {
http.Error(w, "Server is shutting down", http.StatusServiceUnavailable) http.Error(w, "Server is shutting down", http.StatusServiceUnavailable)
return return
} }
@ -68,7 +67,7 @@ func (server *Server) generateHandleWS(ctx context.Context, cancel context.Cance
conn, err := server.upgrader.Upgrade(w, r, nil) conn, err := server.upgrader.Upgrade(w, r, nil)
if err != nil { if err != nil {
http.Error(w, "Failed to upgrade connection: "+err.Error(), 500) closeReason = fmt.Sprintf("origin check error: %s", r.Header.Get("Origin"))
return return
} }
defer conn.Close() defer conn.Close()

View File

@ -28,6 +28,7 @@ type Options struct {
Preferences *HtermPrefernces `hcl:"preferences"` Preferences *HtermPrefernces `hcl:"preferences"`
Width int `hcl:"width" flagName:"width" flagDescribe:"Static width of the screen, 0(default) means dynamically resize" default:"0"` Width int `hcl:"width" flagName:"width" flagDescribe:"Static width of the screen, 0(default) means dynamically resize" default:"0"`
Height int `hcl:"height" flagName:"height" flagDescribe:"Static height of the screen, 0(default) means dynamically resize" default:"0"` Height int `hcl:"height" flagName:"height" flagDescribe:"Static height of the screen, 0(default) means dynamically resize" default:"0"`
WSOrigin string `hcl:"ws_origin" flagName:"ws-origin" flagDescribe:"A regular expression that matches origin URLs to be accepted by WebSocket. No cross origin requests are acceptable by default" default:""`
TitleVariables map[string]interface{} TitleVariables map[string]interface{}
} }

View File

@ -10,6 +10,7 @@ import (
"net" "net"
"net/http" "net/http"
"net/url" "net/url"
"regexp"
noesctmpl "text/template" noesctmpl "text/template"
"time" "time"
@ -56,6 +57,17 @@ func New(factory Factory, options *Options) (*Server, error) {
return nil, errors.Wrapf(err, "failed to parse window title format `%s`", options.TitleFormat) return nil, errors.Wrapf(err, "failed to parse window title format `%s`", options.TitleFormat)
} }
var originChekcer func(r *http.Request) bool
if options.WSOrigin != "" {
matcher, err := regexp.Compile(options.WSOrigin)
if err != nil {
return nil, errors.Wrapf(err, "failed to compile regular expression of Websocket Origin: %s", options.WSOrigin)
}
originChekcer = func(r *http.Request) bool {
return matcher.MatchString(r.Header.Get("Origin"))
}
}
return &Server{ return &Server{
factory: factory, factory: factory,
options: options, options: options,
@ -64,6 +76,7 @@ func New(factory Factory, options *Options) (*Server, error) {
ReadBufferSize: 1024, ReadBufferSize: 1024,
WriteBufferSize: 1024, WriteBufferSize: 1024,
Subprotocols: webtty.Protocols, Subprotocols: webtty.Protocols,
CheckOrigin: originChekcer,
}, },
indexTemplate: indexTemplate, indexTemplate: indexTemplate,
titleTemplate: titleTemplate, titleTemplate: titleTemplate,