From 6b0ffefe7f7ccad3bf653f453401b49a3d05048e Mon Sep 17 00:00:00 2001 From: Emir Pasic Date: Fri, 15 Apr 2022 00:07:42 +0200 Subject: [PATCH] Fix iterator in binary heap --- trees/binaryheap/binaryheap.go | 10 +++-- trees/binaryheap/binaryheap_test.go | 66 ++++++++++++++--------------- trees/binaryheap/iterator.go | 36 +++++++++++++++- 3 files changed, 74 insertions(+), 38 deletions(-) diff --git a/trees/binaryheap/binaryheap.go b/trees/binaryheap/binaryheap.go index b3412c5..e658f25 100644 --- a/trees/binaryheap/binaryheap.go +++ b/trees/binaryheap/binaryheap.go @@ -97,15 +97,19 @@ func (heap *Heap) Clear() { // Values returns all elements in the heap. func (heap *Heap) Values() []interface{} { - return heap.list.Values() + values := make([]interface{}, heap.list.Size(), heap.list.Size()) + for it := heap.Iterator(); it.Next(); { + values[it.Index()] = it.Value() + } + return values } // String returns a string representation of container func (heap *Heap) String() string { str := "BinaryHeap\n" values := []string{} - for _, value := range heap.list.Values() { - values = append(values, fmt.Sprintf("%v", value)) + for it := heap.Iterator(); it.Next(); { + values = append(values, fmt.Sprintf("%v", it.Value())) } str += strings.Join(values, ", ") return str diff --git a/trees/binaryheap/binaryheap_test.go b/trees/binaryheap/binaryheap_test.go index 26d99ea..bb5c42b 100644 --- a/trees/binaryheap/binaryheap_test.go +++ b/trees/binaryheap/binaryheap_test.go @@ -18,11 +18,11 @@ func TestBinaryHeapPush(t *testing.T) { t.Errorf("Got %v expected %v", actualValue, true) } - heap.Push(3) // [3] - heap.Push(2) // [2,3] - heap.Push(1) // [1,3,2](2 swapped with 1, hence last) + heap.Push(3) + heap.Push(2) + heap.Push(1) - if actualValue := heap.Values(); actualValue[0].(int) != 1 || actualValue[1].(int) != 3 || actualValue[2].(int) != 2 { + if actualValue := heap.Values(); actualValue[0].(int) != 1 || actualValue[1].(int) != 2 || actualValue[2].(int) != 3 { t.Errorf("Got %v expected %v", actualValue, "[1,2,3]") } if actualValue := heap.Empty(); actualValue != false { @@ -56,10 +56,10 @@ func TestBinaryHeapPop(t *testing.T) { t.Errorf("Got %v expected %v", actualValue, true) } - heap.Push(3) // [3] - heap.Push(2) // [2,3] - heap.Push(1) // [1,3,2](2 swapped with 1, hence last) - heap.Pop() // [3,2] + heap.Push(3) + heap.Push(2) + heap.Push(1) + heap.Pop() if actualValue, ok := heap.Peek(); actualValue != 2 || !ok { t.Errorf("Got %v expected %v", actualValue, 2) @@ -110,9 +110,9 @@ func TestBinaryHeapIteratorOnEmpty(t *testing.T) { func TestBinaryHeapIteratorNext(t *testing.T) { heap := NewWithIntComparator() - heap.Push(3) // [3] - heap.Push(2) // [2,3] - heap.Push(1) // [1,3,2](2 swapped with 1, hence last) + heap.Push(3) + heap.Push(2) + heap.Push(1) it := heap.Iterator() count := 0 @@ -126,11 +126,11 @@ func TestBinaryHeapIteratorNext(t *testing.T) { t.Errorf("Got %v expected %v", actualValue, expectedValue) } case 1: - if actualValue, expectedValue := value, 3; actualValue != expectedValue { + if actualValue, expectedValue := value, 2; actualValue != expectedValue { t.Errorf("Got %v expected %v", actualValue, expectedValue) } case 2: - if actualValue, expectedValue := value, 2; actualValue != expectedValue { + if actualValue, expectedValue := value, 3; actualValue != expectedValue { t.Errorf("Got %v expected %v", actualValue, expectedValue) } default: @@ -147,9 +147,9 @@ func TestBinaryHeapIteratorNext(t *testing.T) { func TestBinaryHeapIteratorPrev(t *testing.T) { heap := NewWithIntComparator() - heap.Push(3) // [3] - heap.Push(2) // [2,3] - heap.Push(1) // [1,3,2](2 swapped with 1, hence last) + heap.Push(3) + heap.Push(2) + heap.Push(1) it := heap.Iterator() for it.Next() { @@ -165,11 +165,11 @@ func TestBinaryHeapIteratorPrev(t *testing.T) { t.Errorf("Got %v expected %v", actualValue, expectedValue) } case 1: - if actualValue, expectedValue := value, 3; actualValue != expectedValue { + if actualValue, expectedValue := value, 2; actualValue != expectedValue { t.Errorf("Got %v expected %v", actualValue, expectedValue) } case 2: - if actualValue, expectedValue := value, 2; actualValue != expectedValue { + if actualValue, expectedValue := value, 3; actualValue != expectedValue { t.Errorf("Got %v expected %v", actualValue, expectedValue) } default: @@ -213,17 +213,17 @@ func TestBinaryHeapIteratorEnd(t *testing.T) { t.Errorf("Got %v expected %v", index, 0) } - heap.Push(3) // [3] - heap.Push(2) // [2,3] - heap.Push(1) // [1,3,2](2 swapped with 1, hence last) + heap.Push(3) + heap.Push(2) + heap.Push(1) it.End() if index := it.Index(); index != heap.Size() { t.Errorf("Got %v expected %v", index, heap.Size()) } it.Prev() - if index, value := it.Index(), it.Value(); index != heap.Size()-1 || value != 2 { - t.Errorf("Got %v,%v expected %v,%v", index, value, heap.Size()-1, 2) + if index, value := it.Index(), it.Value(); index != heap.Size()-1 || value != 3 { + t.Errorf("Got %v,%v expected %v,%v", index, value, heap.Size()-1, 3) } } @@ -233,9 +233,9 @@ func TestBinaryHeapIteratorFirst(t *testing.T) { if actualValue, expectedValue := it.First(), false; actualValue != expectedValue { t.Errorf("Got %v expected %v", actualValue, expectedValue) } - heap.Push(3) // [3] - heap.Push(2) // [2,3] - heap.Push(1) // [1,3,2](2 swapped with 1, hence last) + heap.Push(3) + heap.Push(2) + heap.Push(1) if actualValue, expectedValue := it.First(), true; actualValue != expectedValue { t.Errorf("Got %v expected %v", actualValue, expectedValue) } @@ -252,12 +252,12 @@ func TestBinaryHeapIteratorLast(t *testing.T) { } tree.Push(2) tree.Push(3) - tree.Push(1) // [1,3,2](2 swapped with 1, hence last) + tree.Push(1) if actualValue, expectedValue := it.Last(), true; actualValue != expectedValue { t.Errorf("Got %v expected %v", actualValue, expectedValue) } - if index, value := it.Index(), it.Value(); index != 2 || value != 2 { - t.Errorf("Got %v,%v expected %v,%v", index, value, 2, 2) + if index, value := it.Index(), it.Value(); index != 2 || value != 3 { + t.Errorf("Got %v,%v expected %v,%v", index, value, 2, 3) } } @@ -370,13 +370,13 @@ func TestBinaryHeapIteratorPrevTo(t *testing.T) { func TestBinaryHeapSerialization(t *testing.T) { heap := NewWithStringComparator() - heap.Push("c") // ["c"] - heap.Push("b") // ["b","c"] - heap.Push("a") // ["a","c","b"]("b" swapped with "a", hence last) + heap.Push("c") + heap.Push("b") + heap.Push("a") var err error assert := func() { - if actualValue := heap.Values(); actualValue[0].(string) != "a" || actualValue[1].(string) != "c" || actualValue[2].(string) != "b" { + if actualValue := heap.Values(); actualValue[0].(string) != "a" || actualValue[1].(string) != "b" || actualValue[2].(string) != "c" { t.Errorf("Got %v expected %v", actualValue, "[1,3,2]") } if actualValue := heap.Size(); actualValue != 3 { diff --git a/trees/binaryheap/iterator.go b/trees/binaryheap/iterator.go index 8a01b05..f217963 100644 --- a/trees/binaryheap/iterator.go +++ b/trees/binaryheap/iterator.go @@ -4,7 +4,9 @@ package binaryheap -import "github.com/emirpasic/gods/containers" +import ( + "github.com/emirpasic/gods/containers" +) // Assert Iterator implementation var _ containers.ReverseIteratorWithIndex = (*Iterator)(nil) @@ -44,7 +46,19 @@ func (iterator *Iterator) Prev() bool { // Value returns the current element's value. // Does not modify the state of the iterator. func (iterator *Iterator) Value() interface{} { - value, _ := iterator.heap.list.Get(iterator.index) + start, end := evaluateRange(iterator.index) + if end > iterator.heap.Size() { + end = iterator.heap.Size() + } + tmpHeap := NewWith(iterator.heap.Comparator) + for n := start; n < end; n++ { + value, _ := iterator.heap.list.Get(n) + tmpHeap.Push(value) + } + for n := 0; n < iterator.index-start; n++ { + tmpHeap.Pop() + } + value, _ := tmpHeap.Pop() return value } @@ -109,3 +123,21 @@ func (iterator *Iterator) PrevTo(f func(index int, value interface{}) bool) bool } return false } + +// numOfBits counts the number of bits of an int +func numOfBits(n int) uint { + var count uint + for n != 0 { + count++ + n >>= 1 + } + return count +} + +// evaluateRange evaluates the index range [start,end) of same level nodes in the heap as the index +func evaluateRange(index int) (start int, end int) { + bits := numOfBits(index+1) - 1 + start = 1<