You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
hostess/hostess/hostlist.go

437 lines
13 KiB
Go

package hostess
import (
"encoding/json"
"errors"
"fmt"
"net"
"sort"
"strings"
)
const EnvHostessFmt = `HOSTESS_FMT`
// ErrInvalidVersionArg is raised when a function expects IPv 4 or 6 but is
// passed a value not 4 or 6.
var ErrInvalidVersionArg = errors.New("version argument must be 4 or 6")
var ErrHostnameNotFound = errors.New("hostname not found")
// Hostlist is a sortable set of Hostnames. When in a Hostlist, Hostnames must
// follow some rules:
//
// - Hostlist may contain IPv4 AND IPv6 ("IP version" or "IPv") Hostnames.
// - Names are only allowed to overlap if IP version is different.
// - Adding a Hostname for an existing name will replace the old one.
//
// The Hostlist uses a deterministic Sort order designed to make a hostfile
// output look a particular way. Generally you don't need to worry about this
// as Sort will be called automatically before Format. However, the Hostlist
// may or may not be sorted at any particular time during runtime.
//
// See the docs and implementation in Sort and Add for more details.
type Hostlist []*Hostname
// NewHostlist initializes a new Hostlist
func NewHostlist() *Hostlist {
return &Hostlist{}
}
// Len returns the number of Hostnames in the list, part of sort.Interface
func (h Hostlist) Len() int {
return len(h)
}
// MakeSurrogateIP takes an IP like 127.0.0.1 and munges it to 0.0.0.1 so we can
// sort it more easily. Note that we don't actually want to change the value,
// so we use value copies here (not pointers).
func MakeSurrogateIP(IP net.IP) net.IP {
if len(IP.String()) > 3 && IP.String()[0:3] == "127" {
return net.ParseIP("0" + IP.String()[3:])
}
return IP
}
// Less determines the sort order of two Hostnames, part of sort.Interface
func (h Hostlist) Less(A, B int) bool {
// Sort IPv4 before IPv6
// A is IPv4 and B is IPv6. A wins!
if !h[A].IPv6 && h[B].IPv6 {
return true
}
// A is IPv6 but B is IPv4. A loses!
if h[A].IPv6 && !h[B].IPv6 {
return false
}
// Sort "localhost" at the top
if h[A].Domain == "localhost" {
return true
}
if h[B].Domain == "localhost" {
return false
}
// Compare the the IP addresses (byte array)
// We want to push 127. to the top so we're going to mark it zero.
surrogateA := MakeSurrogateIP(h[A].IP)
surrogateB := MakeSurrogateIP(h[B].IP)
if !surrogateA.Equal(surrogateB) {
for charIndex := range surrogateA {
// A and B's IPs differ at this index, and A is less. A wins!
if surrogateA[charIndex] < surrogateB[charIndex] {
return true
}
// A and B's IPs differ at this index, and B is less. A loses!
if surrogateA[charIndex] > surrogateB[charIndex] {
return false
}
}
// If we got here then the IPs are the same and we want to continue on
// to the domain sorting section.
}
// Prep for sorting by domain name
aLength := len(h[A].Domain)
bLength := len(h[B].Domain)
max := aLength
if bLength > max {
max = bLength
}
// Sort domains alphabetically
// TODO: This works best if domains are lowercased. However, we do not
// enforce lowercase because of UTF-8 domain names, which may be broken by
// case folding. There is a way to do this correctly but it's complicated
// so I'm not going to do it right now.
for charIndex := 0; charIndex < max; charIndex++ {
// This index is longer than A, so A is shorter. A wins!
if charIndex >= aLength {
return true
}
// This index is longer than B, so B is shorter. A loses!
if charIndex >= bLength {
return false
}
// A and B differ at this index and A is less. A wins!
if h[A].Domain[charIndex] < h[B].Domain[charIndex] {
return true
}
// A and B differ at this index and B is less. A loses!
if h[A].Domain[charIndex] > h[B].Domain[charIndex] {
return false
}
}
// If we got here then A and B are the same -- by definition A is not Less
// than B so we return false. Technically we shouldn't get here since Add
// should not allow duplicates, but we'll guard anyway.
return false
}
// Swap changes the position of two Hostnames, part of sort.Interface
func (h Hostlist) Swap(i, j int) {
h[i], h[j] = h[j], h[i]
}
// Sort this list of Hostnames, according to Hostlist sorting rules:
//
// 1. localhost comes before other hostnames
// 2. IPv4 comes before IPv6
// 3. IPs are sorted in numerical order
// 4. The remaining hostnames are sorted in lexicographical order
func (h *Hostlist) Sort() {
sort.Sort(*h)
}
// Contains returns true if this Hostlist has the specified Hostname
func (h *Hostlist) Contains(b *Hostname) bool {
for _, a := range *h {
if a.Equal(b) {
return true
}
}
return false
}
// ContainsDomain returns true if a Hostname in this Hostlist matches domain
func (h *Hostlist) ContainsDomain(domain string) bool {
for _, hostname := range *h {
if hostname.Domain == domain {
return true
}
}
return false
}
// ContainsIP returns true if a Hostname in this Hostlist matches IP
func (h *Hostlist) ContainsIP(IP net.IP) bool {
for _, hostname := range *h {
if hostname.EqualIP(IP) {
return true
}
}
return false
}
// Add a new Hostname to this hostlist. Add uses some merging logic in the
// event it finds duplicated hostnames. In the case of a conflict (incompatible
// entries) the last write wins. In the case of duplicates, duplicates will be
// removed and the remaining entry will be enabled if any of the duplicates was
// enabled.
//
// Both duplicate and conflicts return errors so you are aware of them, but you
// don't necessarily need to do anything about the error.
func (h *Hostlist) Add(input *Hostname) error {
newHostname, err := NewHostname(input.Domain, input.IP.String(), input.Enabled)
if err != nil {
return err
}
for index, found := range *h {
if found.Equal(newHostname) {
// If either hostname is enabled we will set the existing one to
// enabled state. That way if we add a hostname from the end of a
// hosts file it will take over, and if we later add a disabled one
// the original one will stick. We still error in this case so the
// user can see that there is a duplicate.
(*h)[index].Enabled = found.Enabled || newHostname.Enabled
return fmt.Errorf("duplicate hostname entry for %s -> %s",
newHostname.Domain, newHostname.IP)
} else if found.Domain == newHostname.Domain && found.IPv6 == newHostname.IPv6 {
(*h)[index] = newHostname
return fmt.Errorf("conflicting hostname entries for %s -> %s and -> %s",
newHostname.Domain, newHostname.IP, found.IP)
}
}
*h = append(*h, newHostname)
return nil
}
// IndexOf will indicate the index of a Hostname in Hostlist, or -1 if it is
// not found.
func (h *Hostlist) IndexOf(host *Hostname) int {
for index, found := range *h {
if found.Equal(host) {
return index
}
}
return -1
}
// IndexOfDomainV will indicate the index of a Hostname in Hostlist that has
// the same domain and IP version, or -1 if it is not found.
//
// This function will panic if IP version is not 4 or 6.
func (h *Hostlist) IndexOfDomainV(domain string, version int) int {
if version != 4 && version != 6 {
panic(ErrInvalidVersionArg)
}
for index, hostname := range *h {
if hostname.Domain == domain && hostname.IPv6 == (version == 6) {
return index
}
}
return -1
}
// Remove will delete the Hostname at the specified index. If index is out of
// bounds (i.e. -1), Remove silently no-ops. Remove returns the number of items
// removed (0 or 1).
func (h *Hostlist) Remove(index int) int {
if index > -1 && index < len(*h) {
*h = append((*h)[:index], (*h)[index+1:]...)
return 1
}
return 0
}
// RemoveDomain removes both IPv4 and IPv6 Hostname entries matching domain.
// Returns the number of entries removed.
func (h *Hostlist) RemoveDomain(domain string) int {
return h.RemoveDomainV(domain, 4) + h.RemoveDomainV(domain, 6)
}
// RemoveDomainV removes a Hostname entry matching the domain and IP version.
func (h *Hostlist) RemoveDomainV(domain string, version int) int {
return h.Remove(h.IndexOfDomainV(domain, version))
}
// Enable will change any Hostnames matching name to be enabled.
func (h *Hostlist) Enable(name string) error {
for _, hostname := range *h {
if hostname.Domain == name {
hostname.Enabled = true
return nil
}
}
return ErrHostnameNotFound
}
// EnableV will change a Hostname matching domain and IP version to be enabled.
//
// This function will panic if IP version is not 4 or 6.
func (h *Hostlist) EnableV(domain string, version int) error {
if version != 4 && version != 6 {
return ErrInvalidVersionArg
}
for _, hostname := range *h {
if hostname.Domain == domain && hostname.IPv6 == (version == 6) {
hostname.Enabled = true
return nil
}
}
return ErrHostnameNotFound
}
// Disable will change any Hostnames matching name to be disabled.
func (h *Hostlist) Disable(name string) error {
for _, hostname := range *h {
if hostname.Domain == name {
hostname.Enabled = false
return nil
}
}
return ErrHostnameNotFound
}
// DisableV will change any Hostnames matching domain and IP version to be disabled.
//
// This function will panic if IP version is not 4 or 6.
func (h *Hostlist) DisableV(domain string, version int) error {
if version != 4 && version != 6 {
return ErrInvalidVersionArg
}
for _, hostname := range *h {
if hostname.Domain == domain && hostname.IPv6 == (version == 6) {
hostname.Enabled = false
return nil
}
}
return ErrHostnameNotFound
}
// FilterByIP filters the list of hostnames by IP address.
func (h *Hostlist) FilterByIP(IP net.IP) (hostnames []*Hostname) {
for _, hostname := range *h {
if hostname.IP.Equal(IP) {
hostnames = append(hostnames, hostname)
}
}
return
}
// FilterByDomain filters the list of hostnames by Domain.
func (h *Hostlist) FilterByDomain(domain string) (hostnames []*Hostname) {
for _, hostname := range *h {
if hostname.Domain == domain {
hostnames = append(hostnames, hostname)
}
}
return
}
// FilterByDomainV filters the list of hostnames by domain and IPv4 or IPv6.
// This should never contain more than one item, but returns a list for
// consistency with other filter functions.
//
// This function will panic if IP version is not 4 or 6.
func (h *Hostlist) FilterByDomainV(domain string, version int) (hostnames []*Hostname) {
if version != 4 && version != 6 {
panic(ErrInvalidVersionArg)
}
for _, hostname := range *h {
if hostname.Domain == domain && hostname.IPv6 == (version == 6) {
hostnames = append(hostnames, hostname)
}
}
return
}
// GetUniqueIPs extracts an ordered list of unique IPs from the Hostlist.
// This calls Sort() internally.
func (h *Hostlist) GetUniqueIPs() []net.IP {
h.Sort()
// A map doesn't preserve order so we're going to use the map to check
// whether we've seen something and use the list to keep track of the
// order.
seen := make(map[string]bool)
inOrder := []net.IP{}
for _, hostname := range *h {
key := (*hostname).IP.String()
if !seen[key] {
seen[key] = true
inOrder = append(inOrder, (*hostname).IP)
}
}
return inOrder
}
// Format takes the current list of Hostnames in this Hostfile and turns it
// into a string suitable for use as an /etc/hosts file.
// Sorting uses the following logic:
//
// 1. List is sorted by IP address
// 2. Commented items are sorted displayed
// 3. 127.* appears at the top of the list (so boot resolvers don't break)
// 4. When present, "localhost" will always appear first in the domain list
func (h *Hostlist) Format() []byte {
h.Sort()
out := []byte{}
// We want to output one line of hostnames per IP, so first we get that
// list of IPs and iterate.
for _, IP := range h.GetUniqueIPs() {
// Technically if an IP has some disabled hostnames we'll show two
// lines, one starting with a comment (#).
enabledIPs := []string{}
disabledIPs := []string{}
// For this IP, get all hostnames that match and iterate over them.
for _, hostname := range h.FilterByIP(IP) {
// If it's enabled, put it in the enabled bucket (likewise for
// disabled hostnames)
if hostname.Enabled {
enabledIPs = append(enabledIPs, hostname.Domain)
} else {
disabledIPs = append(disabledIPs, hostname.Domain)
}
}
// Finally, if the bucket contains anything, concatenate it all
// together and append it to the output. Also add a newline.
if len(enabledIPs) > 0 {
concat := fmt.Sprintf("%s %s", IP.String(), strings.Join(enabledIPs, " "))
out = append(out, []byte(concat)...)
out = append(out, []byte("\n")...)
}
if len(disabledIPs) > 0 {
concat := fmt.Sprintf("# %s %s", IP.String(), strings.Join(disabledIPs, " "))
out = append(out, []byte(concat)...)
out = append(out, []byte("\n")...)
}
}
return out
}
// Dump exports all entries in the Hostlist as JSON
func (h *Hostlist) Dump() ([]byte, error) {
return json.MarshalIndent(h, "", " ")
}
// Apply imports all entries from the JSON input to this Hostlist
func (h *Hostlist) Apply(jsonbytes []byte) error {
var hostnames Hostlist
err := json.Unmarshal(jsonbytes, &hostnames)
if err != nil {
return err
}
for _, hostname := range hostnames {
h.Add(hostname)
}
return nil
}