diff --git a/sets/hashset/hashset.go b/sets/hashset/hashset.go index 815d049..e49696e 100644 --- a/sets/hashset/hashset.go +++ b/sets/hashset/hashset.go @@ -97,3 +97,14 @@ func (set *Set) String() string { str += strings.Join(items, ", ") return str } + +// Intersection returns the intersection between two sets +func (set *Set) Intersection(another *Set) *Set { + result := New() + for item, _ := range another.items { + if set.Contains(item) { + result.Add(item) + } + } + return result +} diff --git a/sets/hashset/hashset_test.go b/sets/hashset/hashset_test.go index 4351338..4f2dc7d 100644 --- a/sets/hashset/hashset_test.go +++ b/sets/hashset/hashset_test.go @@ -111,6 +111,18 @@ func TestSetSerialization(t *testing.T) { } } +func TestSetIntersection(t *testing.T) { + set := New("a", "b", "c", "d") + anotherSet := New("c", "d", "f", "g") + intersection := set.Intersection(anotherSet) + if actualValue, expectedValue := intersection.Size(), 2; actualValue != expectedValue { + t.Errorf("Got %v expected %v", actualValue, expectedValue) + } + if actualValue := set.Contains("c", "d"); actualValue != true { + t.Errorf("Got %v expected %v", actualValue, true) + } +} + func benchmarkContains(b *testing.B, set *Set, size int) { for i := 0; i < b.N; i++ { for n := 0; n < size; n++ { diff --git a/sets/linkedhashset/linkedhashset.go b/sets/linkedhashset/linkedhashset.go index e589a12..df4deb3 100644 --- a/sets/linkedhashset/linkedhashset.go +++ b/sets/linkedhashset/linkedhashset.go @@ -116,3 +116,14 @@ func (set *Set) String() string { str += strings.Join(items, ", ") return str } + +// Intersection returns the intersection between two sets +func (set *Set) Intersection(another *Set) *Set { + result := New() + for item, _ := range another.table { + if set.Contains(item) { + result.Add(item) + } + } + return result +} diff --git a/sets/linkedhashset/linkedhashset_test.go b/sets/linkedhashset/linkedhashset_test.go index 59db9ad..16b0729 100644 --- a/sets/linkedhashset/linkedhashset_test.go +++ b/sets/linkedhashset/linkedhashset_test.go @@ -465,6 +465,18 @@ func TestSetSerialization(t *testing.T) { } } +func TestSetIntersection(t *testing.T) { + set := New("a", "b", "c", "d") + anotherSet := New("c", "d", "f", "g") + intersection := set.Intersection(anotherSet) + if actualValue, expectedValue := intersection.Size(), 2; actualValue != expectedValue { + t.Errorf("Got %v expected %v", actualValue, expectedValue) + } + if actualValue := set.Contains("c", "d"); actualValue != true { + t.Errorf("Got %v expected %v", actualValue, true) + } +} + func benchmarkContains(b *testing.B, set *Set, size int) { for i := 0; i < b.N; i++ { for n := 0; n < size; n++ { diff --git a/sets/sets.go b/sets/sets.go index 2573297..a81ec14 100644 --- a/sets/sets.go +++ b/sets/sets.go @@ -16,6 +16,7 @@ type Set interface { Add(elements ...interface{}) Remove(elements ...interface{}) Contains(elements ...interface{}) bool + Intersection(another *Set) *Set containers.Container // Empty() bool diff --git a/sets/treeset/treeset.go b/sets/treeset/treeset.go index 7efbf2d..fd23610 100644 --- a/sets/treeset/treeset.go +++ b/sets/treeset/treeset.go @@ -111,3 +111,14 @@ func (set *Set) String() string { str += strings.Join(items, ", ") return str } + +// Intersection returns the intersection between two sets +func (set *Set) Intersection(another *Set) *Set { + result := NewWith(set.tree.Comparator) + for _, item := range another.Values() { + if set.Contains(item) { + result.Add(item) + } + } + return result +} diff --git a/sets/treeset/treeset_test.go b/sets/treeset/treeset_test.go index 20a6f6a..2428d3c 100644 --- a/sets/treeset/treeset_test.go +++ b/sets/treeset/treeset_test.go @@ -474,6 +474,18 @@ func TestSetSerialization(t *testing.T) { } } +func TestSetIntersection(t *testing.T) { + set := NewWithStringComparator("a", "b", "c", "d") + anotherSet := NewWithStringComparator("c", "d", "f", "g") + intersection := set.Intersection(anotherSet) + if actualValue, expectedValue := intersection.Size(), 2; actualValue != expectedValue { + t.Errorf("Got %v expected %v", actualValue, expectedValue) + } + if actualValue := set.Contains("c", "d"); actualValue != true { + t.Errorf("Got %v expected %v", actualValue, true) + } +} + func benchmarkContains(b *testing.B, set *Set, size int) { for i := 0; i < b.N; i++ { for n := 0; n < size; n++ {