137 lines
3.4 KiB
Go
137 lines
3.4 KiB
Go
package main
|
|
|
|
import (
|
|
"fmt"
|
|
"net"
|
|
"os"
|
|
"regexp"
|
|
"strings"
|
|
"time"
|
|
)
|
|
|
|
func main() {
|
|
targetIP, domains := parseArgs()
|
|
motd()
|
|
process(targetIP, domains)
|
|
}
|
|
|
|
func parseArgs() (net.IP, []string) {
|
|
// Check for arguments
|
|
if len(os.Args) < 3 {
|
|
usage()
|
|
os.Exit(1)
|
|
}
|
|
|
|
// Parse IP
|
|
rawIP := os.Args[1]
|
|
ip := net.ParseIP(rawIP)
|
|
if ip == nil {
|
|
fmt.Fprintf(os.Stderr, "Not an IP address: %s\n", rawIP)
|
|
usage()
|
|
os.Exit(1)
|
|
}
|
|
|
|
// Check for formatting of domain names
|
|
domainExpr, _ := regexp.Compile("^([a-zA-Z0-9]+[a-zA-Z0-9\\-]*[a-zA-Z0-9]+.{0,1})+$")
|
|
for _, a := range os.Args[2:] {
|
|
if !domainExpr.MatchString(a) {
|
|
fmt.Fprintf(os.Stderr, "Doesn't look like a valid domain: %s\n", a)
|
|
usage()
|
|
os.Exit(1)
|
|
}
|
|
}
|
|
|
|
// All parsed and correct
|
|
return ip, os.Args[2:]
|
|
}
|
|
|
|
// process goes over the /etc/hosts and tries to best-fit the domain names for the IP
|
|
func process(ip net.IP, domains []string) {
|
|
// Read file
|
|
bytesEtcHosts, readErr := os.ReadFile("/etc/hosts")
|
|
if readErr != nil {
|
|
fmt.Fprintf(os.Stderr, "Failed to read /etc/hosts: %s\n", readErr.Error())
|
|
os.Exit(1)
|
|
}
|
|
etcHosts := string(bytesEtcHosts)
|
|
|
|
// Iterate over lines, find the first match for the IP
|
|
lines := strings.Split(etcHosts, "\n")
|
|
ipString := ip.String()
|
|
found := false
|
|
for lPos, l := range lines {
|
|
if strings.HasPrefix(l, ipString) {
|
|
// Matching line, append our domains.
|
|
|
|
// Avoid duplicates
|
|
fields := strings.Fields(l)
|
|
newDomainsArray := []string{}
|
|
ignore := false
|
|
for _, d := range domains {
|
|
for _, f := range fields {
|
|
if d == f {
|
|
// this domain name is already in the hosts file, skip
|
|
ignore = true
|
|
break
|
|
}
|
|
}
|
|
if !ignore {
|
|
newDomainsArray = append(newDomainsArray, d)
|
|
}
|
|
}
|
|
newDomains := strings.Join(newDomainsArray, " ")
|
|
|
|
// Preserve comments and construct the line again
|
|
hostLine, comment, hasComment := strings.Cut(l, "#")
|
|
if hasComment {
|
|
lines[lPos] = fmt.Sprintf("%s %s # %s", hostLine, newDomains, comment)
|
|
} else {
|
|
lines[lPos] = fmt.Sprintf("%s %s", hostLine, newDomains)
|
|
}
|
|
found = true
|
|
|
|
// Inform user
|
|
fmt.Printf("Updated the line for '%s' to include '%s'", ipString, newDomains)
|
|
}
|
|
}
|
|
|
|
// If a fitting line was not found in the previous for loop, append a new line
|
|
if !found {
|
|
newDomains := strings.Join(domains, " ")
|
|
date := time.Now().Format("02-01-2006")
|
|
newLine := fmt.Sprintf("%s\t%s # Added by ghost on %s", ipString, newDomains, date)
|
|
lines = append(lines, newLine)
|
|
|
|
// Inform user
|
|
fmt.Printf("Appended line for '%s' to include '%s'", ipString, newDomains)
|
|
}
|
|
|
|
// Write out again
|
|
newHosts := strings.Join(lines, "\n")
|
|
writeErr := os.WriteFile("/etc/hosts", []byte(newHosts), 0o644)
|
|
if writeErr != nil {
|
|
fmt.Fprintf(os.Stderr, "Failed to write /etc/hosts: %s\n", writeErr.Error())
|
|
os.Exit(1)
|
|
}
|
|
}
|
|
|
|
// usage prints the usage
|
|
func usage() {
|
|
fmt.Fprintf(os.Stderr, "Usage: ghost 10.0.1.1 server1.gh <foo.server2.gh> <www.server2.gh> ...\n")
|
|
}
|
|
|
|
// motd prints the banner at the start, fully easter-egg free
|
|
func motd() {
|
|
fmt.Println("+-------------------------+")
|
|
if(time.Now().Month() == time.October) {
|
|
// I mean, it's called 'ghost' after all...
|
|
fmt.Println("👻 🕷️ 🎃 G H O S T 🎃 🕷️ 👻")
|
|
fmt.Println(" your spooooky helper with ")
|
|
} else {
|
|
fmt.Println(" G H O S T ")
|
|
fmt.Println(" your friendly helper with ")
|
|
}
|
|
fmt.Println(" /etc/hosts, written in Go!")
|
|
fmt.Println("+-------------------------+")
|
|
}
|