diff --git a/sets/treeset/treeset.go b/sets/treeset/treeset.go index 42d0eee..25acbcf 100644 --- a/sets/treeset/treeset.go +++ b/sets/treeset/treeset.go @@ -33,6 +33,7 @@ import ( func assertInterfaceImplementation() { var _ sets.Set = (*Set)(nil) + var _ containers.Enumerable = (*Set)(nil) var _ containers.IteratorWithIndex = (*Iterator)(nil) } @@ -71,7 +72,7 @@ func (set *Set) Remove(items ...interface{}) { } } -// Check wether items (one or more) are present in the set. +// Check weather items (one or more) are present in the set. // All items have to be present in the set for the method to return true. // Returns true if no arguments are passed at all, i.e. set is always superset of empty set. func (set *Set) Contains(items ...interface{}) bool { @@ -125,6 +126,63 @@ func (iterator *Iterator) Index() int { return iterator.index } +func (set *Set) Each(f func(index interface{}, value interface{})) { + iterator := set.Iterator() + for iterator.Next() { + f(iterator.Index(), iterator.Value()) + } +} + +func (set *Set) Map(f func(index interface{}, value interface{}) interface{}) containers.Container { + newSet := &Set{tree: rbt.NewWith(set.tree.Comparator)} + iterator := set.Iterator() + for iterator.Next() { + newSet.Add(f(iterator.Index(), iterator.Value())) + } + return newSet +} + +func (set *Set) Select(f func(index interface{}, value interface{}) bool) containers.Container { + newSet := &Set{tree: rbt.NewWith(set.tree.Comparator)} + iterator := set.Iterator() + for iterator.Next() { + if f(iterator.Index(), iterator.Value()) { + newSet.Add(iterator.Value()) + } + } + return newSet +} + +func (set *Set) Any(f func(index interface{}, value interface{}) bool) bool { + iterator := set.Iterator() + for iterator.Next() { + if f(iterator.Index(), iterator.Value()) { + return true + } + } + return false +} + +func (set *Set) All(f func(index interface{}, value interface{}) bool) bool { + iterator := set.Iterator() + for iterator.Next() { + if !f(iterator.Index(), iterator.Value()) { + return false + } + } + return true +} + +func (set *Set) Find(f func(index interface{}, value interface{}) bool) (index interface{}, value interface{}) { + iterator := set.Iterator() + for iterator.Next() { + if f(iterator.Index(), iterator.Value()) { + return iterator.Index(), iterator.Value() + } + } + return nil, nil +} + func (set *Set) String() string { str := "TreeSet\n" items := []string{} diff --git a/sets/treeset/treeset_test.go b/sets/treeset/treeset_test.go index 780a9db..78f1e3a 100644 --- a/sets/treeset/treeset_test.go +++ b/sets/treeset/treeset_test.go @@ -84,11 +84,100 @@ func TestTreeSet(t *testing.T) { } } -func TestTreeSetIterator(t *testing.T) { +func TestTreeSetEnumerableAndIterator(t *testing.T) { set := NewWithStringComparator() - set.Add("c") - set.Add("a") - set.Add("b") + set.Add("c", "a", "b") + + // Each + set.Each(func(index interface{}, value interface{}) { + switch index { + case 0: + if actualValue, expectedValue := value, "a"; actualValue != expectedValue { + t.Errorf("Got %v expected %v", actualValue, expectedValue) + } + case 1: + if actualValue, expectedValue := value, "b"; actualValue != expectedValue { + t.Errorf("Got %v expected %v", actualValue, expectedValue) + } + case 2: + if actualValue, expectedValue := value, "c"; actualValue != expectedValue { + t.Errorf("Got %v expected %v", actualValue, expectedValue) + } + default: + t.Errorf("Too many") + } + }) + + // Map + mappedSet := set.Map(func(index interface{}, value interface{}) interface{} { + return "mapped: " + value.(string) + }).(*Set) + if actualValue, expectedValue := mappedSet.Contains("mapped: a", "mapped: b", "mapped: c"), true; actualValue != expectedValue { + t.Errorf("Got %v expected %v", actualValue, expectedValue) + } + if actualValue, expectedValue := mappedSet.Contains("mapped: a", "mapped: b", "mapped: x"), false; actualValue != expectedValue { + t.Errorf("Got %v expected %v", actualValue, expectedValue) + } + if mappedSet.Size() != 3 { + t.Errorf("Got %v expected %v", mappedSet.Size(), 3) + } + + // Select + selectedSet := set.Select(func(index interface{}, value interface{}) bool { + return value.(string) >= "a" && value.(string) <= "b" + }).(*Set) + if actualValue, expectedValue := selectedSet.Contains("a", "b"), true; actualValue != expectedValue { + fmt.Println("A: ", mappedSet.Contains("b")) + t.Errorf("Got %v (%v) expected %v (%v)", actualValue, selectedSet.Values(), expectedValue, "[a b]") + } + if actualValue, expectedValue := selectedSet.Contains("a", "b", "c"), false; actualValue != expectedValue { + t.Errorf("Got %v (%v) expected %v (%v)", actualValue, selectedSet.Values(), expectedValue, "[a b]") + } + if selectedSet.Size() != 2 { + t.Errorf("Got %v expected %v", selectedSet.Size(), 3) + } + + // Any + any := set.Any(func(index interface{}, value interface{}) bool { + return value.(string) == "c" + }) + if any != true { + t.Errorf("Got %v expected %v", any, true) + } + any = set.Any(func(index interface{}, value interface{}) bool { + return value.(string) == "x" + }) + if any != false { + t.Errorf("Got %v expected %v", any, false) + } + + // All + all := set.All(func(index interface{}, value interface{}) bool { + return value.(string) >= "a" && value.(string) <= "c" + }) + if all != true { + t.Errorf("Got %v expected %v", all, true) + } + all = set.All(func(index interface{}, value interface{}) bool { + return value.(string) >= "a" && value.(string) <= "b" + }) + if all != false { + t.Errorf("Got %v expected %v", all, false) + } + + // Find + foundIndex, foundValue := set.Find(func(index interface{}, value interface{}) bool { + return value.(string) == "c" + }) + if foundValue != "c" || foundIndex != 2 { + t.Errorf("Got %v at %v expected %v at %v", foundValue, foundIndex, "c", 2) + } + foundIndex, foundValue = set.Find(func(index interface{}, value interface{}) bool { + return value.(string) == "x" + }) + if foundValue != nil || foundIndex != nil { + t.Errorf("Got %v at %v expected %v at %v", foundValue, foundIndex, nil, nil) + } // Iterator it := set.Iterator()