From 26a3d25dafda19e737ec738894cb6e4e73b65a2e Mon Sep 17 00:00:00 2001 From: Chris Bednarski Date: Wed, 26 Feb 2020 20:14:07 -0800 Subject: [PATCH] Add the library back. oops. --- hostess/hostfile.go | 204 +++++++++++++++ hostess/hostfile_test.go | 202 +++++++++++++++ hostess/hostlist.go | 434 ++++++++++++++++++++++++++++++++ hostess/hostlist_test.go | 247 ++++++++++++++++++ hostess/hostname.go | 101 ++++++++ hostess/hostname_test.go | 115 +++++++++ hostess/test-fixtures/hostfile1 | 13 + hostess/test-fixtures/hostfile2 | 2 + 8 files changed, 1318 insertions(+) create mode 100644 hostess/hostfile.go create mode 100644 hostess/hostfile_test.go create mode 100644 hostess/hostlist.go create mode 100644 hostess/hostlist_test.go create mode 100644 hostess/hostname.go create mode 100644 hostess/hostname_test.go create mode 100644 hostess/test-fixtures/hostfile1 create mode 100644 hostess/test-fixtures/hostfile2 diff --git a/hostess/hostfile.go b/hostess/hostfile.go new file mode 100644 index 0000000..04ca657 --- /dev/null +++ b/hostess/hostfile.go @@ -0,0 +1,204 @@ +package hostess + +import ( + "fmt" + "io/ioutil" + "os" + "runtime" + "strings" +) + +const defaultOSX = ` +## +# Host Database +# +# localhost is used to configure the loopback interface +# when the system is booting. Do not change this entry. +## + +127.0.0.1 localhost +255.255.255.255 broadcasthost +::1 localhost +fe80::1%lo0 localhost +` + +const defaultLinux = ` +127.0.0.1 localhost +127.0.1.1 HOSTNAME + +# The following lines are desirable for IPv6 capable hosts +::1 localhost ip6-localhost ip6-loopback +fe00::0 ip6-localnet +ff00::0 ip6-mcastprefix +ff02::1 ip6-allnodes +ff02::2 ip6-allrouters +ff02::3 ip6-allhosts +` + +// Hostfile represents /etc/hosts (or a similar file, depending on OS), and +// includes a list of Hostnames. Hostfile includes +type Hostfile struct { + Path string + Hosts Hostlist + data []byte +} + +// NewHostfile creates a new Hostfile object from the specified file. +func NewHostfile() *Hostfile { + return &Hostfile{GetHostsPath(), Hostlist{}, []byte{}} +} + +// GetHostsPath returns the location of the hostfile; either env HOSTESS_PATH +// or /etc/hosts if HOSTESS_PATH is not set. +func GetHostsPath() string { + path := os.Getenv("HOSTESS_PATH") + if path == "" { + if runtime.GOOS == "windows" { + path = "C:\\Windows\\System32\\drivers\\etc\\hosts" + } else { + path = "/etc/hosts" + } + } + return path +} + +// TrimWS (Trim Whitespace) removes space, newline, and tabs from a string +// using strings.Trim() +func TrimWS(s string) string { + return strings.TrimSpace(s) +} + +// ParseLine parses an individual line in a hostfile, which may contain one +// (un)commented ip and one or more hostnames. For example +// +// 127.0.0.1 localhost mysite1 mysite2 +func ParseLine(line string) (Hostlist, error) { + var hostnames Hostlist + + if len(line) == 0 { + return hostnames, fmt.Errorf("line is blank") + } + + // Parse leading # for disabled lines + enabled := true + if line[0:1] == "#" { + enabled = false + line = TrimWS(line[1:]) + } + + // Parse other #s for actual comments + line = strings.Split(line, "#")[0] + + // Replace tabs and multispaces with single spaces throughout + line = strings.Replace(line, "\t", " ", -1) + for strings.Contains(line, " ") { + line = strings.Replace(line, " ", " ", -1) + } + line = TrimWS(line) + + // Break line into words + words := strings.Split(line, " ") + for idx, word := range words { + words[idx] = TrimWS(word) + } + + // Separate the first bit (the ip) from the other bits (the domains) + ip := words[0] + domains := words[1:] + + // if LooksLikeIPv4(ip) || LooksLikeIPv6(ip) { + for _, v := range domains { + hostname, err := NewHostname(v, ip, enabled) + if err != nil { + return nil, err + } + hostnames = append(hostnames, hostname) + } + // } + + return hostnames, nil +} + +// MustParseLine is like ParseLine but panics instead of errors. +func MustParseLine(line string) Hostlist { + hostlist, err := ParseLine(line) + if err != nil { + panic(err) + } + return hostlist +} + +// Parse reads +func (h *Hostfile) Parse() []error { + var errs []error + var line = 1 + for _, v := range strings.Split(string(h.data), "\n") { + hostnames, _ := ParseLine(v) + // if err != nil { + // log.Printf("Error parsing line %d: %s\n", line, err) + // } + for _, hostname := range hostnames { + err := h.Hosts.Add(hostname) + if err != nil { + errs = append(errs, err) + } + } + line++ + } + return errs +} + +// Read the contents of the hostfile from disk +func (h *Hostfile) Read() error { + data, err := ioutil.ReadFile(h.Path) + if err == nil { + h.data = data + } + return err +} + +// LoadHostfile creates a new Hostfile struct and tries to populate it from +// disk. Read and/or parse errors are returned as a slice. +func LoadHostfile() (hostfile *Hostfile, errs []error) { + hostfile = NewHostfile() + readErr := hostfile.Read() + if readErr != nil { + errs = []error{readErr} + return + } + errs = hostfile.Parse() + hostfile.Hosts.Sort() + return +} + +// GetData returns the internal snapshot of the hostfile we read when we loaded +// this hostfile from disk (if we ever did that). This is implemented for +// testing and you probably won't need to use it. +func (h *Hostfile) GetData() []byte { + return h.data +} + +// Format takes the current list of Hostnames in this Hostfile and turns it +// into a string suitable for use as an /etc/hosts file. +// Sorting uses the following logic: +// 1. List is sorted by IP address +// 2. Commented items are left in place +// 3. 127.* appears at the top of the list (so boot resolvers don't break) +// 4. When present, localhost will always appear first in the domain list +func (h *Hostfile) Format() []byte { + return h.Hosts.Format() +} + +// Save writes the Hostfile to disk to /etc/hosts or to the location specified +// by the HOSTESS_PATH environment variable (if set). +func (h *Hostfile) Save() error { + file, err := os.OpenFile(h.Path, os.O_RDWR|os.O_APPEND|os.O_TRUNC, 0644) + if err != nil { + return err + } + + defer file.Close() + _, err = file.Write(h.Format()) + + return err +} diff --git a/hostess/hostfile_test.go b/hostess/hostfile_test.go new file mode 100644 index 0000000..667343e --- /dev/null +++ b/hostess/hostfile_test.go @@ -0,0 +1,202 @@ +package hostess_test + +import ( + "fmt" + "runtime" + "strings" + "testing" + + hostess2 "github.com/cbednarski/hostess/hostess" +) + +const ipv4Pass = ` +127.0.0.1 +127.0.1.1 +10.200.30.50 +99.99.99.99 +999.999.999.999 +0.1.1.0 +` + +const ipv4Fail = ` +1234.1.1.1 +123.5.6 +12.12 +76.76.67.67.45 +` + +const ipv6 = `` +const domain = "localhost" +const ip = "127.0.0.1" +const enabled = true + +func Diff(expected, actual string) string { + return fmt.Sprintf(` +---- Expected ---- +%s +----- Actual ----- +%s +`, expected, actual) +} + +func TestGetHostsPath(t *testing.T) { + path := hostess2.GetHostsPath() + var expected string + if runtime.GOOS == "windows" { + expected = "C:\\Windows\\System32\\drivers\\etc\\hosts" + } else { + expected = "/etc/hosts" + } + if path != expected { + t.Error("Hosts path should be " + expected) + } +} + +func TestFormatHostfile(t *testing.T) { + // The sort order here is a bit weird. + // 1. We want localhost entries at the top + // 2. The rest are sorted by IP as STRINGS, not numeric values, so 10 + // precedes 8 + const expected = `127.0.0.1 localhost devsite +127.0.1.1 ip-10-37-12-18 +# 8.8.8.8 google.com +10.37.12.18 devsite.com m.devsite.com +` + + hostfile := hostess2.NewHostfile() + hostfile.Path = "./hosts" + hostfile.Hosts.Add(hostess2.MustHostname("localhost", "127.0.0.1", true)) + hostfile.Hosts.Add(hostess2.MustHostname("ip-10-37-12-18", "127.0.1.1", true)) + hostfile.Hosts.Add(hostess2.MustHostname("devsite", "127.0.0.1", true)) + hostfile.Hosts.Add(hostess2.MustHostname("google.com", "8.8.8.8", false)) + hostfile.Hosts.Add(hostess2.MustHostname("devsite.com", "10.37.12.18", true)) + hostfile.Hosts.Add(hostess2.MustHostname("m.devsite.com", "10.37.12.18", true)) + f := string(hostfile.Format()) + if f != expected { + t.Errorf("Hostfile output is not formatted correctly: %s", Diff(expected, f)) + } +} + +func TestTrimWS(t *testing.T) { + const expected = ` candy + + ` + actual := hostess2.TrimWS(expected) + if actual != "candy" { + t.Errorf("Output was not trimmed correctly: %s", Diff(expected, actual)) + } +} + +func TestParseLineBlank(t *testing.T) { + // Blank line + hosts, err := hostess2.ParseLine("") + expected := "line is blank" + if err.Error() != expected { + t.Errorf("Expected error %q; found %q", expected, err.Error()) + } + if len(hosts) > 0 { + t.Error("Expected to find zero hostnames") + } +} + +func TestParseLineComment(t *testing.T) { + // Comment + hosts, err := hostess2.ParseLine("# The following lines are desirable for IPv6 capable hosts") + if err == nil { + t.Error(err) + } + if len(hosts) > 0 { + t.Error("Expected to find zero hostnames") + } +} + +func TestParseLineOneWordComment(t *testing.T) { + // Single word comment + hosts, err := hostess2.ParseLine("#blah") + if err != nil { + t.Error(err) + } + if len(hosts) > 0 { + t.Error("Expected to find zero hostnames") + } +} + +func TestParseLineBasicHostnameComment(t *testing.T) { + hosts, err := hostess2.ParseLine("#66.33.99.11 test.domain.com") + if err != nil { + t.Error(err) + } + if !hosts.Contains(hostess2.MustHostname("test.domain.com", "66.33.99.11", false)) || + len(hosts) != 1 { + t.Error("Expected to find test.domain.com (disabled)") + } +} + +func TestParseLineMultiHostnameComment(t *testing.T) { + hosts, err := hostess2.ParseLine("# 66.33.99.11 test.domain.com domain.com") + if err != nil { + t.Error(err) + } + if !hosts.Contains(hostess2.MustHostname("test.domain.com", "66.33.99.11", false)) || + !hosts.Contains(hostess2.MustHostname("domain.com", "66.33.99.11", false)) || + len(hosts) != 2 { + t.Error("Expected to find domain.com and test.domain.com (disabled)") + t.Errorf("Found %+v", hosts) + } +} + +func TestParseLineMultiHostname(t *testing.T) { + // Not Commented stuff + hosts, err := hostess2.ParseLine("255.255.255.255 broadcasthost test.domain.com domain.com") + if err != nil { + t.Error(err) + } + if !hosts.Contains(hostess2.MustHostname("broadcasthost", "255.255.255.255", true)) || + !hosts.Contains(hostess2.MustHostname("test.domain.com", "255.255.255.255", true)) || + !hosts.Contains(hostess2.MustHostname("domain.com", "255.255.255.255", true)) || + len(hosts) != 3 { + t.Error("Expected to find broadcasthost, domain.com, and test.domain.com (enabled)") + } +} + +func TestParseLineIPv6A(t *testing.T) { + // Ipv6 stuff + hosts, err := hostess2.ParseLine("::1 localhost") + if err != nil { + t.Error(err) + } + if !hosts.Contains(hostess2.MustHostname("localhost", "::1", true)) || + len(hosts) != 1 { + t.Error("Expected to find localhost ipv6 (enabled)") + } +} + +func TestParseLineIPv6B(t *testing.T) { + hosts, err := hostess2.ParseLine("ff02::1 ip6-allnodes") + if err != nil { + t.Error(err) + } + if !hosts.Contains(hostess2.MustHostname("ip6-allnodes", "ff02::1", true)) || + len(hosts) != 1 { + t.Error("Expected to find ip6-allnodes ipv6 (enabled)") + } +} + +func TestLoadHostfile(t *testing.T) { + hostfile := hostess2.NewHostfile() + hostfile.Read() + if !strings.Contains(string(hostfile.GetData()), domain) { + t.Errorf("Expected to find %s", domain) + } + + hostfile.Parse() + on := enabled + if runtime.GOOS == "windows" { + on = false + } + hostname := hostess2.MustHostname(domain, ip, on) + found := hostfile.Hosts.Contains(hostname) + if !found { + t.Errorf("Expected to find %#v", hostname) + } +} diff --git a/hostess/hostlist.go b/hostess/hostlist.go new file mode 100644 index 0000000..392bd56 --- /dev/null +++ b/hostess/hostlist.go @@ -0,0 +1,434 @@ +package hostess + +import ( + "encoding/json" + "errors" + "fmt" + "net" + "sort" + "strings" +) + +// ErrInvalidVersionArg is raised when a function expects IPv 4 or 6 but is +// passed a value not 4 or 6. +var ErrInvalidVersionArg = errors.New("version argument must be 4 or 6") +var ErrHostnameNotFound = errors.New("hostname not found") + +// Hostlist is a sortable set of Hostnames. When in a Hostlist, Hostnames must +// follow some rules: +// +// - Hostlist may contain IPv4 AND IPv6 ("IP version" or "IPv") Hostnames. +// - Names are only allowed to overlap if IP version is different. +// - Adding a Hostname for an existing name will replace the old one. +// +// The Hostlist uses a deterministic Sort order designed to make a hostfile +// output look a particular way. Generally you don't need to worry about this +// as Sort will be called automatically before Format. However, the Hostlist +// may or may not be sorted at any particular time during runtime. +// +// See the docs and implementation in Sort and Add for more details. +type Hostlist []*Hostname + +// NewHostlist initializes a new Hostlist +func NewHostlist() *Hostlist { + return &Hostlist{} +} + +// Len returns the number of Hostnames in the list, part of sort.Interface +func (h Hostlist) Len() int { + return len(h) +} + +// MakeSurrogateIP takes an IP like 127.0.0.1 and munges it to 0.0.0.1 so we can +// sort it more easily. Note that we don't actually want to change the value, +// so we use value copies here (not pointers). +func MakeSurrogateIP(IP net.IP) net.IP { + if len(IP.String()) > 3 && IP.String()[0:3] == "127" { + return net.ParseIP("0" + IP.String()[3:]) + } + return IP +} + +// Less determines the sort order of two Hostnames, part of sort.Interface +func (h Hostlist) Less(A, B int) bool { + // Sort IPv4 before IPv6 + // A is IPv4 and B is IPv6. A wins! + if !h[A].IPv6 && h[B].IPv6 { + return true + } + // A is IPv6 but B is IPv4. A loses! + if h[A].IPv6 && !h[B].IPv6 { + return false + } + + // Sort "localhost" at the top + if h[A].Domain == "localhost" { + return true + } + if h[B].Domain == "localhost" { + return false + } + + // Compare the the IP addresses (byte array) + // We want to push 127. to the top so we're going to mark it zero. + surrogateA := MakeSurrogateIP(h[A].IP) + surrogateB := MakeSurrogateIP(h[B].IP) + if !surrogateA.Equal(surrogateB) { + for charIndex := range surrogateA { + // A and B's IPs differ at this index, and A is less. A wins! + if surrogateA[charIndex] < surrogateB[charIndex] { + return true + } + // A and B's IPs differ at this index, and B is less. A loses! + if surrogateA[charIndex] > surrogateB[charIndex] { + return false + } + } + // If we got here then the IPs are the same and we want to continue on + // to the domain sorting section. + } + + // Prep for sorting by domain name + aLength := len(h[A].Domain) + bLength := len(h[B].Domain) + max := aLength + if bLength > max { + max = bLength + } + + // Sort domains alphabetically + // TODO: This works best if domains are lowercased. However, we do not + // enforce lowercase because of UTF-8 domain names, which may be broken by + // case folding. There is a way to do this correctly but it's complicated + // so I'm not going to do it right now. + for charIndex := 0; charIndex < max; charIndex++ { + // This index is longer than A, so A is shorter. A wins! + if charIndex >= aLength { + return true + } + // This index is longer than B, so B is shorter. A loses! + if charIndex >= bLength { + return false + } + // A and B differ at this index and A is less. A wins! + if h[A].Domain[charIndex] < h[B].Domain[charIndex] { + return true + } + // A and B differ at this index and B is less. A loses! + if h[A].Domain[charIndex] > h[B].Domain[charIndex] { + return false + } + } + + // If we got here then A and B are the same -- by definition A is not Less + // than B so we return false. Technically we shouldn't get here since Add + // should not allow duplicates, but we'll guard anyway. + return false +} + +// Swap changes the position of two Hostnames, part of sort.Interface +func (h Hostlist) Swap(i, j int) { + h[i], h[j] = h[j], h[i] +} + +// Sort this list of Hostnames, according to Hostlist sorting rules: +// +// 1. localhost comes before other hostnames +// 2. IPv4 comes before IPv6 +// 3. IPs are sorted in numerical order +// 4. The remaining hostnames are sorted in lexicographical order +func (h *Hostlist) Sort() { + sort.Sort(*h) +} + +// Contains returns true if this Hostlist has the specified Hostname +func (h *Hostlist) Contains(b *Hostname) bool { + for _, a := range *h { + if a.Equal(b) { + return true + } + } + return false +} + +// ContainsDomain returns true if a Hostname in this Hostlist matches domain +func (h *Hostlist) ContainsDomain(domain string) bool { + for _, hostname := range *h { + if hostname.Domain == domain { + return true + } + } + return false +} + +// ContainsIP returns true if a Hostname in this Hostlist matches IP +func (h *Hostlist) ContainsIP(IP net.IP) bool { + for _, hostname := range *h { + if hostname.EqualIP(IP) { + return true + } + } + return false +} + +// Add a new Hostname to this hostlist. Add uses some merging logic in the +// event it finds duplicated hostnames. In the case of a conflict (incompatible +// entries) the last write wins. In the case of duplicates, duplicates will be +// removed and the remaining entry will be enabled if any of the duplicates was +// enabled. +// +// Both duplicate and conflicts return errors so you are aware of them, but you +// don't necessarily need to do anything about the error. +func (h *Hostlist) Add(input *Hostname) error { + newHostname, err := NewHostname(input.Domain, input.IP.String(), input.Enabled) + if err != nil { + return err + } + for index, found := range *h { + if found.Equal(newHostname) { + // If either hostname is enabled we will set the existing one to + // enabled state. That way if we add a hostname from the end of a + // hosts file it will take over, and if we later add a disabled one + // the original one will stick. We still error in this case so the + // user can see that there is a duplicate. + (*h)[index].Enabled = found.Enabled || newHostname.Enabled + return fmt.Errorf("duplicate hostname entry for %s -> %s", + newHostname.Domain, newHostname.IP) + } else if found.Domain == newHostname.Domain && found.IPv6 == newHostname.IPv6 { + (*h)[index] = newHostname + return fmt.Errorf("conflicting hostname entries for %s -> %s and -> %s", + newHostname.Domain, newHostname.IP, found.IP) + } + } + *h = append(*h, newHostname) + return nil +} + +// IndexOf will indicate the index of a Hostname in Hostlist, or -1 if it is +// not found. +func (h *Hostlist) IndexOf(host *Hostname) int { + for index, found := range *h { + if found.Equal(host) { + return index + } + } + return -1 +} + +// IndexOfDomainV will indicate the index of a Hostname in Hostlist that has +// the same domain and IP version, or -1 if it is not found. +// +// This function will panic if IP version is not 4 or 6. +func (h *Hostlist) IndexOfDomainV(domain string, version int) int { + if version != 4 && version != 6 { + panic(ErrInvalidVersionArg) + } + for index, hostname := range *h { + if hostname.Domain == domain && hostname.IPv6 == (version == 6) { + return index + } + } + return -1 +} + +// Remove will delete the Hostname at the specified index. If index is out of +// bounds (i.e. -1), Remove silently no-ops. Remove returns the number of items +// removed (0 or 1). +func (h *Hostlist) Remove(index int) int { + if index > -1 && index < len(*h) { + *h = append((*h)[:index], (*h)[index+1:]...) + return 1 + } + return 0 +} + +// RemoveDomain removes both IPv4 and IPv6 Hostname entries matching domain. +// Returns the number of entries removed. +func (h *Hostlist) RemoveDomain(domain string) int { + return h.RemoveDomainV(domain, 4) + h.RemoveDomainV(domain, 6) +} + +// RemoveDomainV removes a Hostname entry matching the domain and IP version. +func (h *Hostlist) RemoveDomainV(domain string, version int) int { + return h.Remove(h.IndexOfDomainV(domain, version)) +} + +// Enable will change any Hostnames matching name to be enabled. +func (h *Hostlist) Enable(name string) error { + for _, hostname := range *h { + if hostname.Domain == name { + hostname.Enabled = true + return nil + } + } + return ErrHostnameNotFound +} + +// EnableV will change a Hostname matching domain and IP version to be enabled. +// +// This function will panic if IP version is not 4 or 6. +func (h *Hostlist) EnableV(domain string, version int) error { + if version != 4 && version != 6 { + return ErrInvalidVersionArg + } + for _, hostname := range *h { + if hostname.Domain == domain && hostname.IPv6 == (version == 6) { + hostname.Enabled = true + return nil + } + } + return ErrHostnameNotFound +} + +// Disable will change any Hostnames matching name to be disabled. +func (h *Hostlist) Disable(name string) error { + for _, hostname := range *h { + if hostname.Domain == name { + hostname.Enabled = false + return nil + } + } + return ErrHostnameNotFound +} + +// DisableV will change any Hostnames matching domain and IP version to be disabled. +// +// This function will panic if IP version is not 4 or 6. +func (h *Hostlist) DisableV(domain string, version int) error { + if version != 4 && version != 6 { + return ErrInvalidVersionArg + } + for _, hostname := range *h { + if hostname.Domain == domain && hostname.IPv6 == (version == 6) { + hostname.Enabled = false + return nil + } + } + return ErrHostnameNotFound +} + +// FilterByIP filters the list of hostnames by IP address. +func (h *Hostlist) FilterByIP(IP net.IP) (hostnames []*Hostname) { + for _, hostname := range *h { + if hostname.IP.Equal(IP) { + hostnames = append(hostnames, hostname) + } + } + return +} + +// FilterByDomain filters the list of hostnames by Domain. +func (h *Hostlist) FilterByDomain(domain string) (hostnames []*Hostname) { + for _, hostname := range *h { + if hostname.Domain == domain { + hostnames = append(hostnames, hostname) + } + } + return +} + +// FilterByDomainV filters the list of hostnames by domain and IPv4 or IPv6. +// This should never contain more than one item, but returns a list for +// consistency with other filter functions. +// +// This function will panic if IP version is not 4 or 6. +func (h *Hostlist) FilterByDomainV(domain string, version int) (hostnames []*Hostname) { + if version != 4 && version != 6 { + panic(ErrInvalidVersionArg) + } + for _, hostname := range *h { + if hostname.Domain == domain && hostname.IPv6 == (version == 6) { + hostnames = append(hostnames, hostname) + } + } + return +} + +// GetUniqueIPs extracts an ordered list of unique IPs from the Hostlist. +// This calls Sort() internally. +func (h *Hostlist) GetUniqueIPs() []net.IP { + h.Sort() + // A map doesn't preserve order so we're going to use the map to check + // whether we've seen something and use the list to keep track of the + // order. + seen := make(map[string]bool) + inOrder := []net.IP{} + + for _, hostname := range *h { + key := (*hostname).IP.String() + if !seen[key] { + seen[key] = true + inOrder = append(inOrder, (*hostname).IP) + } + } + return inOrder +} + +// Format takes the current list of Hostnames in this Hostfile and turns it +// into a string suitable for use as an /etc/hosts file. +// Sorting uses the following logic: +// +// 1. List is sorted by IP address +// 2. Commented items are sorted displayed +// 3. 127.* appears at the top of the list (so boot resolvers don't break) +// 4. When present, "localhost" will always appear first in the domain list +func (h *Hostlist) Format() []byte { + h.Sort() + out := []byte{} + + // We want to output one line of hostnames per IP, so first we get that + // list of IPs and iterate. + for _, IP := range h.GetUniqueIPs() { + // Technically if an IP has some disabled hostnames we'll show two + // lines, one starting with a comment (#). + enabledIPs := []string{} + disabledIPs := []string{} + + // For this IP, get all hostnames that match and iterate over them. + for _, hostname := range h.FilterByIP(IP) { + // If it's enabled, put it in the enabled bucket (likewise for + // disabled hostnames) + if hostname.Enabled { + enabledIPs = append(enabledIPs, hostname.Domain) + } else { + disabledIPs = append(disabledIPs, hostname.Domain) + } + } + + // Finally, if the bucket contains anything, concatenate it all + // together and append it to the output. Also add a newline. + if len(enabledIPs) > 0 { + concat := fmt.Sprintf("%s %s", IP.String(), strings.Join(enabledIPs, " ")) + out = append(out, []byte(concat)...) + out = append(out, []byte("\n")...) + } + + if len(disabledIPs) > 0 { + concat := fmt.Sprintf("# %s %s", IP.String(), strings.Join(disabledIPs, " ")) + out = append(out, []byte(concat)...) + out = append(out, []byte("\n")...) + } + } + + return out +} + +// Dump exports all entries in the Hostlist as JSON +func (h *Hostlist) Dump() ([]byte, error) { + return json.MarshalIndent(h, "", " ") +} + +// Apply imports all entries from the JSON input to this Hostlist +func (h *Hostlist) Apply(jsonbytes []byte) error { + var hostnames Hostlist + err := json.Unmarshal(jsonbytes, &hostnames) + if err != nil { + return err + } + + for _, hostname := range hostnames { + h.Add(hostname) + } + + return nil +} diff --git a/hostess/hostlist_test.go b/hostess/hostlist_test.go new file mode 100644 index 0000000..48a050e --- /dev/null +++ b/hostess/hostlist_test.go @@ -0,0 +1,247 @@ +package hostess_test + +import ( + "bytes" + "fmt" + "net" + "testing" + + "github.com/cbednarski/hostess/hostess" +) + +func TestAddDuplicate(t *testing.T) { + list := hostess.NewHostlist() + + hostname := hostess.MustHostname("mysite", "1.2.3.4", false) + if err := list.Add(hostname); err != nil { + t.Error(err) + } + + hostname.Enabled = true + if err := list.Add(hostname); err == nil { + t.Error("Expected error because of duplicate entry") + } + + if !(*list)[0].Enabled { + t.Error("Expected hostname to be in enabled state") + } +} + +func TestAddConflict(t *testing.T) { + hostnameA := hostess.MustHostname("mysite", "1.2.3.4", true) + hostnameB := hostess.MustHostname("mysite", "5.2.3.4", false) + + list := hostess.NewHostlist() + list.Add(hostnameA) + if err := list.Add(hostnameB); err == nil { + t.Errorf("Expected conflict error") + } + + if !(*list)[0].Equal(hostnameB) { + t.Error("Expected second hostname to overwrite") + } + if (*list)[0].Enabled { + t.Error("Expected second hostname to be disabled") + } +} + +func TestMakeSurrogateIP(t *testing.T) { + original := net.ParseIP("127.0.0.1") + expected1 := net.ParseIP("0.0.0.1") + IP1 := hostess.MakeSurrogateIP(original) + if !IP1.Equal(expected1) { + t.Errorf("Expected %s to convert to %s; got %s", original, expected1, IP1) + } + + expected2 := net.ParseIP("10.20.30.40") + IP2 := hostess.MakeSurrogateIP(expected2) + if !IP2.Equal(expected2) { + t.Errorf("Expected %s to remain unchanged; got %s", expected2, IP2) + } +} + +func TestContainsDomainIp(t *testing.T) { + hosts := hostess.NewHostlist() + hosts.Add(hostess.MustHostname(domain, ip, false)) + hosts.Add(hostess.MustHostname("google.com", "8.8.8.8", true)) + + if !hosts.ContainsDomain(domain) { + t.Errorf("Expected to find %s", domain) + } + + const extraneousDomain = "yahoo.com" + if hosts.ContainsDomain(extraneousDomain) { + t.Errorf("Did not expect to find %s", extraneousDomain) + } + + var expectedIP = net.ParseIP(ip) + if !hosts.ContainsIP(expectedIP) { + t.Errorf("Expected to find %s", ip) + } + + var extraneousIP = net.ParseIP("1.2.3.4") + if hosts.ContainsIP(extraneousIP) { + t.Errorf("Did not expect to find %s", extraneousIP) + } + + expectedHostname := hostess.MustHostname(domain, ip, true) + if !hosts.Contains(expectedHostname) { + t.Errorf("Expected to find %+v", expectedHostname) + } + + extraneousHostname := hostess.MustHostname("yahoo.com", "4.3.2.1", false) + if hosts.Contains(extraneousHostname) { + t.Errorf("Did not expect to find %+v", extraneousHostname) + } +} + +func TestFormat(t *testing.T) { + hosts := hostess.NewHostlist() + hosts.Add(hostess.MustHostname(domain, ip, false)) + hosts.Add(hostess.MustHostname("google.com", "8.8.8.8", true)) + + expected := `# 127.0.0.1 localhost +8.8.8.8 google.com +` + if string(hosts.Format()) != expected { + t.Error("Formatted hosts list is not formatted correctly") + } +} + +func TestRemove(t *testing.T) { + hosts := hostess.NewHostlist() + hosts.Add(hostess.MustHostname(domain, ip, false)) + hosts.Add(hostess.MustHostname("google.com", "8.8.8.8", true)) + + removed := hosts.Remove(1) + if removed != 1 { + t.Error("Expected to remove 1 item") + } + if len(*hosts) > 1 { + t.Errorf("Expected hostlist to have 1 item, found %d", len(*hosts)) + } + if hosts.ContainsDomain("google.com") { + t.Errorf("Expected not to find google.com") + } + + hosts.Add(hostess.MustHostname(domain, "::1", enabled)) + removed = hosts.RemoveDomain(domain) + if removed != 2 { + t.Error("Expected to remove 2 items") + } +} + +func TestRemoveDomain(t *testing.T) { + hosts := hostess.NewHostlist() + h1 := hostess.MustHostname("google.com", "127.0.0.1", false) + h2 := hostess.MustHostname("google.com", "::1", true) + hosts.Add(h1) + hosts.Add(h2) + + hosts.RemoveDomainV("google.com", 4) + if hosts.Contains(h1) { + t.Error("Should not contain ipv4 hostname") + } + if !hosts.Contains(h2) { + t.Error("Should still contain ipv6 hostname") + } + + hosts.RemoveDomainV("google.com", 6) + if len(*hosts) != 0 { + t.Error("Should no longer contain any hostnames") + } +} + +func CheckIndexDomain(t *testing.T, index int, domain string, hosts *hostess.Hostlist) { + if (*hosts)[index].Domain != domain { + t.Errorf("Expected %s to be in position %d. Found: %s", domain, index, (*hosts)[index].FormatHuman()) + } +} + +func TestSort(t *testing.T) { + // Getting 100% coverage on this is kinda tricky. It's pretty close and + // this is already too long. + + hosts := hostess.NewHostlist() + hosts.Add(hostess.MustHostname("google.com", "8.8.8.8", true)) + hosts.Add(hostess.MustHostname("google3.com", "::1", true)) + hosts.Add(hostess.MustHostname(domain, ip, false)) + hosts.Add(hostess.MustHostname("google2.com", "8.8.4.4", true)) + hosts.Add(hostess.MustHostname("blah2", "10.20.1.1", true)) + hosts.Add(hostess.MustHostname("blah3", "10.20.1.1", true)) + hosts.Add(hostess.MustHostname("blah33", "10.20.1.1", true)) + hosts.Add(hostess.MustHostname("blah", "10.20.1.1", true)) + hosts.Add(hostess.MustHostname("hostname", "127.0.1.1", true)) + hosts.Add(hostess.MustHostname("devsite", "127.0.0.1", true)) + + hosts.Sort() + + CheckIndexDomain(t, 0, "localhost", hosts) + CheckIndexDomain(t, 1, "devsite", hosts) + CheckIndexDomain(t, 2, "hostname", hosts) + CheckIndexDomain(t, 3, "google2.com", hosts) + CheckIndexDomain(t, 4, "google.com", hosts) + CheckIndexDomain(t, 5, "blah", hosts) + CheckIndexDomain(t, 6, "blah2", hosts) + CheckIndexDomain(t, 7, "blah3", hosts) + CheckIndexDomain(t, 8, "blah33", hosts) + CheckIndexDomain(t, 9, "google3.com", hosts) +} + +func ExampleHostlist() { + hosts := hostess.NewHostlist() + hosts.Add(hostess.MustHostname("google.com", "127.0.0.1", false)) + hosts.Add(hostess.MustHostname("google.com", "::1", true)) + + fmt.Printf("%s\n", hosts.Format()) + // Output: + // # 127.0.0.1 google.com + // ::1 google.com +} + +const hostsjson = `[ + { + "domain": "google.com", + "ip": "127.0.0.1", + "enabled": false + }, + { + "domain": "google.com", + "ip": "::1", + "enabled": true + } +]` + +func TestDump(t *testing.T) { + hosts := hostess.NewHostlist() + hosts.Add(hostess.MustHostname("google.com", "127.0.0.1", false)) + hosts.Add(hostess.MustHostname("google.com", "::1", true)) + + expected := []byte(hostsjson) + actual, _ := hosts.Dump() + + if !bytes.Equal(actual, expected) { + t.Errorf("JSON output did not match expected output: %s", Diff(string(expected), string(actual))) + } + +} + +func TestApply(t *testing.T) { + hosts := hostess.NewHostlist() + hosts.Apply([]byte(hostsjson)) + + hostnameA := hostess.MustHostname("google.com", "127.0.0.1", false) + if !hosts.Contains(hostnameA) { + t.Errorf("Expected to find %s", hostnameA.Format()) + } + + hostnameB := hostess.MustHostname("google.com", "::1", true) + if !hosts.Contains(hostnameB) { + t.Errorf("Expected to find %s", hostnameB.Format()) + } + + hosts.Apply([]byte(hostsjson)) + if hosts.Len() != 2 { + t.Error("Hostslist contains the wrong number of items, expected 2") + } +} diff --git a/hostess/hostname.go b/hostess/hostname.go new file mode 100644 index 0000000..687f3b2 --- /dev/null +++ b/hostess/hostname.go @@ -0,0 +1,101 @@ +package hostess + +import ( + "fmt" + "net" + "regexp" + "strings" +) + +var ipv4Pattern = regexp.MustCompile(`^(?:[0-9]{1,3}\.){3}[0-9]{1,3}$`) +var ipv6Pattern = regexp.MustCompile(`^[(a-fA-F0-9){1-4}:]+$`) + +// LooksLikeIPv4 returns true if the IP looks like it's IPv4. This does not +// validate whether the string is a valid IP address. +func LooksLikeIPv4(ip string) bool { + return ipv4Pattern.MatchString(ip) +} + +// LooksLikeIPv6 returns true if the IP looks like it's IPv6. This does not +// validate whether the string is a valid IP address. +func LooksLikeIPv6(ip string) bool { + if !strings.Contains(ip, ":") { + return false + } + return ipv6Pattern.MatchString(ip) +} + +// Hostname represents a hosts file entry, including a Domain, IP, whether the +// Hostname is enabled (uncommented in the hosts file), and whether the IP is +// in the IPv6 format. You should always create these with NewHostname(). Note: +// when using Hostnames in the context of a Hostlist, you should not change the +// Hostname fields except through the Hostlist's aggregate methods. Doing so +// can cause unexpected behavior. Instead, use Hostlist's Add, Remove, Enable, +// and Disable methods. +type Hostname struct { + Domain string `json:"domain"` + IP net.IP `json:"ip"` + Enabled bool `json:"enabled"` + IPv6 bool `json:"-"` +} + +// NewHostname creates a new Hostname struct and automatically sets the IPv6 +// field based on the IP you pass in. +func NewHostname(domain, ip string, enabled bool) (*Hostname, error) { + if !LooksLikeIPv4(ip) && !LooksLikeIPv6(ip) { + return nil, fmt.Errorf("Unable to parse IP address %q", ip) + } + IP := net.ParseIP(ip) + return &Hostname{domain, IP, enabled, LooksLikeIPv6(ip)}, nil +} + +// MustHostname calls NewHostname but panics if there is an error parsing it. +func MustHostname(domain, ip string, enabled bool) *Hostname { + hostname, err := NewHostname(domain, ip, enabled) + if err != nil { + panic(err) + } + return hostname +} + +// Equal compares two Hostnames. Note that only the Domain and IP fields are +// compared because Enabled is transient state, and IPv6 should be set +// automatically based on IP. +func (h *Hostname) Equal(n *Hostname) bool { + return h.Domain == n.Domain && h.IP.Equal(n.IP) +} + +// EqualIP compares an IP against this Hostname. +func (h *Hostname) EqualIP(ip net.IP) bool { + return h.IP.Equal(ip) +} + +// IsValid does a spot-check on the domain and IP to make sure they aren't blank +func (h *Hostname) IsValid() bool { + return h.Domain != "" && h.IP != nil +} + +// Format outputs the Hostname as you'd see it in a hosts file, with a comment +// if it is disabled. E.g. +// # 127.0.0.1 blah.example.com +func (h *Hostname) Format() string { + r := fmt.Sprintf("%s %s", h.IP.String(), h.Domain) + if !h.Enabled { + r = "# " + r + } + return r +} + +// FormatEnabled displays Hostname.Enabled as (On) or (Off) +func (h *Hostname) FormatEnabled() string { + if h.Enabled { + return "(On)" + } + return "(Off)" +} + +// FormatHuman outputs the Hostname in a more human-readable format: +// blah.example.com -> 127.0.0.1 (Off) +func (h *Hostname) FormatHuman() string { + return fmt.Sprintf("%s -> %s %s", h.Domain, h.IP, h.FormatEnabled()) +} diff --git a/hostess/hostname_test.go b/hostess/hostname_test.go new file mode 100644 index 0000000..9a134f6 --- /dev/null +++ b/hostess/hostname_test.go @@ -0,0 +1,115 @@ +package hostess_test + +import ( + "net" + "testing" + + hostess2 "github.com/cbednarski/hostess/hostess" +) + +func TestHostname(t *testing.T) { + h := hostess2.MustHostname(domain, ip, enabled) + + if h.Domain != domain { + t.Errorf("Domain should be %s", domain) + } + if !h.IP.Equal(net.ParseIP(ip)) { + t.Errorf("IP should be %s", ip) + } + if h.Enabled != enabled { + t.Errorf("Enabled should be %t", enabled) + } +} + +func TestEqual(t *testing.T) { + a := hostess2.MustHostname("localhost", "127.0.0.1", true) + b := hostess2.MustHostname("localhost", "127.0.0.1", false) + c := hostess2.MustHostname("localhost", "127.0.1.1", false) + + if !a.Equal(b) { + t.Errorf("%+v and %+v should be equal", a, b) + } + if a.Equal(c) { + t.Errorf("%+v and %+v should not be equal", a, c) + } +} + +func TestEqualIP(t *testing.T) { + a := hostess2.MustHostname("localhost", "127.0.0.1", true) + c := hostess2.MustHostname("localhost", "127.0.1.1", false) + ip := net.ParseIP("127.0.0.1") + + if !a.EqualIP(ip) { + t.Errorf("%s and %s should be equal", a.IP, ip) + } + if a.EqualIP(c.IP) { + t.Errorf("%s and %s should not be equal", a.IP, c.IP) + } +} + +func TestIsValid(t *testing.T) { + hostname := &hostess2.Hostname{ + Domain: "localhost", + IP: net.ParseIP("127.0.0.1"), + Enabled: true, + IPv6: true, + } + if !hostname.IsValid() { + t.Fatalf("%+v should be a valid hostname", hostname) + } +} + +func TestIsValidBlank(t *testing.T) { + hostname := &hostess2.Hostname{ + Domain: "", + IP: net.ParseIP("127.0.0.1"), + Enabled: true, + IPv6: true, + } + if hostname.IsValid() { + t.Errorf("%+v should be invalid because the name is blank", hostname) + } +} +func TestIsValidBadIP(t *testing.T) { + hostname := &hostess2.Hostname{ + Domain: "localhost", + IP: net.ParseIP("localhost"), + Enabled: true, + IPv6: true, + } + if hostname.IsValid() { + t.Errorf("%+v should be invalid because the ip is malformed", hostname) + } +} + +func TestFormatHostname(t *testing.T) { + hostname := hostess2.MustHostname(domain, ip, enabled) + + const exp_enabled = "127.0.0.1 localhost" + if hostname.Format() != exp_enabled { + t.Errorf("Hostname format doesn't match desired output: %s", Diff(hostname.Format(), exp_enabled)) + } + + hostname.Enabled = false + const exp_disabled = "# 127.0.0.1 localhost" + if hostname.Format() != exp_disabled { + t.Errorf("Hostname format doesn't match desired output: %s", Diff(hostname.Format(), exp_disabled)) + } +} + +func TestFormatEnabled(t *testing.T) { + hostname := hostess2.MustHostname(domain, ip, enabled) + const expectedOn = "(On)" + if hostname.FormatEnabled() != expectedOn { + t.Errorf("Expected hostname to be turned %s", expectedOn) + } + const expectedHumanOn = "localhost -> 127.0.0.1 (On)" + if hostname.FormatHuman() != expectedHumanOn { + t.Errorf("Unexpected output%s", Diff(expectedHumanOn, hostname.FormatHuman())) + } + + hostname.Enabled = false + if hostname.FormatEnabled() != "(Off)" { + t.Error("Expected hostname to be turned (Off)") + } +} diff --git a/hostess/test-fixtures/hostfile1 b/hostess/test-fixtures/hostfile1 new file mode 100644 index 0000000..1172aa2 --- /dev/null +++ b/hostess/test-fixtures/hostfile1 @@ -0,0 +1,13 @@ +#192.168.0.1 pie.dev.example.com +192.168.0.2 cookie.example.com +::1 hostname.pie hostname.candy cake.example.com +fe:23b3:890e:342e::ef strawberry.pie.example.com +# fe:23b3:890e:342e::ef dev.strawberry.pie.example.com +192.168.1.3 pie.example.com +127.0.1.1 robobrain +# fe:23b3:890e:342e::ef chocolate.cake.example.com chocolate.ru.example.com chocolate.tr.example.com chocolate.cookie.example.com +fe:23b3:890e:342e::ef chocolate.pie.example.com +::1 localhost +127.0.0.1 localhost +192.168.1.1 pie.example.com +192.168.1.1 strawberry.pie.example.com \ No newline at end of file diff --git a/hostess/test-fixtures/hostfile2 b/hostess/test-fixtures/hostfile2 new file mode 100644 index 0000000..eb20dbf --- /dev/null +++ b/hostess/test-fixtures/hostfile2 @@ -0,0 +1,2 @@ +# entries: 2361 +0.0.0.0 101com.com