ghost/main.go

137 lines
3.4 KiB
Go

package main
import (
"fmt"
"net"
"os"
"regexp"
"strings"
"time"
)
func main() {
targetIP, domains := parseArgs()
motd()
process(targetIP, domains)
}
func parseArgs() (net.IP, []string) {
// Check for arguments
if len(os.Args) < 3 {
usage()
os.Exit(1)
}
// Parse IP
rawIP := os.Args[1]
ip := net.ParseIP(rawIP)
if ip == nil {
fmt.Fprintf(os.Stderr, "Not an IP address: %s\n", rawIP)
usage()
os.Exit(1)
}
// Check for formatting of domain names
domainExpr, _ := regexp.Compile("^([a-zA-Z0-9]+[a-zA-Z0-9\\-]*[a-zA-Z0-9]+.{0,1})+$")
for _, a := range os.Args[2:] {
if !domainExpr.MatchString(a) {
fmt.Fprintf(os.Stderr, "Doesn't look like a valid domain: %s\n", a)
usage()
os.Exit(1)
}
}
// All parsed and correct
return ip, os.Args[2:]
}
// process goes over the /etc/hosts and tries to best-fit the domain names for the IP
func process(ip net.IP, domains []string) {
// Read file
bytesEtcHosts, readErr := os.ReadFile("/etc/hosts")
if readErr != nil {
fmt.Fprintf(os.Stderr, "Failed to read /etc/hosts: %s\n", readErr.Error())
os.Exit(1)
}
etcHosts := string(bytesEtcHosts)
// Iterate over lines, find the first match for the IP
lines := strings.Split(etcHosts, "\n")
ipString := ip.String()
found := false
for lPos, l := range lines {
if strings.HasPrefix(l, ipString) {
// Matching line, append our domains.
// Avoid duplicates
fields := strings.Fields(l)
newDomainsArray := []string{}
ignore := false
for _, d := range domains {
for _, f := range fields {
if d == f {
// this domain name is already in the hosts file, skip
ignore = true
break
}
}
if !ignore {
newDomainsArray = append(newDomainsArray, d)
}
}
newDomains := strings.Join(newDomainsArray, " ")
// Preserve comments and construct the line again
hostLine, comment, hasComment := strings.Cut(l, "#")
if hasComment {
lines[lPos] = fmt.Sprintf("%s %s # %s", hostLine, newDomains, comment)
} else {
lines[lPos] = fmt.Sprintf("%s %s", hostLine, newDomains)
}
found = true
// Inform user
fmt.Printf("Updated the line for '%s' to include '%s'", ipString, newDomains)
}
}
// If a fitting line was not found in the previous for loop, append a new line
if !found {
newDomains := strings.Join(domains, " ")
date := time.Now().Format("02-01-2006")
newLine := fmt.Sprintf("%s\t%s # Added by ghost on %s", ipString, newDomains, date)
lines = append(lines, newLine)
// Inform user
fmt.Printf("Appended line for '%s' to include '%s'", ipString, newDomains)
}
// Write out again
newHosts := strings.Join(lines, "\n")
writeErr := os.WriteFile("/etc/hosts", []byte(newHosts), 0o644)
if writeErr != nil {
fmt.Fprintf(os.Stderr, "Failed to write /etc/hosts: %s\n", writeErr.Error())
os.Exit(1)
}
}
// usage prints the usage
func usage() {
fmt.Fprintf(os.Stderr, "Usage: ghost 10.0.1.1 server1.gh <foo.server2.gh> <www.server2.gh> ...\n")
}
// motd prints the banner at the start, fully easter-egg free
func motd() {
fmt.Println("+-------------------------+")
if(time.Now().Month() == time.October) {
// I mean, it's called 'ghost' after all...
fmt.Println("👻 🕷️ 🎃 G H O S T 🎃 🕷️ 👻")
fmt.Println(" your spooooky helper with ")
} else {
fmt.Println(" G H O S T ")
fmt.Println(" your friendly helper with ")
}
fmt.Println(" /etc/hosts, written in Go!")
fmt.Println("+-------------------------+")
}