afl-transmit/net/listener.go

107 lines
2.7 KiB
Go

package net
import (
"flag"
"fmt"
"github.com/maride/afl-transmit/logistic"
"github.com/maride/afl-transmit/stats"
"io"
"io/ioutil"
"log"
"net"
"strings"
)
var (
port int
restrictToPeers bool
)
// Registers the flags required for the listener
func RegisterListenFlags() {
flag.IntVar(&port, "port", ServerPort, "Port to bind server component to")
flag.BoolVar(&restrictToPeers, "restrict-to-peers", false, "Only allow connections from peers")
}
// Sets up a listener and listens forever for packets on the given port, storing their contents in the outputDirectory
func Listen(outputDirectory string) error {
// Create listener
addrStr := fmt.Sprintf(":%v", port)
listener, listenErr := net.Listen("tcp", addrStr)
if listenErr != nil {
return listenErr
}
// Prepare output directory path
outputDirectory = strings.TrimRight(outputDirectory, "/")
// Listen forever
for {
// Accept connection
conn, connErr := listener.Accept()
if connErr != nil {
log.Printf("Encountered error while accepting from %s: %s", conn.RemoteAddr().String(), connErr)
continue
}
// Check if we should restrict connections from peers
handleConnection := true
if restrictToPeers {
found := false
// Loop over peers
for _, p := range peers {
// Check if we found the remote address in our peers list
if p.Address == conn.RemoteAddr().String() {
found = true
break
}
}
// Handle connection only if its a peer
handleConnection = found
}
if handleConnection {
// Handle in a separate thread
go handle(conn, outputDirectory)
}
}
}
// Handles a single connection, and unpacks the received data into outputDirectory
func handle(conn net.Conn, outputDirectory string) {
// Make sure to close connection on return
defer conn.Close()
// Read raw content
cont, contErr := ioutil.ReadAll(conn) // bufio.NewReader(conn).ReadString('\x00')
// Check if we are able to decrypt
if CryptApplicable() {
// Decrypt packet
var decryptErr error
cont, decryptErr = Decrypt(cont)
if decryptErr != nil {
log.Printf("Failed to decrypt packet from %s: %s", conn.RemoteAddr().String(), decryptErr)
return
}
}
if contErr == nil || contErr == io.EOF {
// We received the whole content, time to process it
unpackErr := logistic.UnpackInto(cont, outputDirectory)
if unpackErr != nil {
log.Printf("Encountered error processing packet from %s: %s", conn.RemoteAddr().String(), unpackErr)
}
// Push read bytes to stats
stats.PushStat(stats.Stat{ReceivedBytes: uint64(len(cont))})
return
} else {
// We encountered an error on that connection
log.Printf("Encountered error while reading from %s: %s", conn.RemoteAddr().String(), contErr)
return
}
}