mirror of
https://github.com/sorenisanerd/gotty.git
synced 2024-11-22 12:24:25 +00:00
Reduce struct variables of server.Server
This commit is contained in:
parent
21899e638b
commit
9b8d2d5ed5
@ -8,7 +8,9 @@ import (
|
|||||||
"log"
|
"log"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/gorilla/websocket"
|
"github.com/gorilla/websocket"
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
@ -16,46 +18,63 @@ import (
|
|||||||
"github.com/yudai/gotty/webtty"
|
"github.com/yudai/gotty/webtty"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (server *Server) generateHandleWS(ctx context.Context, cancel context.CancelFunc) http.HandlerFunc {
|
func (server *Server) generateHandleWS(ctx context.Context, cancel context.CancelFunc, connections *int64, wg *sync.WaitGroup) http.HandlerFunc {
|
||||||
|
once := new(int64)
|
||||||
|
|
||||||
|
timer := time.NewTimer(time.Duration(server.options.Timeout) * time.Second)
|
||||||
|
if server.options.Timeout > 0 {
|
||||||
|
go func() {
|
||||||
|
select {
|
||||||
|
case <-timer.C:
|
||||||
|
cancel()
|
||||||
|
case <-ctx.Done():
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
return func(w http.ResponseWriter, r *http.Request) {
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
if server.options.Once {
|
if server.options.Once {
|
||||||
if atomic.LoadInt64(server.once) > 0 {
|
success := atomic.CompareAndSwapInt64(once, 0, 1)
|
||||||
|
if !success {
|
||||||
http.Error(w, "Server is shutting down", http.StatusServiceUnavailable)
|
http.Error(w, "Server is shutting down", http.StatusServiceUnavailable)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
atomic.AddInt64(server.once, 1)
|
|
||||||
}
|
}
|
||||||
connections := atomic.AddInt64(server.connections, 1)
|
|
||||||
server.wsWG.Add(1)
|
if server.options.Timeout > 0 {
|
||||||
server.stopTimer()
|
timer.Stop()
|
||||||
|
}
|
||||||
|
wg.Add(1)
|
||||||
|
num := atomic.AddInt64(connections, 1)
|
||||||
closeReason := "unknown reason"
|
closeReason := "unknown reason"
|
||||||
|
|
||||||
defer func() {
|
defer func() {
|
||||||
server.wsWG.Done()
|
num := atomic.AddInt64(connections, -1)
|
||||||
|
if num == 0 && server.options.Timeout > 0 {
|
||||||
connections := atomic.AddInt64(server.connections, -1)
|
timer.Reset(time.Duration(server.options.Timeout) * time.Second)
|
||||||
if connections == 0 {
|
|
||||||
server.resetTimer()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Printf(
|
log.Printf(
|
||||||
"Connection closed by %s: %s, connections: %d/%d",
|
"Connection closed by %s: %s, connections: %d/%d",
|
||||||
closeReason, r.RemoteAddr, connections, server.options.MaxConnection,
|
closeReason, r.RemoteAddr, num, server.options.MaxConnection,
|
||||||
)
|
)
|
||||||
|
|
||||||
if server.options.Once {
|
if server.options.Once {
|
||||||
cancel()
|
cancel()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
wg.Done()
|
||||||
}()
|
}()
|
||||||
|
|
||||||
log.Printf("New client connected: %s", r.RemoteAddr)
|
|
||||||
if int64(server.options.MaxConnection) != 0 {
|
if int64(server.options.MaxConnection) != 0 {
|
||||||
if connections > int64(server.options.MaxConnection) {
|
if num > int64(server.options.MaxConnection) {
|
||||||
closeReason = "exceeding max number of connections"
|
closeReason = "exceeding max number of connections"
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
log.Printf("New client connected: %s, connections: %d/%d", r.RemoteAddr, num, server.options.MaxConnection)
|
||||||
|
|
||||||
if r.Method != "GET" {
|
if r.Method != "GET" {
|
||||||
http.Error(w, "Method not allowed", 405)
|
http.Error(w, "Method not allowed", 405)
|
||||||
return
|
return
|
||||||
|
@ -13,7 +13,6 @@ import (
|
|||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
noesctmpl "text/template"
|
noesctmpl "text/template"
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/elazarl/go-bindata-assetfs"
|
"github.com/elazarl/go-bindata-assetfs"
|
||||||
"github.com/gorilla/websocket"
|
"github.com/gorilla/websocket"
|
||||||
@ -29,18 +28,9 @@ type Server struct {
|
|||||||
factory Factory
|
factory Factory
|
||||||
options *Options
|
options *Options
|
||||||
|
|
||||||
srv *http.Server
|
upgrader *websocket.Upgrader
|
||||||
|
|
||||||
upgrader *websocket.Upgrader
|
|
||||||
|
|
||||||
indexTemplate *template.Template
|
indexTemplate *template.Template
|
||||||
titleTemplate *noesctmpl.Template
|
titleTemplate *noesctmpl.Template
|
||||||
titleVars map[string]interface{}
|
|
||||||
timer *time.Timer
|
|
||||||
wsWG sync.WaitGroup
|
|
||||||
url *url.URL // use URL()
|
|
||||||
connections *int64 // Use atomic operations
|
|
||||||
once *int64 // use atomic operations
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// New creates a new instance of Server.
|
// New creates a new instance of Server.
|
||||||
@ -51,7 +41,6 @@ func New(factory Factory, options *Options) (*Server, error) {
|
|||||||
panic("index not found") // must be in bindata
|
panic("index not found") // must be in bindata
|
||||||
}
|
}
|
||||||
if options.IndexFile != "" {
|
if options.IndexFile != "" {
|
||||||
log.Printf("Using index file at " + options.IndexFile)
|
|
||||||
path := homedir.Expand(options.IndexFile)
|
path := homedir.Expand(options.IndexFile)
|
||||||
indexData, err = ioutil.ReadFile(path)
|
indexData, err = ioutil.ReadFile(path)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -68,9 +57,6 @@ func New(factory Factory, options *Options) (*Server, error) {
|
|||||||
return nil, errors.Wrapf(err, "failed to parse window title format `%s`", options.TitleFormat)
|
return nil, errors.Wrapf(err, "failed to parse window title format `%s`", options.TitleFormat)
|
||||||
}
|
}
|
||||||
|
|
||||||
connections := int64(0)
|
|
||||||
once := int64(0)
|
|
||||||
|
|
||||||
return &Server{
|
return &Server{
|
||||||
factory: factory,
|
factory: factory,
|
||||||
options: options,
|
options: options,
|
||||||
@ -82,8 +68,6 @@ func New(factory Factory, options *Options) (*Server, error) {
|
|||||||
},
|
},
|
||||||
indexTemplate: indexTemplate,
|
indexTemplate: indexTemplate,
|
||||||
titleTemplate: titleTemplate,
|
titleTemplate: titleTemplate,
|
||||||
connections: &connections,
|
|
||||||
once: &once,
|
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -97,12 +81,18 @@ func (server *Server) Run(ctx context.Context, options ...RunOption) error {
|
|||||||
opt(opts)
|
opt(opts)
|
||||||
}
|
}
|
||||||
|
|
||||||
handlers := server.setupHandlers(cctx, cancel)
|
// wg and connections can be incosistent because they are handled nonatomically
|
||||||
srv, err := server.setupHTTPServer(handlers)
|
wg := new(sync.WaitGroup) // to wait all connections to be closed
|
||||||
|
connections := new(int64) // number of active connections
|
||||||
|
|
||||||
|
url := server.setupURL()
|
||||||
|
handlers := server.setupHandlers(cctx, cancel, url, connections, wg)
|
||||||
|
srv, err := server.setupHTTPServer(handlers, url)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errors.Wrapf(err, "failed to setup an HTTP server")
|
return errors.Wrapf(err, "failed to setup an HTTP server")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
log.Printf("URL: %s", url.String())
|
||||||
if server.options.PermitWrite {
|
if server.options.PermitWrite {
|
||||||
log.Printf("Permitting clients to write input to the PTY.")
|
log.Printf("Permitting clients to write input to the PTY.")
|
||||||
}
|
}
|
||||||
@ -111,19 +101,6 @@ func (server *Server) Run(ctx context.Context, options ...RunOption) error {
|
|||||||
log.Printf("Once option is provided, accepting only one client")
|
log.Printf("Once option is provided, accepting only one client")
|
||||||
}
|
}
|
||||||
|
|
||||||
server.srv = srv
|
|
||||||
|
|
||||||
if server.options.Timeout > 0 {
|
|
||||||
server.timer = time.NewTimer(time.Duration(server.options.Timeout) * time.Second)
|
|
||||||
go func() {
|
|
||||||
select {
|
|
||||||
case <-server.timer.C:
|
|
||||||
cancel()
|
|
||||||
case <-cctx.Done():
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
|
|
||||||
listenErr := make(chan error, 1)
|
listenErr := make(chan error, 1)
|
||||||
go func() {
|
go func() {
|
||||||
if server.options.EnableTLS {
|
if server.options.EnableTLS {
|
||||||
@ -161,21 +138,35 @@ func (server *Server) Run(ctx context.Context, options ...RunOption) error {
|
|||||||
err = cctx.Err()
|
err = cctx.Err()
|
||||||
}
|
}
|
||||||
|
|
||||||
conn := atomic.LoadInt64(server.connections)
|
conn := atomic.LoadInt64(connections)
|
||||||
if conn > 0 {
|
if conn > 0 {
|
||||||
log.Printf("Waiting for %d connections to be closed", conn)
|
log.Printf("Waiting for %d connections to be closed", conn)
|
||||||
}
|
}
|
||||||
server.wsWG.Wait()
|
wg.Wait()
|
||||||
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (server *Server) setupHandlers(ctx context.Context, cancel context.CancelFunc) http.Handler {
|
func (server *Server) setupURL() *url.URL {
|
||||||
|
host := net.JoinHostPort(server.options.Address, server.options.Port)
|
||||||
|
scheme := "http"
|
||||||
|
path := "/"
|
||||||
|
|
||||||
|
if server.options.EnableRandomUrl {
|
||||||
|
path = "/" + randomstring.Generate(server.options.RandomUrlLength) + "/"
|
||||||
|
}
|
||||||
|
|
||||||
|
if server.options.EnableTLS {
|
||||||
|
scheme = "https"
|
||||||
|
}
|
||||||
|
return &url.URL{Scheme: scheme, Host: host, Path: path}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (server *Server) setupHandlers(ctx context.Context, cancel context.CancelFunc, url *url.URL, connections *int64, wg *sync.WaitGroup) http.Handler {
|
||||||
staticFileHandler := http.FileServer(
|
staticFileHandler := http.FileServer(
|
||||||
&assetfs.AssetFS{Asset: Asset, AssetDir: AssetDir, Prefix: "static"},
|
&assetfs.AssetFS{Asset: Asset, AssetDir: AssetDir, Prefix: "static"},
|
||||||
)
|
)
|
||||||
|
|
||||||
url := server.URL()
|
|
||||||
var siteMux = http.NewServeMux()
|
var siteMux = http.NewServeMux()
|
||||||
siteMux.HandleFunc(url.Path, server.handleIndex)
|
siteMux.HandleFunc(url.Path, server.handleIndex)
|
||||||
siteMux.Handle(url.Path+"js/", http.StripPrefix(url.Path, staticFileHandler))
|
siteMux.Handle(url.Path+"js/", http.StripPrefix(url.Path, staticFileHandler))
|
||||||
@ -193,16 +184,13 @@ func (server *Server) setupHandlers(ctx context.Context, cancel context.CancelFu
|
|||||||
|
|
||||||
wsMux := http.NewServeMux()
|
wsMux := http.NewServeMux()
|
||||||
wsMux.Handle("/", siteHandler)
|
wsMux.Handle("/", siteHandler)
|
||||||
wsMux.HandleFunc(url.Path+"ws", server.generateHandleWS(ctx, cancel))
|
wsMux.HandleFunc(url.Path+"ws", server.generateHandleWS(ctx, cancel, connections, wg))
|
||||||
siteHandler = http.Handler(wsMux)
|
siteHandler = http.Handler(wsMux)
|
||||||
|
|
||||||
return server.wrapLogger(siteHandler)
|
return server.wrapLogger(siteHandler)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (server *Server) setupHTTPServer(handler http.Handler) (*http.Server, error) {
|
func (server *Server) setupHTTPServer(handler http.Handler, url *url.URL) (*http.Server, error) {
|
||||||
url := server.URL()
|
|
||||||
log.Printf("URL: %s", url.String())
|
|
||||||
|
|
||||||
srv := &http.Server{
|
srv := &http.Server{
|
||||||
Addr: url.Host,
|
Addr: url.Host,
|
||||||
Handler: handler,
|
Handler: handler,
|
||||||
@ -219,22 +207,6 @@ func (server *Server) setupHTTPServer(handler http.Handler) (*http.Server, error
|
|||||||
return srv, nil
|
return srv, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (server *Server) URL() *url.URL {
|
|
||||||
if server.url == nil {
|
|
||||||
host := net.JoinHostPort(server.options.Address, server.options.Port)
|
|
||||||
path := ""
|
|
||||||
if server.options.EnableRandomUrl {
|
|
||||||
path += "/" + randomstring.Generate(server.options.RandomUrlLength)
|
|
||||||
}
|
|
||||||
scheme := "http"
|
|
||||||
if server.options.EnableTLS {
|
|
||||||
scheme = "https"
|
|
||||||
}
|
|
||||||
server.url = &url.URL{Scheme: scheme, Host: host, Path: path + "/"}
|
|
||||||
}
|
|
||||||
return server.url
|
|
||||||
}
|
|
||||||
|
|
||||||
func (server *Server) tlsConfig() (*tls.Config, error) {
|
func (server *Server) tlsConfig() (*tls.Config, error) {
|
||||||
caFile := homedir.Expand(server.options.TLSCACrtFile)
|
caFile := homedir.Expand(server.options.TLSCACrtFile)
|
||||||
caCert, err := ioutil.ReadFile(caFile)
|
caCert, err := ioutil.ReadFile(caFile)
|
||||||
|
@ -1,17 +0,0 @@
|
|||||||
package server
|
|
||||||
|
|
||||||
import (
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
func (server *Server) stopTimer() {
|
|
||||||
if server.options.Timeout > 0 {
|
|
||||||
server.timer.Stop()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (server *Server) resetTimer() {
|
|
||||||
if server.options.Timeout > 0 {
|
|
||||||
server.timer.Reset(time.Duration(server.options.Timeout) * time.Second)
|
|
||||||
}
|
|
||||||
}
|
|
Loading…
Reference in New Issue
Block a user