diff --git a/commands.go b/commands.go index a0f47bf..1987a7e 100644 --- a/commands.go +++ b/commands.go @@ -66,8 +66,11 @@ func Add(c *cli.Context) { } hostsfile := MaybeLoadHostFile(c) - hostname := Hostname{c.Args()[0], c.Args()[1], true} - var err error + hostname, err := NewHostname(c.Args()[0], c.Args()[1], true) + if err != nil { + MaybeError(c, fmt.Sprintf("%s is not a valid ip address", c.Args()[1])) + } + if !hostsfile.Contains(hostname) { err = hostsfile.Add(hostname) } @@ -76,7 +79,7 @@ func Add(c *cli.Context) { if c.Bool("n") { fmt.Println(hostsfile.Format()) } else { - MaybePrintln(c, fmt.Sprintf("Added %s", ShowHostname(hostname))) + MaybePrintln(c, fmt.Sprintf("Added %s", ShowHostname(*hostname))) hostsfile.Save() } } else { @@ -153,7 +156,7 @@ func Ls(c *cli.Context) { hostname := hostsfile.Hosts[domain] fmt.Printf("%s -> %s %s\n", StrPadRight(hostname.Domain, maxdomain), - StrPadRight(hostname.Ip, maxip), + StrPadRight(hostname.Ip.String(), maxip), ShowEnabled(hostname.Enabled)) } } diff --git a/hostfile.go b/hostfile.go index 4a4238d..cec2090 100644 --- a/hostfile.go +++ b/hostfile.go @@ -4,6 +4,7 @@ import ( "errors" "fmt" "io/ioutil" + "net" "os" "sort" "strings" @@ -83,8 +84,8 @@ func TrimWS(s string) string { return strings.Trim(s, " \n\t") } -func ParseLine(line string) []Hostname { - var hostnames []Hostname +func ParseLine(line string) []*Hostname { + var hostnames []*Hostname if len(line) == 0 { return hostnames @@ -115,7 +116,10 @@ func ParseLine(line string) []Hostname { if LooksLikeIpv4(ip) || LooksLikeIpv6(ip) { for _, v := range domains { - hostnames = append(hostnames, Hostname{v, ip, enabled}) + hostname, err := NewHostname(v, ip, enabled) + if err == nil { + hostnames = append(hostnames, hostname) + } } } @@ -148,17 +152,17 @@ func MoveToFront(list []string, search string) []string { // ListDomainsByIp will look through Hostfile to find domains that match the // specified Ip and return them in a sorted slice. -func (h *Hostfile) ListDomainsByIp(ip string) []string { +func (h *Hostfile) ListDomainsByIp(ip net.IP) []string { var names []string for _, v := range h.Hosts { - if v.Ip == ip { + if v.Ip.Equal(ip) { names = append(names, v.Domain) } } sort.Strings(names) // Magic for localhost only, to make sure it's the first domain on its line - if ip == "127.0.0.1" { + if ip.Equal(net.ParseIP("127.0.0.1")) { names = MoveToFront(names, "localhost") } @@ -192,10 +196,10 @@ func (h *Hostfile) Format() string { // 127.0.0.1 = [localhost, blah, blah2] // 2.2.2.3 = [domain1, domain2] for _, hostname := range h.Hosts { - if hostname.Ip[0:4] == "127." { - localhosts[hostname.Ip] = append(localhosts[hostname.Ip], hostname.Domain) + if hostname.Ip.String()[0:4] == "127." { + localhosts[hostname.Ip.String()] = append(localhosts[hostname.Ip.String()], hostname.Domain) } else { - ips[hostname.Ip] = append(ips[hostname.Ip], hostname.Domain) + ips[hostname.Ip.String()] = append(ips[hostname.Ip.String()], hostname.Domain) } } @@ -208,9 +212,10 @@ func (h *Hostfile) Format() string { enabled_b := false disabled := "# " + ip disabled_b := false - for _, domain := range h.ListDomainsByIp(ip) { + IP := net.ParseIP(ip) + for _, domain := range h.ListDomainsByIp(IP) { hostname := *h.Hosts[domain] - if hostname.Ip == ip { + if hostname.Ip.Equal(IP) { if hostname.Enabled { enabled += " " + hostname.Domain enabled_b = true @@ -233,9 +238,10 @@ func (h *Hostfile) Format() string { enabled_b := false disabled := "# " + ip disabled_b := false - for _, domain := range h.ListDomainsByIp(ip) { + IP := net.ParseIP(ip) + for _, domain := range h.ListDomainsByIp(IP) { hostname := *h.Hosts[domain] - if hostname.Ip == ip { + if hostname.Ip.Equal(IP) { if hostname.Enabled { enabled += " " + hostname.Domain enabled_b = true @@ -261,9 +267,9 @@ func (h *Hostfile) Save() error { return nil } -func (h *Hostfile) Contains(b Hostname) bool { +func (h *Hostfile) Contains(b *Hostname) bool { for _, a := range h.Hosts { - if a.Equals(b) { + if a.Equals(*b) { return true } } @@ -279,10 +285,10 @@ func (h *Hostfile) ContainsDomain(search string) bool { return false } -func (h *Hostfile) Add(host Hostname) error { +func (h *Hostfile) Add(host *Hostname) error { host_f, found := h.Hosts[host.Domain] if found { - if host_f.Ip == host.Ip { + if host_f.Ip.Equal(host.Ip) { return errors.New(fmt.Sprintf("Duplicate hostname entry for %s -> %s", host.Domain, host.Ip)) } else { @@ -290,7 +296,7 @@ func (h *Hostfile) Add(host Hostname) error { host.Domain, host.Ip, host_f.Ip)) } } else { - h.Hosts[host.Domain] = &host + h.Hosts[host.Domain] = host } return nil } diff --git a/hostfile_test.go b/hostfile_test.go index 26f42d2..060668d 100644 --- a/hostfile_test.go +++ b/hostfile_test.go @@ -2,6 +2,7 @@ package hostess_test import ( "github.com/cbednarski/hostess" + "net" "strings" "testing" ) @@ -29,11 +30,11 @@ const ipv4_fail = ` ` const ipv6 = `` - const domain = "localhost" -const ip = "127.0.0.1" const enabled = true +var ip = net.ParseIP("127.0.0.1") + func TestGetHostsPath(t *testing.T) { path := hostess.GetHostsPath() const expected = "/etc/hosts" @@ -44,8 +45,8 @@ func TestGetHostsPath(t *testing.T) { func TestHostfile(t *testing.T) { hostfile := hostess.NewHostfile("./hosts") - hostfile.Add(hostess.Hostname{domain, ip, true}) - if hostfile.Hosts[domain].Ip != ip { + hostfile.Add(&hostess.Hostname{domain, ip, true, false}) + if !hostfile.Hosts[domain].Ip.Equal(ip) { t.Errorf("Hostsfile should have %s pointing to %s", domain, ip) } @@ -69,14 +70,14 @@ func TestHostFileDuplicates(t *testing.T) { hostfile := hostess.NewHostfile("./hosts") const exp_duplicate = "Duplicate hostname entry for localhost -> 127.0.0.1" - hostfile.Add(hostess.Hostname{domain, ip, true}) - err := hostfile.Add(hostess.Hostname{domain, ip, true}) + hostfile.Add(&hostess.Hostname{domain, ip, true, false}) + err := hostfile.Add(&hostess.Hostname{domain, ip, true, false}) if err.Error() != exp_duplicate { t.Errorf(asserts, exp_duplicate, err) } const exp_conflict = "Conflicting hostname entries for localhost -> 127.0.1.1 and -> 127.0.0.1" - err2 := hostfile.Add(hostess.Hostname{domain, "127.0.1.1", true}) + err2 := hostfile.Add(&hostess.Hostname{domain, net.ParseIP("127.0.1.1"), true, false}) if err2.Error() != exp_conflict { t.Errorf(asserts, exp_conflict, err2) } @@ -96,12 +97,12 @@ func TestFormatHostfile(t *testing.T) { # 8.8.8.8 google.com` hostfile := hostess.NewHostfile("./hosts") - hostfile.Add(hostess.Hostname{"localhost", "127.0.0.1", true}) - hostfile.Add(hostess.Hostname{"ip-10-37-12-18", "127.0.1.1", true}) - hostfile.Add(hostess.Hostname{"devsite", "127.0.0.1", true}) - hostfile.Add(hostess.Hostname{"google.com", "8.8.8.8", false}) - hostfile.Add(hostess.Hostname{"devsite.com", "10.37.12.18", true}) - hostfile.Add(hostess.Hostname{"m.devsite.com", "10.37.12.18", true}) + hostfile.Add(&hostess.Hostname{"localhost", net.ParseIP("127.0.0.1"), true, false}) + hostfile.Add(&hostess.Hostname{"ip-10-37-12-18", net.ParseIP("127.0.1.1"), true, false}) + hostfile.Add(&hostess.Hostname{"devsite", net.ParseIP("127.0.0.1"), true, false}) + hostfile.Add(&hostess.Hostname{"google.com", net.ParseIP("8.8.8.8"), false, false}) + hostfile.Add(&hostess.Hostname{"devsite.com", net.ParseIP("10.37.12.18"), true, false}) + hostfile.Add(&hostess.Hostname{"m.devsite.com", net.ParseIP("10.37.12.18"), true, false}) f := hostfile.Format() if f != expected { t.Errorf(asserts, expected, f) @@ -120,28 +121,28 @@ func TestTrimWS(t *testing.T) { func TestListDomainsByIp(t *testing.T) { hostfile := hostess.NewHostfile("./hosts") - hostfile.Add(hostess.Hostname{"devsite.com", "10.37.12.18", true}) - hostfile.Add(hostess.Hostname{"m.devsite.com", "10.37.12.18", true}) - hostfile.Add(hostess.Hostname{"google.com", "8.8.8.8", false}) + hostfile.Add(&hostess.Hostname{"devsite.com", net.ParseIP("10.37.12.18"), true, false}) + hostfile.Add(&hostess.Hostname{"m.devsite.com", net.ParseIP("10.37.12.18"), true, false}) + hostfile.Add(&hostess.Hostname{"google.com", net.ParseIP("8.8.8.8"), false, false}) - names := hostfile.ListDomainsByIp("10.37.12.18") + names := hostfile.ListDomainsByIp(net.ParseIP("10.37.12.18")) if !(names[0] == "devsite.com" && names[1] == "m.devsite.com") { t.Errorf("Expected devsite.com and m.devsite.com. Got %s", names) } hostfile2 := hostess.NewHostfile("./hosts") - hostfile2.Add(hostess.Hostname{"localhost", "127.0.0.1", true}) - hostfile2.Add(hostess.Hostname{"ip-10-37-12-18", "127.0.1.1", true}) - hostfile2.Add(hostess.Hostname{"devsite", "127.0.0.1", true}) + hostfile2.Add(&hostess.Hostname{"localhost", net.ParseIP("127.0.0.1"), true, false}) + hostfile2.Add(&hostess.Hostname{"ip-10-37-12-18", net.ParseIP("127.0.1.1"), true, false}) + hostfile2.Add(&hostess.Hostname{"devsite", net.ParseIP("127.0.0.1"), true, false}) - names2 := hostfile2.ListDomainsByIp("127.0.0.1") + names2 := hostfile2.ListDomainsByIp(net.ParseIP("127.0.0.1")) if !(names2[0] == "localhost" && names2[1] == "devsite") { t.Errorf("Expected localhost and devsite. Got %s", names2) } } func TestParseLine(t *testing.T) { - var hosts = []hostess.Hostname{} + var hosts = []*hostess.Hostname{} // Blank line hosts = hostess.ParseLine("") @@ -162,14 +163,14 @@ func TestParseLine(t *testing.T) { } hosts = hostess.ParseLine("#66.33.99.11 test.domain.com") - if !hostess.ContainsHostname(hosts, hostess.Hostname{"test.domain.com", "66.33.99.11", false}) || + if !hostess.ContainsHostname(hosts, &hostess.Hostname{"test.domain.com", net.ParseIP("66.33.99.11"), false, false}) || len(hosts) != 1 { t.Error("Expected to find test.domain.com (disabled)") } hosts = hostess.ParseLine("# 66.33.99.11 test.domain.com domain.com") - if !hostess.ContainsHostname(hosts, hostess.Hostname{"test.domain.com", "66.33.99.11", false}) || - !hostess.ContainsHostname(hosts, hostess.Hostname{"domain.com", "66.33.99.11", false}) || + if !hostess.ContainsHostname(hosts, &hostess.Hostname{"test.domain.com", net.ParseIP("66.33.99.11"), false, false}) || + !hostess.ContainsHostname(hosts, &hostess.Hostname{"domain.com", net.ParseIP("66.33.99.11"), false, false}) || len(hosts) != 2 { t.Error("Expected to find domain.com and test.domain.com (disabled)") t.Errorf("Found %s", hosts) @@ -177,22 +178,22 @@ func TestParseLine(t *testing.T) { // Not Commented stuff hosts = hostess.ParseLine("255.255.255.255 broadcasthost test.domain.com domain.com") - if !hostess.ContainsHostname(hosts, hostess.Hostname{"broadcasthost", "255.255.255.255", true}) || - !hostess.ContainsHostname(hosts, hostess.Hostname{"test.domain.com", "255.255.255.255", true}) || - !hostess.ContainsHostname(hosts, hostess.Hostname{"domain.com", "255.255.255.255", true}) || + if !hostess.ContainsHostname(hosts, &hostess.Hostname{"broadcasthost", net.ParseIP("255.255.255.255"), true, false}) || + !hostess.ContainsHostname(hosts, &hostess.Hostname{"test.domain.com", net.ParseIP("255.255.255.255"), true, false}) || + !hostess.ContainsHostname(hosts, &hostess.Hostname{"domain.com", net.ParseIP("255.255.255.255"), true, false}) || len(hosts) != 3 { t.Error("Expected to find broadcasthost, domain.com, and test.domain.com (enabled)") } // Ipv6 stuff hosts = hostess.ParseLine("::1 localhost") - if !hostess.ContainsHostname(hosts, hostess.Hostname{"localhost", "::1", true}) || + if !hostess.ContainsHostname(hosts, &hostess.Hostname{"localhost", net.ParseIP("::1"), true, true}) || len(hosts) != 1 { t.Error("Expected to find localhost ipv6 (enabled)") } hosts = hostess.ParseLine("ff02::1 ip6-allnodes") - if !hostess.ContainsHostname(hosts, hostess.Hostname{"ip6-allnodes", "ff02::1", true}) || + if !hostess.ContainsHostname(hosts, &hostess.Hostname{"ip6-allnodes", net.ParseIP("ff02::1"), true, true}) || len(hosts) != 1 { t.Error("Expected to find ip6-allnodes ipv6 (enabled)") } @@ -206,7 +207,7 @@ func TestLoadHostfile(t *testing.T) { } hostfile.Parse() - hostname := hostess.Hostname{domain, ip, enabled} + hostname := hostess.Hostname{domain, ip, enabled, false} _, found := hostfile.Hosts[hostname.Domain] if !found { t.Errorf("Expected to find %s", hostname) diff --git a/hostlist.go b/hostlist.go index f8786aa..14bc30a 100644 --- a/hostlist.go +++ b/hostlist.go @@ -1,15 +1,19 @@ package hostess -func ContainsHostname(hostnames []Hostname, hostname Hostname) bool { +import ( + "net" +) + +func ContainsHostname(hostnames []*Hostname, hostname *Hostname) bool { for _, v := range hostnames { - if v.Ip == hostname.Ip && v.Domain == hostname.Domain { + if v.Ip.Equal(hostname.Ip) && v.Domain == hostname.Domain { return true } } return false } -func ContainsDomain(hostnames []Hostname, domain string) bool { +func ContainsDomain(hostnames []*Hostname, domain string) bool { for _, v := range hostnames { if v.Domain == domain { return true @@ -18,9 +22,9 @@ func ContainsDomain(hostnames []Hostname, domain string) bool { return false } -func ContainsIp(hostnames []Hostname, ip string) bool { +func ContainsIp(hostnames []*Hostname, ip net.IP) bool { for _, v := range hostnames { - if v.Ip == ip { + if v.Ip.Equal(ip) { return true } } diff --git a/hostlist_test.go b/hostlist_test.go index b987b60..20695e2 100644 --- a/hostlist_test.go +++ b/hostlist_test.go @@ -2,13 +2,14 @@ package hostess_test import ( "github.com/cbednarski/hostess" + "net" "testing" ) func TestContainsDomainIp(t *testing.T) { - hosts := []hostess.Hostname{ - hostess.Hostname{domain, ip, false}, - hostess.Hostname{"google.com", "8.8.8.8", true}, + hosts := []*hostess.Hostname{ + &hostess.Hostname{domain, ip, false, false}, + &hostess.Hostname{"google.com", net.ParseIP("8.8.8.8"), true, false}, } if !hostess.ContainsDomain(hosts, domain) { @@ -24,17 +25,17 @@ func TestContainsDomainIp(t *testing.T) { t.Errorf("Expected to find %s", ip) } - const extra_ip = "1.2.3.4" + var extra_ip = net.ParseIP("1.2.3.4") if hostess.ContainsIp(hosts, extra_ip) { t.Errorf("Did not expect to find %s", extra_ip) } - hostname := hostess.Hostname{domain, ip, true} + hostname := &hostess.Hostname{domain, ip, true, false} if !hostess.ContainsHostname(hosts, hostname) { t.Errorf("Expected to find %s", hostname) } - extra_hostname := hostess.Hostname{"yahoo.com", "4.3.2.1", false} + extra_hostname := &hostess.Hostname{"yahoo.com", net.ParseIP("4.3.2.1"), false, false} if hostess.ContainsHostname(hosts, extra_hostname) { t.Errorf("Did not expect to find %s", extra_hostname) } diff --git a/hostname.go b/hostname.go index 1492312..da08e59 100644 --- a/hostname.go +++ b/hostname.go @@ -2,18 +2,18 @@ package hostess import ( "fmt" + "net" "regexp" "strings" ) var ipv4_pattern = regexp.MustCompile(`^(?:[0-9]{1,3}\.){3}[0-9]{1,3}$`) +var ipv6_pattern = regexp.MustCompile(`^[a-z0-9:]+$`) func LooksLikeIpv4(ip string) bool { return ipv4_pattern.MatchString(ip) } -var ipv6_pattern = regexp.MustCompile(`^[a-z0-9:]+$`) - func LooksLikeIpv6(ip string) bool { if !strings.Contains(ip, ":") { return false @@ -23,13 +23,21 @@ func LooksLikeIpv6(ip string) bool { type Hostname struct { Domain string - Ip string + Ip net.IP Enabled bool - // Ipv6 bool + Ipv6 bool +} + +func NewHostname(domain, ip string, enabled bool) (hostname *Hostname, err error) { + IP := net.ParseIP(ip) + if IP != nil { + hostname = &Hostname{domain, IP, enabled, LooksLikeIpv6(ip)} + } + return } func (h *Hostname) Format() string { - r := fmt.Sprintf("%s %s", h.Ip, h.Domain) + r := fmt.Sprintf("%s %s", h.Ip.String(), h.Domain) if !h.Enabled { r = "# " + r } @@ -37,7 +45,7 @@ func (h *Hostname) Format() string { } func (a *Hostname) Equals(b Hostname) bool { - if a.Domain == b.Domain && a.Ip == b.Ip { + if a.Domain == b.Domain && a.Ip.Equal(b.Ip) { return true } return false diff --git a/hostname_test.go b/hostname_test.go index 8558518..7f23889 100644 --- a/hostname_test.go +++ b/hostname_test.go @@ -15,8 +15,8 @@ func TestHostname(t *testing.T) { if h.Domain != domain { t.Errorf("Domain should be %s", domain) } - if h.Ip != ip { - t.Errorf("Domain should be %s", ip) + if !h.Ip.Equal(ip) { + t.Errorf("Ip should be %s", ip) } if h.Enabled != enabled { t.Errorf("Enabled should be %s", enabled) @@ -24,7 +24,7 @@ func TestHostname(t *testing.T) { } func TestFormatHostname(t *testing.T) { - hostname := hostess.Hostname{domain, ip, enabled} + hostname := &hostess.Hostname{domain, ip, enabled, false} const exp_enabled = "127.0.0.1 localhost" if hostname.Format() != exp_enabled {