- expose the root of the red-black tree to allow custom tree traversal

This commit is contained in:
Emir Pasic 2016-03-23 05:40:01 +01:00
parent 4257bbbae3
commit 8dab13c925

View File

@ -49,18 +49,18 @@ const (
) )
type Tree struct { type Tree struct {
root *node Root *Node
size int size int
comparator utils.Comparator comparator utils.Comparator
} }
type node struct { type Node struct {
key interface{} Key interface{}
value interface{} Value interface{}
color color color color
left *node Left *Node
right *node Right *Node
parent *node Parent *Node
} }
// Instantiates a red-black tree with the custom comparator. // Instantiates a red-black tree with the custom comparator.
@ -81,35 +81,35 @@ func NewWithStringComparator() *Tree {
// Inserts node into the tree. // Inserts node into the tree.
// Key should adhere to the comparator's type assertion, otherwise method panics. // Key should adhere to the comparator's type assertion, otherwise method panics.
func (tree *Tree) Put(key interface{}, value interface{}) { func (tree *Tree) Put(key interface{}, value interface{}) {
insertedNode := &node{key: key, value: value, color: red} insertedNode := &Node{Key: key, Value: value, color: red}
if tree.root == nil { if tree.Root == nil {
tree.root = insertedNode tree.Root = insertedNode
} else { } else {
node := tree.root node := tree.Root
loop := true loop := true
for loop { for loop {
compare := tree.comparator(key, node.key) compare := tree.comparator(key, node.Key)
switch { switch {
case compare == 0: case compare == 0:
node.value = value node.Value = value
return return
case compare < 0: case compare < 0:
if node.left == nil { if node.Left == nil {
node.left = insertedNode node.Left = insertedNode
loop = false loop = false
} else { } else {
node = node.left node = node.Left
} }
case compare > 0: case compare > 0:
if node.right == nil { if node.Right == nil {
node.right = insertedNode node.Right = insertedNode
loop = false loop = false
} else { } else {
node = node.right node = node.Right
} }
} }
} }
insertedNode.parent = node insertedNode.Parent = node
} }
tree.insertCase1(insertedNode) tree.insertCase1(insertedNode)
tree.size += 1 tree.size += 1
@ -121,7 +121,7 @@ func (tree *Tree) Put(key interface{}, value interface{}) {
func (tree *Tree) Get(key interface{}) (value interface{}, found bool) { func (tree *Tree) Get(key interface{}) (value interface{}, found bool) {
node := tree.lookup(key) node := tree.lookup(key)
if node != nil { if node != nil {
return node.value, true return node.Value, true
} }
return nil, false return nil, false
} }
@ -129,29 +129,29 @@ func (tree *Tree) Get(key interface{}) (value interface{}, found bool) {
// Remove the node from the tree by key. // Remove the node from the tree by key.
// Key should adhere to the comparator's type assertion, otherwise method panics. // Key should adhere to the comparator's type assertion, otherwise method panics.
func (tree *Tree) Remove(key interface{}) { func (tree *Tree) Remove(key interface{}) {
var child *node var child *Node
node := tree.lookup(key) node := tree.lookup(key)
if node == nil { if node == nil {
return return
} }
if node.left != nil && node.right != nil { if node.Left != nil && node.Right != nil {
pred := node.left.maximumNode() pred := node.Left.maximumNode()
node.key = pred.key node.Key = pred.Key
node.value = pred.value node.Value = pred.Value
node = pred node = pred
} }
if node.left == nil || node.right == nil { if node.Left == nil || node.Right == nil {
if node.right == nil { if node.Right == nil {
child = node.left child = node.Left
} else { } else {
child = node.right child = node.Right
} }
if node.color == black { if node.color == black {
node.color = nodeColor(child) node.color = nodeColor(child)
tree.deleteCase1(node) tree.deleteCase1(node)
} }
tree.replaceNode(node, child) tree.replaceNode(node, child)
if node.parent == nil && child != nil { if node.Parent == nil && child != nil {
child.color = black child.color = black
} }
} }
@ -172,7 +172,7 @@ func (tree *Tree) Size() int {
func (tree *Tree) Keys() []interface{} { func (tree *Tree) Keys() []interface{} {
keys := make([]interface{}, tree.size) keys := make([]interface{}, tree.size)
for i, node := range tree.inOrder() { for i, node := range tree.inOrder() {
keys[i] = node.key keys[i] = node.Key
} }
return keys return keys
} }
@ -181,48 +181,48 @@ func (tree *Tree) Keys() []interface{} {
func (tree *Tree) Values() []interface{} { func (tree *Tree) Values() []interface{} {
values := make([]interface{}, tree.size) values := make([]interface{}, tree.size)
for i, node := range tree.inOrder() { for i, node := range tree.inOrder() {
values[i] = node.value values[i] = node.Value
} }
return values return values
} }
// Removes all nodes from the tree. // Removes all nodes from the tree.
func (tree *Tree) Clear() { func (tree *Tree) Clear() {
tree.root = nil tree.Root = nil
tree.size = 0 tree.size = 0
} }
func (tree *Tree) String() string { func (tree *Tree) String() string {
str := "RedBlackTree\n" str := "RedBlackTree\n"
if !tree.Empty() { if !tree.Empty() {
output(tree.root, "", true, &str) output(tree.Root, "", true, &str)
} }
return str return str
} }
func (node *node) String() string { func (node *Node) String() string {
return fmt.Sprintf("%v", node.key) return fmt.Sprintf("%v", node.Key)
} }
// Returns all nodes in order // Returns all nodes in order
func (tree *Tree) inOrder() []*node { func (tree *Tree) inOrder() []*Node {
nodes := make([]*node, tree.size) nodes := make([]*Node, tree.size)
if tree.size > 0 { if tree.size > 0 {
current := tree.root current := tree.Root
stack := linkedliststack.New() stack := linkedliststack.New()
done := false done := false
count := 0 count := 0
for !done { for !done {
if current != nil { if current != nil {
stack.Push(current) stack.Push(current)
current = current.left current = current.Left
} else { } else {
if !stack.Empty() { if !stack.Empty() {
currentPop, _ := stack.Pop() currentPop, _ := stack.Pop()
current = currentPop.(*node) current = currentPop.(*Node)
nodes[count] = current nodes[count] = current
count += 1 count += 1
current = current.right current = current.Right
} else { } else {
done = true done = true
} }
@ -232,15 +232,15 @@ func (tree *Tree) inOrder() []*node {
return nodes return nodes
} }
func output(node *node, prefix string, isTail bool, str *string) { func output(node *Node, prefix string, isTail bool, str *string) {
if node.right != nil { if node.Right != nil {
newPrefix := prefix newPrefix := prefix
if isTail { if isTail {
newPrefix += "│ " newPrefix += "│ "
} else { } else {
newPrefix += " " newPrefix += " "
} }
output(node.right, newPrefix, false, str) output(node.Right, newPrefix, false, str)
} }
*str += prefix *str += prefix
if isTail { if isTail {
@ -249,114 +249,114 @@ func output(node *node, prefix string, isTail bool, str *string) {
*str += "┌── " *str += "┌── "
} }
*str += node.String() + "\n" *str += node.String() + "\n"
if node.left != nil { if node.Left != nil {
newPrefix := prefix newPrefix := prefix
if isTail { if isTail {
newPrefix += " " newPrefix += " "
} else { } else {
newPrefix += "│ " newPrefix += "│ "
} }
output(node.left, newPrefix, true, str) output(node.Left, newPrefix, true, str)
} }
} }
func (tree *Tree) lookup(key interface{}) *node { func (tree *Tree) lookup(key interface{}) *Node {
node := tree.root node := tree.Root
for node != nil { for node != nil {
compare := tree.comparator(key, node.key) compare := tree.comparator(key, node.Key)
switch { switch {
case compare == 0: case compare == 0:
return node return node
case compare < 0: case compare < 0:
node = node.left node = node.Left
case compare > 0: case compare > 0:
node = node.right node = node.Right
} }
} }
return nil return nil
} }
func (node *node) grandparent() *node { func (node *Node) grandparent() *Node {
if node != nil && node.parent != nil { if node != nil && node.Parent != nil {
return node.parent.parent return node.Parent.Parent
} }
return nil return nil
} }
func (node *node) uncle() *node { func (node *Node) uncle() *Node {
if node == nil || node.parent == nil || node.parent.parent == nil { if node == nil || node.Parent == nil || node.Parent.Parent == nil {
return nil return nil
} }
return node.parent.sibling() return node.Parent.sibling()
} }
func (node *node) sibling() *node { func (node *Node) sibling() *Node {
if node == nil || node.parent == nil { if node == nil || node.Parent == nil {
return nil return nil
} }
if node == node.parent.left { if node == node.Parent.Left {
return node.parent.right return node.Parent.Right
} else { } else {
return node.parent.left return node.Parent.Left
} }
} }
func (tree *Tree) rotateLeft(node *node) { func (tree *Tree) rotateLeft(node *Node) {
right := node.right right := node.Right
tree.replaceNode(node, right) tree.replaceNode(node, right)
node.right = right.left node.Right = right.Left
if right.left != nil { if right.Left != nil {
right.left.parent = node right.Left.Parent = node
} }
right.left = node right.Left = node
node.parent = right node.Parent = right
} }
func (tree *Tree) rotateRight(node *node) { func (tree *Tree) rotateRight(node *Node) {
left := node.left left := node.Left
tree.replaceNode(node, left) tree.replaceNode(node, left)
node.left = left.right node.Left = left.Right
if left.right != nil { if left.Right != nil {
left.right.parent = node left.Right.Parent = node
} }
left.right = node left.Right = node
node.parent = left node.Parent = left
} }
func (tree *Tree) replaceNode(old *node, new *node) { func (tree *Tree) replaceNode(old *Node, new *Node) {
if old.parent == nil { if old.Parent == nil {
tree.root = new tree.Root = new
} else { } else {
if old == old.parent.left { if old == old.Parent.Left {
old.parent.left = new old.Parent.Left = new
} else { } else {
old.parent.right = new old.Parent.Right = new
} }
} }
if new != nil { if new != nil {
new.parent = old.parent new.Parent = old.Parent
} }
} }
func (tree *Tree) insertCase1(node *node) { func (tree *Tree) insertCase1(node *Node) {
if node.parent == nil { if node.Parent == nil {
node.color = black node.color = black
} else { } else {
tree.insertCase2(node) tree.insertCase2(node)
} }
} }
func (tree *Tree) insertCase2(node *node) { func (tree *Tree) insertCase2(node *Node) {
if nodeColor(node.parent) == black { if nodeColor(node.Parent) == black {
return return
} }
tree.insertCase3(node) tree.insertCase3(node)
} }
func (tree *Tree) insertCase3(node *node) { func (tree *Tree) insertCase3(node *Node) {
uncle := node.uncle() uncle := node.uncle()
if nodeColor(uncle) == red { if nodeColor(uncle) == red {
node.parent.color = black node.Parent.color = black
uncle.color = black uncle.color = black
node.grandparent().color = red node.grandparent().color = red
tree.insertCase1(node.grandparent()) tree.insertCase1(node.grandparent())
@ -365,121 +365,121 @@ func (tree *Tree) insertCase3(node *node) {
} }
} }
func (tree *Tree) insertCase4(node *node) { func (tree *Tree) insertCase4(node *Node) {
grandparent := node.grandparent() grandparent := node.grandparent()
if node == node.parent.right && node.parent == grandparent.left { if node == node.Parent.Right && node.Parent == grandparent.Left {
tree.rotateLeft(node.parent) tree.rotateLeft(node.Parent)
node = node.left node = node.Left
} else if node == node.parent.left && node.parent == grandparent.right { } else if node == node.Parent.Left && node.Parent == grandparent.Right {
tree.rotateRight(node.parent) tree.rotateRight(node.Parent)
node = node.right node = node.Right
} }
tree.insertCase5(node) tree.insertCase5(node)
} }
func (tree *Tree) insertCase5(node *node) { func (tree *Tree) insertCase5(node *Node) {
node.parent.color = black node.Parent.color = black
grandparent := node.grandparent() grandparent := node.grandparent()
grandparent.color = red grandparent.color = red
if node == node.parent.left && node.parent == grandparent.left { if node == node.Parent.Left && node.Parent == grandparent.Left {
tree.rotateRight(grandparent) tree.rotateRight(grandparent)
} else if node == node.parent.right && node.parent == grandparent.right { } else if node == node.Parent.Right && node.Parent == grandparent.Right {
tree.rotateLeft(grandparent) tree.rotateLeft(grandparent)
} }
} }
func (node *node) maximumNode() *node { func (node *Node) maximumNode() *Node {
if node == nil { if node == nil {
return nil return nil
} }
for node.right != nil { for node.Right != nil {
node = node.right node = node.Right
} }
return node return node
} }
func (tree *Tree) deleteCase1(node *node) { func (tree *Tree) deleteCase1(node *Node) {
if node.parent == nil { if node.Parent == nil {
return return
} else { } else {
tree.deleteCase2(node) tree.deleteCase2(node)
} }
} }
func (tree *Tree) deleteCase2(node *node) { func (tree *Tree) deleteCase2(node *Node) {
sibling := node.sibling() sibling := node.sibling()
if nodeColor(sibling) == red { if nodeColor(sibling) == red {
node.parent.color = red node.Parent.color = red
sibling.color = black sibling.color = black
if node == node.parent.left { if node == node.Parent.Left {
tree.rotateLeft(node.parent) tree.rotateLeft(node.Parent)
} else { } else {
tree.rotateRight(node.parent) tree.rotateRight(node.Parent)
} }
} }
tree.deleteCase3(node) tree.deleteCase3(node)
} }
func (tree *Tree) deleteCase3(node *node) { func (tree *Tree) deleteCase3(node *Node) {
sibling := node.sibling() sibling := node.sibling()
if nodeColor(node.parent) == black && if nodeColor(node.Parent) == black &&
nodeColor(sibling) == black && nodeColor(sibling) == black &&
nodeColor(sibling.left) == black && nodeColor(sibling.Left) == black &&
nodeColor(sibling.right) == black { nodeColor(sibling.Right) == black {
sibling.color = red sibling.color = red
tree.deleteCase1(node.parent) tree.deleteCase1(node.Parent)
} else { } else {
tree.deleteCase4(node) tree.deleteCase4(node)
} }
} }
func (tree *Tree) deleteCase4(node *node) { func (tree *Tree) deleteCase4(node *Node) {
sibling := node.sibling() sibling := node.sibling()
if nodeColor(node.parent) == red && if nodeColor(node.Parent) == red &&
nodeColor(sibling) == black && nodeColor(sibling) == black &&
nodeColor(sibling.left) == black && nodeColor(sibling.Left) == black &&
nodeColor(sibling.right) == black { nodeColor(sibling.Right) == black {
sibling.color = red sibling.color = red
node.parent.color = black node.Parent.color = black
} else { } else {
tree.deleteCase5(node) tree.deleteCase5(node)
} }
} }
func (tree *Tree) deleteCase5(node *node) { func (tree *Tree) deleteCase5(node *Node) {
sibling := node.sibling() sibling := node.sibling()
if node == node.parent.left && if node == node.Parent.Left &&
nodeColor(sibling) == black && nodeColor(sibling) == black &&
nodeColor(sibling.left) == red && nodeColor(sibling.Left) == red &&
nodeColor(sibling.right) == black { nodeColor(sibling.Right) == black {
sibling.color = red sibling.color = red
sibling.left.color = black sibling.Left.color = black
tree.rotateRight(sibling) tree.rotateRight(sibling)
} else if node == node.parent.right && } else if node == node.Parent.Right &&
nodeColor(sibling) == black && nodeColor(sibling) == black &&
nodeColor(sibling.right) == red && nodeColor(sibling.Right) == red &&
nodeColor(sibling.left) == black { nodeColor(sibling.Left) == black {
sibling.color = red sibling.color = red
sibling.right.color = black sibling.Right.color = black
tree.rotateLeft(sibling) tree.rotateLeft(sibling)
} }
tree.deleteCase6(node) tree.deleteCase6(node)
} }
func (tree *Tree) deleteCase6(node *node) { func (tree *Tree) deleteCase6(node *Node) {
sibling := node.sibling() sibling := node.sibling()
sibling.color = nodeColor(node.parent) sibling.color = nodeColor(node.Parent)
node.parent.color = black node.Parent.color = black
if node == node.parent.left && nodeColor(sibling.right) == red { if node == node.Parent.Left && nodeColor(sibling.Right) == red {
sibling.right.color = black sibling.Right.color = black
tree.rotateLeft(node.parent) tree.rotateLeft(node.Parent)
} else if nodeColor(sibling.left) == red { } else if nodeColor(sibling.Left) == red {
sibling.left.color = black sibling.Left.color = black
tree.rotateRight(node.parent) tree.rotateRight(node.Parent)
} }
} }
func nodeColor(node *node) color { func nodeColor(node *Node) color {
if node == nil { if node == nil {
return black return black
} }