diff --git a/README.md b/README.md index 7cb7dd2..920847b 100644 --- a/README.md +++ b/README.md @@ -57,6 +57,7 @@ type Container interface { Size() int Clear() Values() []interface{} + String() string } ``` @@ -110,6 +111,7 @@ type List interface { // Size() int // Clear() // Values() []interface{} + // String() string } ``` @@ -228,6 +230,8 @@ func main() { A set is a data structure that can store elements and has no repeated values. It is a computer implementation of the mathematical concept of a finite set. Unlike most other collection types, rather than retrieving a specific element from a set, one typically tests an element for membership in a set. This structure is often used to ensure that no duplicates are present in a container. +Set additionally allow set operations such as [intersection](https://en.wikipedia.org/wiki/Intersection_(set_theory)), [union](https://en.wikipedia.org/wiki/Union_(set_theory)), [difference](https://proofwiki.org/wiki/Definition:Set_Difference), etc. + Implements [Container](#containers) interface. ```go @@ -235,12 +239,16 @@ type Set interface { Add(elements ...interface{}) Remove(elements ...interface{}) Contains(elements ...interface{}) bool - + // Intersection(another *Set) *Set + // Union(another *Set) *Set + // Difference(another *Set) *Set + containers.Container // Empty() bool // Size() int // Clear() // Values() []interface{} + // String() string } ``` @@ -343,6 +351,7 @@ type Stack interface { // Size() int // Clear() // Values() []interface{} + // String() string } ``` @@ -418,6 +427,7 @@ type Map interface { // Size() int // Clear() // Values() []interface{} + // String() string } ``` @@ -591,6 +601,7 @@ type Tree interface { // Size() int // Clear() // Values() []interface{} + // String() string } ``` @@ -1348,7 +1359,7 @@ func main() { ### Serialization -All data structures can be serialized (marshalled) and deserialized (unmarshalled). Currently only JSON support is available. +All data structures can be serialized (marshalled) and deserialized (unmarshalled). Currently, only JSON support is available. #### JSONSerializer @@ -1481,7 +1492,7 @@ Container specific operations: ```go // Returns sorted container''s elements with respect to the passed comparator. -// Does not effect the ordering of elements within the container. +// Does not affect the ordering of elements within the container. func GetSortedValues(container Container, comparator utils.Comparator) []interface{} ``` diff --git a/containers/enumerable.go b/containers/enumerable.go index ac48b54..7066005 100644 --- a/containers/enumerable.go +++ b/containers/enumerable.go @@ -11,11 +11,9 @@ type EnumerableWithIndex interface { // Map invokes the given function once for each element and returns a // container containing the values returned by the given function. - // TODO would appreciate help on how to enforce this in containers (don't want to type assert when chaining) // Map(func(index int, value interface{}) interface{}) Container // Select returns a new container containing all elements for which the given function returns a true value. - // TODO need help on how to enforce this in containers (don't want to type assert when chaining) // Select(func(index int, value interface{}) bool) Container // Any passes each element of the container to the given function and @@ -39,11 +37,9 @@ type EnumerableWithKey interface { // Map invokes the given function once for each element and returns a container // containing the values returned by the given function as key/value pairs. - // TODO need help on how to enforce this in containers (don't want to type assert when chaining) // Map(func(key interface{}, value interface{}) (interface{}, interface{})) Container // Select returns a new container containing all elements for which the given function returns a true value. - // TODO need help on how to enforce this in containers (don't want to type assert when chaining) // Select(func(key interface{}, value interface{}) bool) Container // Any passes each element of the container to the given function and diff --git a/sets/hashset/hashset.go b/sets/hashset/hashset.go index 815d049..558e628 100644 --- a/sets/hashset/hashset.go +++ b/sets/hashset/hashset.go @@ -97,3 +97,58 @@ func (set *Set) String() string { str += strings.Join(items, ", ") return str } + +// Intersection returns the intersection between two sets. +// The new set consists of all elements that are both in "set" and "another". +// Ref: https://en.wikipedia.org/wiki/Intersection_(set_theory) +func (set *Set) Intersection(another *Set) *Set { + result := New() + + // Iterate over smaller set (optimization) + if set.Size() <= another.Size() { + for item := range set.items { + if _, contains := another.items[item]; contains { + result.Add(item) + } + } + } else { + for item := range another.items { + if _, contains := set.items[item]; contains { + result.Add(item) + } + } + } + + return result +} + +// Union returns the union of two sets. +// The new set consists of all elements that are in "set" or "another" (possibly both). +// Ref: https://en.wikipedia.org/wiki/Union_(set_theory) +func (set *Set) Union(another *Set) *Set { + result := New() + + for item := range set.items { + result.Add(item) + } + for item := range another.items { + result.Add(item) + } + + return result +} + +// Difference returns the difference between two sets. +// The new set consists of all elements that are in "set" but not in "another". +// Ref: https://proofwiki.org/wiki/Definition:Set_Difference +func (set *Set) Difference(another *Set) *Set { + result := New() + + for item := range set.items { + if _, contains := another.items[item]; !contains { + result.Add(item) + } + } + + return result +} diff --git a/sets/hashset/hashset_test.go b/sets/hashset/hashset_test.go index 4351338..ccb50b3 100644 --- a/sets/hashset/hashset_test.go +++ b/sets/hashset/hashset_test.go @@ -111,6 +111,72 @@ func TestSetSerialization(t *testing.T) { } } +func TestSetIntersection(t *testing.T) { + set := New() + another := New() + + intersection := set.Intersection(another) + if actualValue, expectedValue := intersection.Size(), 0; actualValue != expectedValue { + t.Errorf("Got %v expected %v", actualValue, expectedValue) + } + + set.Add("a", "b", "c", "d") + another.Add("c", "d", "e", "f") + + intersection = set.Intersection(another) + + if actualValue, expectedValue := intersection.Size(), 2; actualValue != expectedValue { + t.Errorf("Got %v expected %v", actualValue, expectedValue) + } + if actualValue := intersection.Contains("c", "d"); actualValue != true { + t.Errorf("Got %v expected %v", actualValue, true) + } +} + +func TestSetUnion(t *testing.T) { + set := New() + another := New() + + union := set.Union(another) + if actualValue, expectedValue := union.Size(), 0; actualValue != expectedValue { + t.Errorf("Got %v expected %v", actualValue, expectedValue) + } + + set.Add("a", "b", "c", "d") + another.Add("c", "d", "e", "f") + + union = set.Union(another) + + if actualValue, expectedValue := union.Size(), 6; actualValue != expectedValue { + t.Errorf("Got %v expected %v", actualValue, expectedValue) + } + if actualValue := union.Contains("a", "b", "c", "d", "e", "f"); actualValue != true { + t.Errorf("Got %v expected %v", actualValue, true) + } +} + +func TestSetDifference(t *testing.T) { + set := New() + another := New() + + difference := set.Difference(another) + if actualValue, expectedValue := difference.Size(), 0; actualValue != expectedValue { + t.Errorf("Got %v expected %v", actualValue, expectedValue) + } + + set.Add("a", "b", "c", "d") + another.Add("c", "d", "e", "f") + + difference = set.Difference(another) + + if actualValue, expectedValue := difference.Size(), 2; actualValue != expectedValue { + t.Errorf("Got %v expected %v", actualValue, expectedValue) + } + if actualValue := difference.Contains("a", "b"); actualValue != true { + t.Errorf("Got %v expected %v", actualValue, true) + } +} + func benchmarkContains(b *testing.B, set *Set, size int) { for i := 0; i < b.N; i++ { for n := 0; n < size; n++ { diff --git a/sets/linkedhashset/linkedhashset.go b/sets/linkedhashset/linkedhashset.go index e589a12..e028591 100644 --- a/sets/linkedhashset/linkedhashset.go +++ b/sets/linkedhashset/linkedhashset.go @@ -116,3 +116,58 @@ func (set *Set) String() string { str += strings.Join(items, ", ") return str } + +// Intersection returns the intersection between two sets. +// The new set consists of all elements that are both in "set" and "another". +// Ref: https://en.wikipedia.org/wiki/Intersection_(set_theory) +func (set *Set) Intersection(another *Set) *Set { + result := New() + + // Iterate over smaller set (optimization) + if set.Size() <= another.Size() { + for item := range set.table { + if _, contains := another.table[item]; contains { + result.Add(item) + } + } + } else { + for item := range another.table { + if _, contains := set.table[item]; contains { + result.Add(item) + } + } + } + + return result +} + +// Union returns the union of two sets. +// The new set consists of all elements that are in "set" or "another" (possibly both). +// Ref: https://en.wikipedia.org/wiki/Union_(set_theory) +func (set *Set) Union(another *Set) *Set { + result := New() + + for item := range set.table { + result.Add(item) + } + for item := range another.table { + result.Add(item) + } + + return result +} + +// Difference returns the difference between two sets. +// The new set consists of all elements that are in "set" but not in "another". +// Ref: https://proofwiki.org/wiki/Definition:Set_Difference +func (set *Set) Difference(another *Set) *Set { + result := New() + + for item := range set.table { + if _, contains := another.table[item]; !contains { + result.Add(item) + } + } + + return result +} diff --git a/sets/linkedhashset/linkedhashset_test.go b/sets/linkedhashset/linkedhashset_test.go index 59db9ad..7e3c236 100644 --- a/sets/linkedhashset/linkedhashset_test.go +++ b/sets/linkedhashset/linkedhashset_test.go @@ -465,6 +465,72 @@ func TestSetSerialization(t *testing.T) { } } +func TestSetIntersection(t *testing.T) { + set := New() + another := New() + + intersection := set.Intersection(another) + if actualValue, expectedValue := intersection.Size(), 0; actualValue != expectedValue { + t.Errorf("Got %v expected %v", actualValue, expectedValue) + } + + set.Add("a", "b", "c", "d") + another.Add("c", "d", "e", "f") + + intersection = set.Intersection(another) + + if actualValue, expectedValue := intersection.Size(), 2; actualValue != expectedValue { + t.Errorf("Got %v expected %v", actualValue, expectedValue) + } + if actualValue := intersection.Contains("c", "d"); actualValue != true { + t.Errorf("Got %v expected %v", actualValue, true) + } +} + +func TestSetUnion(t *testing.T) { + set := New() + another := New() + + union := set.Union(another) + if actualValue, expectedValue := union.Size(), 0; actualValue != expectedValue { + t.Errorf("Got %v expected %v", actualValue, expectedValue) + } + + set.Add("a", "b", "c", "d") + another.Add("c", "d", "e", "f") + + union = set.Union(another) + + if actualValue, expectedValue := union.Size(), 6; actualValue != expectedValue { + t.Errorf("Got %v expected %v", actualValue, expectedValue) + } + if actualValue := union.Contains("a", "b", "c", "d", "e", "f"); actualValue != true { + t.Errorf("Got %v expected %v", actualValue, true) + } +} + +func TestSetDifference(t *testing.T) { + set := New() + another := New() + + difference := set.Difference(another) + if actualValue, expectedValue := difference.Size(), 0; actualValue != expectedValue { + t.Errorf("Got %v expected %v", actualValue, expectedValue) + } + + set.Add("a", "b", "c", "d") + another.Add("c", "d", "e", "f") + + difference = set.Difference(another) + + if actualValue, expectedValue := difference.Size(), 2; actualValue != expectedValue { + t.Errorf("Got %v expected %v", actualValue, expectedValue) + } + if actualValue := difference.Contains("a", "b"); actualValue != true { + t.Errorf("Got %v expected %v", actualValue, true) + } +} + func benchmarkContains(b *testing.B, set *Set, size int) { for i := 0; i < b.N; i++ { for n := 0; n < size; n++ { diff --git a/sets/sets.go b/sets/sets.go index 2573297..d96801c 100644 --- a/sets/sets.go +++ b/sets/sets.go @@ -16,6 +16,9 @@ type Set interface { Add(elements ...interface{}) Remove(elements ...interface{}) Contains(elements ...interface{}) bool + // Intersection(another *Set) *Set + // Union(another *Set) *Set + // Difference(another *Set) *Set containers.Container // Empty() bool diff --git a/sets/treeset/treeset.go b/sets/treeset/treeset.go index 7efbf2d..7e7d1d6 100644 --- a/sets/treeset/treeset.go +++ b/sets/treeset/treeset.go @@ -14,6 +14,7 @@ import ( "github.com/emirpasic/gods/sets" rbt "github.com/emirpasic/gods/trees/redblacktree" "github.com/emirpasic/gods/utils" + "reflect" "strings" ) @@ -111,3 +112,79 @@ func (set *Set) String() string { str += strings.Join(items, ", ") return str } + +// Intersection returns the intersection between two sets. +// The new set consists of all elements that are both in "set" and "another". +// The two sets should have the same comparators, otherwise the result is empty set. +// Ref: https://en.wikipedia.org/wiki/Intersection_(set_theory) +func (set *Set) Intersection(another *Set) *Set { + result := NewWith(set.tree.Comparator) + + setComparator := reflect.ValueOf(set.tree.Comparator) + anotherComparator := reflect.ValueOf(another.tree.Comparator) + if setComparator.Pointer() != anotherComparator.Pointer() { + return result + } + + // Iterate over smaller set (optimization) + if set.Size() <= another.Size() { + for it := set.Iterator(); it.Next(); { + if another.Contains(it.Value()) { + result.Add(it.Value()) + } + } + } else { + for it := another.Iterator(); it.Next(); { + if set.Contains(it.Value()) { + result.Add(it.Value()) + } + } + } + + return result +} + +// Union returns the union of two sets. +// The new set consists of all elements that are in "set" or "another" (possibly both). +// The two sets should have the same comparators, otherwise the result is empty set. +// Ref: https://en.wikipedia.org/wiki/Union_(set_theory) +func (set *Set) Union(another *Set) *Set { + result := NewWith(set.tree.Comparator) + + setComparator := reflect.ValueOf(set.tree.Comparator) + anotherComparator := reflect.ValueOf(another.tree.Comparator) + if setComparator.Pointer() != anotherComparator.Pointer() { + return result + } + + for it := set.Iterator(); it.Next(); { + result.Add(it.Value()) + } + for it := another.Iterator(); it.Next(); { + result.Add(it.Value()) + } + + return result +} + +// Difference returns the difference between two sets. +// The two sets should have the same comparators, otherwise the result is empty set. +// The new set consists of all elements that are in "set" but not in "another". +// Ref: https://proofwiki.org/wiki/Definition:Set_Difference +func (set *Set) Difference(another *Set) *Set { + result := NewWith(set.tree.Comparator) + + setComparator := reflect.ValueOf(set.tree.Comparator) + anotherComparator := reflect.ValueOf(another.tree.Comparator) + if setComparator.Pointer() != anotherComparator.Pointer() { + return result + } + + for it := set.Iterator(); it.Next(); { + if !another.Contains(it.Value()) { + result.Add(it.Value()) + } + } + + return result +} diff --git a/sets/treeset/treeset_test.go b/sets/treeset/treeset_test.go index 20a6f6a..2839d4b 100644 --- a/sets/treeset/treeset_test.go +++ b/sets/treeset/treeset_test.go @@ -474,6 +474,105 @@ func TestSetSerialization(t *testing.T) { } } +func TestSetIntersection(t *testing.T) { + { + set := NewWithStringComparator() + another := NewWithIntComparator() + set.Add("a", "b", "c", "d") + another.Add(1, 2, 3, 4) + difference := set.Difference(another) + if actualValue, expectedValue := difference.Size(), 0; actualValue != expectedValue { + t.Errorf("Got %v expected %v", actualValue, expectedValue) + } + } + + set := NewWithStringComparator() + another := NewWithStringComparator() + + intersection := set.Intersection(another) + if actualValue, expectedValue := intersection.Size(), 0; actualValue != expectedValue { + t.Errorf("Got %v expected %v", actualValue, expectedValue) + } + + set.Add("a", "b", "c", "d") + another.Add("c", "d", "e", "f") + + intersection = set.Intersection(another) + + if actualValue, expectedValue := intersection.Size(), 2; actualValue != expectedValue { + t.Errorf("Got %v expected %v", actualValue, expectedValue) + } + if actualValue := intersection.Contains("c", "d"); actualValue != true { + t.Errorf("Got %v expected %v", actualValue, true) + } +} + +func TestSetUnion(t *testing.T) { + { + set := NewWithStringComparator() + another := NewWithIntComparator() + set.Add("a", "b", "c", "d") + another.Add(1, 2, 3, 4) + difference := set.Difference(another) + if actualValue, expectedValue := difference.Size(), 0; actualValue != expectedValue { + t.Errorf("Got %v expected %v", actualValue, expectedValue) + } + } + + set := NewWithStringComparator() + another := NewWithStringComparator() + + union := set.Union(another) + if actualValue, expectedValue := union.Size(), 0; actualValue != expectedValue { + t.Errorf("Got %v expected %v", actualValue, expectedValue) + } + + set.Add("a", "b", "c", "d") + another.Add("c", "d", "e", "f") + + union = set.Union(another) + + if actualValue, expectedValue := union.Size(), 6; actualValue != expectedValue { + t.Errorf("Got %v expected %v", actualValue, expectedValue) + } + if actualValue := union.Contains("a", "b", "c", "d", "e", "f"); actualValue != true { + t.Errorf("Got %v expected %v", actualValue, true) + } +} + +func TestSetDifference(t *testing.T) { + { + set := NewWithStringComparator() + another := NewWithIntComparator() + set.Add("a", "b", "c", "d") + another.Add(1, 2, 3, 4) + difference := set.Difference(another) + if actualValue, expectedValue := difference.Size(), 0; actualValue != expectedValue { + t.Errorf("Got %v expected %v", actualValue, expectedValue) + } + } + + set := NewWithStringComparator() + another := NewWithStringComparator() + + difference := set.Difference(another) + if actualValue, expectedValue := difference.Size(), 0; actualValue != expectedValue { + t.Errorf("Got %v expected %v", actualValue, expectedValue) + } + + set.Add("a", "b", "c", "d") + another.Add("c", "d", "e", "f") + + difference = set.Difference(another) + + if actualValue, expectedValue := difference.Size(), 2; actualValue != expectedValue { + t.Errorf("Got %v expected %v", actualValue, expectedValue) + } + if actualValue := difference.Contains("a", "b"); actualValue != true { + t.Errorf("Got %v expected %v", actualValue, true) + } +} + func benchmarkContains(b *testing.B, set *Set, size int) { for i := 0; i < b.N; i++ { for n := 0; n < size; n++ {