diff --git a/knockr.go b/knockr.go index 7442b02..69bf161 100644 --- a/knockr.go +++ b/knockr.go @@ -47,6 +47,11 @@ func stats() { } } +func get_address_from_conn(c net.Conn) string { + host, _, _ := net.SplitHostPort(c.RemoteAddr().String()) + return host +} + func listener(port int, listen_func func(c net.Conn)) { // Set up listening sockets on specified port and hand over to specified listen_func ln, err := net.Listen("tcp", fmt.Sprintf(":%d", port)) @@ -69,7 +74,7 @@ func listener(port int, listen_func func(c net.Conn)) { func whitelist_handler(c net.Conn) { // Handler function for whitelist socket connections, whitelisting the connecting host - host, _, _ := net.SplitHostPort(c.RemoteAddr().String()) + host := get_address_from_conn(c) if is_blacklisted(host) { if arguments.Verbose { @@ -85,7 +90,7 @@ func whitelist_handler(c net.Conn) { func blacklist_handler(c net.Conn) { // Handler which blocks every host connecting to it. // Useful to place it on port (whitelistPort-1) to crash port scanners. - host, _, _ := net.SplitHostPort(c.RemoteAddr().String()) + host := get_address_from_conn(c) if ! is_whitelisted(host) { if arguments.Verbose { @@ -102,7 +107,7 @@ func blacklist_handler(c net.Conn) { func gateway_handler(c net.Conn) { // Filter connections whether or not the connecting host is whitelisted - host, _, _ := net.SplitHostPort(c.RemoteAddr().String()) + host := get_address_from_conn(c) if is_blacklisted(host) { if arguments.Verbose { @@ -188,17 +193,19 @@ func proxy(c net.Conn) { fmt.Println("[ERR] Proxy connection to server failed") fmt.Println(" Error is ", err) } else { - go proxy_writefunc(c, ln, &traffic_in) - proxy_writefunc(ln, c, &traffic_out) + host := get_address_from_conn(c) + go proxy_writefunc(c, ln, &traffic_in, host) + proxy_writefunc(ln, c, &traffic_out, host) } } -func proxy_writefunc(a net.Conn, b net.Conn, written_bytes *int64) { +func proxy_writefunc(a net.Conn, b net.Conn, written_bytes *int64, addr string) { var delta int64 = 0 var err error for err == nil { delta, err = io.CopyN(a, b, 1024) *written_bytes += delta + update_whitelist_time(addr) } }