From e438e7b77b53520c448304dd0263c4a0b69c9346 Mon Sep 17 00:00:00 2001 From: Emir Pasic Date: Wed, 13 Apr 2022 00:36:42 +0200 Subject: [PATCH] Set operations: intersection, union, difference --- README.md | 7 +- containers/enumerable.go | 4 - sets/hashset/hashset.go | 55 +++++++++++++ sets/hashset/hashset_test.go | 66 ++++++++++++++++ sets/linkedhashset/linkedhashset.go | 55 +++++++++++++ sets/linkedhashset/linkedhashset_test.go | 66 ++++++++++++++++ sets/sets.go | 3 + sets/treeset/treeset.go | 77 ++++++++++++++++++ sets/treeset/treeset_test.go | 99 ++++++++++++++++++++++++ 9 files changed, 427 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index 6d1cf06..920847b 100644 --- a/README.md +++ b/README.md @@ -230,6 +230,8 @@ func main() { A set is a data structure that can store elements and has no repeated values. It is a computer implementation of the mathematical concept of a finite set. Unlike most other collection types, rather than retrieving a specific element from a set, one typically tests an element for membership in a set. This structure is often used to ensure that no duplicates are present in a container. +Set additionally allow set operations such as [intersection](https://en.wikipedia.org/wiki/Intersection_(set_theory)), [union](https://en.wikipedia.org/wiki/Union_(set_theory)), [difference](https://proofwiki.org/wiki/Definition:Set_Difference), etc. + Implements [Container](#containers) interface. ```go @@ -237,7 +239,10 @@ type Set interface { Add(elements ...interface{}) Remove(elements ...interface{}) Contains(elements ...interface{}) bool - + // Intersection(another *Set) *Set + // Union(another *Set) *Set + // Difference(another *Set) *Set + containers.Container // Empty() bool // Size() int diff --git a/containers/enumerable.go b/containers/enumerable.go index ac48b54..7066005 100644 --- a/containers/enumerable.go +++ b/containers/enumerable.go @@ -11,11 +11,9 @@ type EnumerableWithIndex interface { // Map invokes the given function once for each element and returns a // container containing the values returned by the given function. - // TODO would appreciate help on how to enforce this in containers (don't want to type assert when chaining) // Map(func(index int, value interface{}) interface{}) Container // Select returns a new container containing all elements for which the given function returns a true value. - // TODO need help on how to enforce this in containers (don't want to type assert when chaining) // Select(func(index int, value interface{}) bool) Container // Any passes each element of the container to the given function and @@ -39,11 +37,9 @@ type EnumerableWithKey interface { // Map invokes the given function once for each element and returns a container // containing the values returned by the given function as key/value pairs. - // TODO need help on how to enforce this in containers (don't want to type assert when chaining) // Map(func(key interface{}, value interface{}) (interface{}, interface{})) Container // Select returns a new container containing all elements for which the given function returns a true value. - // TODO need help on how to enforce this in containers (don't want to type assert when chaining) // Select(func(key interface{}, value interface{}) bool) Container // Any passes each element of the container to the given function and diff --git a/sets/hashset/hashset.go b/sets/hashset/hashset.go index 815d049..99808f9 100644 --- a/sets/hashset/hashset.go +++ b/sets/hashset/hashset.go @@ -97,3 +97,58 @@ func (set *Set) String() string { str += strings.Join(items, ", ") return str } + +// Intersection returns the intersection between two sets. +// The new set consists of all elements that are both in "set" and "another". +// Ref: https://en.wikipedia.org/wiki/Intersection_(set_theory) +func (set *Set) Intersection(another *Set) *Set { + result := New() + + // Iterate over smaller set (optimization) + if set.Size() <= another.Size() { + for item, _ := range set.items { + if _, contains := another.items[item]; contains { + result.Add(item) + } + } + } else { + for item, _ := range another.items { + if _, contains := set.items[item]; contains { + result.Add(item) + } + } + } + + return result +} + +// Union returns the union of two sets. +// The new set consists of all elements that are in "set" or "another" (possibly both). +// Ref: https://en.wikipedia.org/wiki/Union_(set_theory) +func (set *Set) Union(another *Set) *Set { + result := New() + + for item, _ := range set.items { + result.Add(item) + } + for item, _ := range another.items { + result.Add(item) + } + + return result +} + +// Difference returns the difference between two sets. +// The new set consists of all elements that are in "set" but not in "another". +// Ref: https://proofwiki.org/wiki/Definition:Set_Difference +func (set *Set) Difference(another *Set) *Set { + result := New() + + for item, _ := range set.items { + if _, contains := another.items[item]; !contains { + result.Add(item) + } + } + + return result +} diff --git a/sets/hashset/hashset_test.go b/sets/hashset/hashset_test.go index 4351338..ccb50b3 100644 --- a/sets/hashset/hashset_test.go +++ b/sets/hashset/hashset_test.go @@ -111,6 +111,72 @@ func TestSetSerialization(t *testing.T) { } } +func TestSetIntersection(t *testing.T) { + set := New() + another := New() + + intersection := set.Intersection(another) + if actualValue, expectedValue := intersection.Size(), 0; actualValue != expectedValue { + t.Errorf("Got %v expected %v", actualValue, expectedValue) + } + + set.Add("a", "b", "c", "d") + another.Add("c", "d", "e", "f") + + intersection = set.Intersection(another) + + if actualValue, expectedValue := intersection.Size(), 2; actualValue != expectedValue { + t.Errorf("Got %v expected %v", actualValue, expectedValue) + } + if actualValue := intersection.Contains("c", "d"); actualValue != true { + t.Errorf("Got %v expected %v", actualValue, true) + } +} + +func TestSetUnion(t *testing.T) { + set := New() + another := New() + + union := set.Union(another) + if actualValue, expectedValue := union.Size(), 0; actualValue != expectedValue { + t.Errorf("Got %v expected %v", actualValue, expectedValue) + } + + set.Add("a", "b", "c", "d") + another.Add("c", "d", "e", "f") + + union = set.Union(another) + + if actualValue, expectedValue := union.Size(), 6; actualValue != expectedValue { + t.Errorf("Got %v expected %v", actualValue, expectedValue) + } + if actualValue := union.Contains("a", "b", "c", "d", "e", "f"); actualValue != true { + t.Errorf("Got %v expected %v", actualValue, true) + } +} + +func TestSetDifference(t *testing.T) { + set := New() + another := New() + + difference := set.Difference(another) + if actualValue, expectedValue := difference.Size(), 0; actualValue != expectedValue { + t.Errorf("Got %v expected %v", actualValue, expectedValue) + } + + set.Add("a", "b", "c", "d") + another.Add("c", "d", "e", "f") + + difference = set.Difference(another) + + if actualValue, expectedValue := difference.Size(), 2; actualValue != expectedValue { + t.Errorf("Got %v expected %v", actualValue, expectedValue) + } + if actualValue := difference.Contains("a", "b"); actualValue != true { + t.Errorf("Got %v expected %v", actualValue, true) + } +} + func benchmarkContains(b *testing.B, set *Set, size int) { for i := 0; i < b.N; i++ { for n := 0; n < size; n++ { diff --git a/sets/linkedhashset/linkedhashset.go b/sets/linkedhashset/linkedhashset.go index e589a12..46ebdfe 100644 --- a/sets/linkedhashset/linkedhashset.go +++ b/sets/linkedhashset/linkedhashset.go @@ -116,3 +116,58 @@ func (set *Set) String() string { str += strings.Join(items, ", ") return str } + +// Intersection returns the intersection between two sets. +// The new set consists of all elements that are both in "set" and "another". +// Ref: https://en.wikipedia.org/wiki/Intersection_(set_theory) +func (set *Set) Intersection(another *Set) *Set { + result := New() + + // Iterate over smaller set (optimization) + if set.Size() <= another.Size() { + for item, _ := range set.table { + if _, contains := another.table[item]; contains { + result.Add(item) + } + } + } else { + for item, _ := range another.table { + if _, contains := set.table[item]; contains { + result.Add(item) + } + } + } + + return result +} + +// Union returns the union of two sets. +// The new set consists of all elements that are in "set" or "another" (possibly both). +// Ref: https://en.wikipedia.org/wiki/Union_(set_theory) +func (set *Set) Union(another *Set) *Set { + result := New() + + for item, _ := range set.table { + result.Add(item) + } + for item, _ := range another.table { + result.Add(item) + } + + return result +} + +// Difference returns the difference between two sets. +// The new set consists of all elements that are in "set" but not in "another". +// Ref: https://proofwiki.org/wiki/Definition:Set_Difference +func (set *Set) Difference(another *Set) *Set { + result := New() + + for item, _ := range set.table { + if _, contains := another.table[item]; !contains { + result.Add(item) + } + } + + return result +} diff --git a/sets/linkedhashset/linkedhashset_test.go b/sets/linkedhashset/linkedhashset_test.go index 59db9ad..7e3c236 100644 --- a/sets/linkedhashset/linkedhashset_test.go +++ b/sets/linkedhashset/linkedhashset_test.go @@ -465,6 +465,72 @@ func TestSetSerialization(t *testing.T) { } } +func TestSetIntersection(t *testing.T) { + set := New() + another := New() + + intersection := set.Intersection(another) + if actualValue, expectedValue := intersection.Size(), 0; actualValue != expectedValue { + t.Errorf("Got %v expected %v", actualValue, expectedValue) + } + + set.Add("a", "b", "c", "d") + another.Add("c", "d", "e", "f") + + intersection = set.Intersection(another) + + if actualValue, expectedValue := intersection.Size(), 2; actualValue != expectedValue { + t.Errorf("Got %v expected %v", actualValue, expectedValue) + } + if actualValue := intersection.Contains("c", "d"); actualValue != true { + t.Errorf("Got %v expected %v", actualValue, true) + } +} + +func TestSetUnion(t *testing.T) { + set := New() + another := New() + + union := set.Union(another) + if actualValue, expectedValue := union.Size(), 0; actualValue != expectedValue { + t.Errorf("Got %v expected %v", actualValue, expectedValue) + } + + set.Add("a", "b", "c", "d") + another.Add("c", "d", "e", "f") + + union = set.Union(another) + + if actualValue, expectedValue := union.Size(), 6; actualValue != expectedValue { + t.Errorf("Got %v expected %v", actualValue, expectedValue) + } + if actualValue := union.Contains("a", "b", "c", "d", "e", "f"); actualValue != true { + t.Errorf("Got %v expected %v", actualValue, true) + } +} + +func TestSetDifference(t *testing.T) { + set := New() + another := New() + + difference := set.Difference(another) + if actualValue, expectedValue := difference.Size(), 0; actualValue != expectedValue { + t.Errorf("Got %v expected %v", actualValue, expectedValue) + } + + set.Add("a", "b", "c", "d") + another.Add("c", "d", "e", "f") + + difference = set.Difference(another) + + if actualValue, expectedValue := difference.Size(), 2; actualValue != expectedValue { + t.Errorf("Got %v expected %v", actualValue, expectedValue) + } + if actualValue := difference.Contains("a", "b"); actualValue != true { + t.Errorf("Got %v expected %v", actualValue, true) + } +} + func benchmarkContains(b *testing.B, set *Set, size int) { for i := 0; i < b.N; i++ { for n := 0; n < size; n++ { diff --git a/sets/sets.go b/sets/sets.go index 2573297..d96801c 100644 --- a/sets/sets.go +++ b/sets/sets.go @@ -16,6 +16,9 @@ type Set interface { Add(elements ...interface{}) Remove(elements ...interface{}) Contains(elements ...interface{}) bool + // Intersection(another *Set) *Set + // Union(another *Set) *Set + // Difference(another *Set) *Set containers.Container // Empty() bool diff --git a/sets/treeset/treeset.go b/sets/treeset/treeset.go index 7efbf2d..7e7d1d6 100644 --- a/sets/treeset/treeset.go +++ b/sets/treeset/treeset.go @@ -14,6 +14,7 @@ import ( "github.com/emirpasic/gods/sets" rbt "github.com/emirpasic/gods/trees/redblacktree" "github.com/emirpasic/gods/utils" + "reflect" "strings" ) @@ -111,3 +112,79 @@ func (set *Set) String() string { str += strings.Join(items, ", ") return str } + +// Intersection returns the intersection between two sets. +// The new set consists of all elements that are both in "set" and "another". +// The two sets should have the same comparators, otherwise the result is empty set. +// Ref: https://en.wikipedia.org/wiki/Intersection_(set_theory) +func (set *Set) Intersection(another *Set) *Set { + result := NewWith(set.tree.Comparator) + + setComparator := reflect.ValueOf(set.tree.Comparator) + anotherComparator := reflect.ValueOf(another.tree.Comparator) + if setComparator.Pointer() != anotherComparator.Pointer() { + return result + } + + // Iterate over smaller set (optimization) + if set.Size() <= another.Size() { + for it := set.Iterator(); it.Next(); { + if another.Contains(it.Value()) { + result.Add(it.Value()) + } + } + } else { + for it := another.Iterator(); it.Next(); { + if set.Contains(it.Value()) { + result.Add(it.Value()) + } + } + } + + return result +} + +// Union returns the union of two sets. +// The new set consists of all elements that are in "set" or "another" (possibly both). +// The two sets should have the same comparators, otherwise the result is empty set. +// Ref: https://en.wikipedia.org/wiki/Union_(set_theory) +func (set *Set) Union(another *Set) *Set { + result := NewWith(set.tree.Comparator) + + setComparator := reflect.ValueOf(set.tree.Comparator) + anotherComparator := reflect.ValueOf(another.tree.Comparator) + if setComparator.Pointer() != anotherComparator.Pointer() { + return result + } + + for it := set.Iterator(); it.Next(); { + result.Add(it.Value()) + } + for it := another.Iterator(); it.Next(); { + result.Add(it.Value()) + } + + return result +} + +// Difference returns the difference between two sets. +// The two sets should have the same comparators, otherwise the result is empty set. +// The new set consists of all elements that are in "set" but not in "another". +// Ref: https://proofwiki.org/wiki/Definition:Set_Difference +func (set *Set) Difference(another *Set) *Set { + result := NewWith(set.tree.Comparator) + + setComparator := reflect.ValueOf(set.tree.Comparator) + anotherComparator := reflect.ValueOf(another.tree.Comparator) + if setComparator.Pointer() != anotherComparator.Pointer() { + return result + } + + for it := set.Iterator(); it.Next(); { + if !another.Contains(it.Value()) { + result.Add(it.Value()) + } + } + + return result +} diff --git a/sets/treeset/treeset_test.go b/sets/treeset/treeset_test.go index 20a6f6a..2839d4b 100644 --- a/sets/treeset/treeset_test.go +++ b/sets/treeset/treeset_test.go @@ -474,6 +474,105 @@ func TestSetSerialization(t *testing.T) { } } +func TestSetIntersection(t *testing.T) { + { + set := NewWithStringComparator() + another := NewWithIntComparator() + set.Add("a", "b", "c", "d") + another.Add(1, 2, 3, 4) + difference := set.Difference(another) + if actualValue, expectedValue := difference.Size(), 0; actualValue != expectedValue { + t.Errorf("Got %v expected %v", actualValue, expectedValue) + } + } + + set := NewWithStringComparator() + another := NewWithStringComparator() + + intersection := set.Intersection(another) + if actualValue, expectedValue := intersection.Size(), 0; actualValue != expectedValue { + t.Errorf("Got %v expected %v", actualValue, expectedValue) + } + + set.Add("a", "b", "c", "d") + another.Add("c", "d", "e", "f") + + intersection = set.Intersection(another) + + if actualValue, expectedValue := intersection.Size(), 2; actualValue != expectedValue { + t.Errorf("Got %v expected %v", actualValue, expectedValue) + } + if actualValue := intersection.Contains("c", "d"); actualValue != true { + t.Errorf("Got %v expected %v", actualValue, true) + } +} + +func TestSetUnion(t *testing.T) { + { + set := NewWithStringComparator() + another := NewWithIntComparator() + set.Add("a", "b", "c", "d") + another.Add(1, 2, 3, 4) + difference := set.Difference(another) + if actualValue, expectedValue := difference.Size(), 0; actualValue != expectedValue { + t.Errorf("Got %v expected %v", actualValue, expectedValue) + } + } + + set := NewWithStringComparator() + another := NewWithStringComparator() + + union := set.Union(another) + if actualValue, expectedValue := union.Size(), 0; actualValue != expectedValue { + t.Errorf("Got %v expected %v", actualValue, expectedValue) + } + + set.Add("a", "b", "c", "d") + another.Add("c", "d", "e", "f") + + union = set.Union(another) + + if actualValue, expectedValue := union.Size(), 6; actualValue != expectedValue { + t.Errorf("Got %v expected %v", actualValue, expectedValue) + } + if actualValue := union.Contains("a", "b", "c", "d", "e", "f"); actualValue != true { + t.Errorf("Got %v expected %v", actualValue, true) + } +} + +func TestSetDifference(t *testing.T) { + { + set := NewWithStringComparator() + another := NewWithIntComparator() + set.Add("a", "b", "c", "d") + another.Add(1, 2, 3, 4) + difference := set.Difference(another) + if actualValue, expectedValue := difference.Size(), 0; actualValue != expectedValue { + t.Errorf("Got %v expected %v", actualValue, expectedValue) + } + } + + set := NewWithStringComparator() + another := NewWithStringComparator() + + difference := set.Difference(another) + if actualValue, expectedValue := difference.Size(), 0; actualValue != expectedValue { + t.Errorf("Got %v expected %v", actualValue, expectedValue) + } + + set.Add("a", "b", "c", "d") + another.Add("c", "d", "e", "f") + + difference = set.Difference(another) + + if actualValue, expectedValue := difference.Size(), 2; actualValue != expectedValue { + t.Errorf("Got %v expected %v", actualValue, expectedValue) + } + if actualValue := difference.Contains("a", "b"); actualValue != true { + t.Errorf("Got %v expected %v", actualValue, true) + } +} + func benchmarkContains(b *testing.B, set *Set, size int) { for i := 0; i < b.N; i++ { for n := 0; n < size; n++ {