diff --git a/avltree/avltree.go b/avltree/avltree.go index b822946..c2e1ff7 100644 --- a/avltree/avltree.go +++ b/avltree/avltree.go @@ -36,35 +36,14 @@ func New(comparator utils.Comparator) *AVLTree { return &AVLTree{comparator: comparator} } -func (avl *AVLTree) replace(old *Node, newN *Node) { - - if old.parent == nil { - setChild(newN, 0, old.children[0]) - setChild(newN, 1, old.children[1]) - - newN.parent = nil - newN.child = -1 - newN.height = old.height - - avl.root = newN - } else { - - setChild(newN, 0, old.children[0]) - setChild(newN, 1, old.children[1]) - - newN.parent = old.parent - newN.child = old.child - newN.height = old.height - old.parent.children[old.child] = newN +func (avl *AVLTree) String() string { + if avl.size == 0 { + return "" } -} + str := "AVLTree" + "\n" + output(avl.root, "", true, &str) -func setChild(p *Node, child int, node *Node) { - p.children[child] = node - if node != nil { - node.child = child - node.parent = p - } + return str } func (avl *AVLTree) Remove(v interface{}) *Node { @@ -195,14 +174,41 @@ func (avl *AVLTree) Put(v interface{}) { } -func (avl *AVLTree) String() string { - if avl.size == 0 { - return "" - } - str := "AVLTree" + "\n" - output(avl.root, "", true, &str) +func (avl *AVLTree) replace(old *Node, newN *Node) { - return str + if old.parent == nil { + setChild(newN, 0, old.children[0]) + setChild(newN, 1, old.children[1]) + + newN.parent = nil + newN.child = -1 + newN.height = old.height + + avl.root = newN + } else { + + setChild(newN, 0, old.children[0]) + setChild(newN, 1, old.children[1]) + + newN.parent = old.parent + newN.child = old.child + newN.height = old.height + old.parent.children[old.child] = newN + } +} + +func setChild(p *Node, child int, node *Node) { + p.children[child] = node + if node != nil { + node.child = child + node.parent = p + } +} + +func setChildNotNil(p *Node, child int, node *Node) { + p.children[child] = node + node.child = child + node.parent = p } func (avl *AVLTree) debugString() string { @@ -265,7 +271,7 @@ func (avl *AVLTree) lrrotate(cur *Node) *Node { avl.root = rl rl.parent = nil } else { - setChild(cur.parent, cur.child, rl) + setChildNotNil(cur.parent, cur.child, rl) } rll := rl.children[0] @@ -274,8 +280,8 @@ func (avl *AVLTree) lrrotate(cur *Node) *Node { setChild(cur, 1, rll) setChild(r, 0, rlr) - setChild(rl, 0, cur) - setChild(rl, 1, r) + setChildNotNil(rl, 0, cur) + setChildNotNil(rl, 1, r) cur.height = getMaxChildrenHeight(cur) + 1 r.height = getMaxChildrenHeight(r) + 1 @@ -292,7 +298,7 @@ func (avl *AVLTree) rlrotate(cur *Node) *Node { avl.root = lr lr.parent = nil } else { - setChild(cur.parent, cur.child, lr) + setChildNotNil(cur.parent, cur.child, lr) } lrr := lr.children[1] @@ -300,8 +306,8 @@ func (avl *AVLTree) rlrotate(cur *Node) *Node { setChild(cur, 0, lrr) setChild(l, 1, lrl) - setChild(lr, 1, cur) - setChild(lr, 0, l) + setChildNotNil(lr, 1, cur) + setChildNotNil(lr, 0, l) cur.height = getMaxChildrenHeight(cur) + 1 l.height = getMaxChildrenHeight(l) + 1 @@ -317,16 +323,17 @@ func (avl *AVLTree) rrotate(cur *Node) *Node { setChild(cur, 0, l.children[1]) l.parent = cur.parent - if cur.parent != nil { - cur.parent.children[cur.child] = l - } else { + if cur.parent == nil { avl.root = l + } else { + cur.parent.children[cur.child] = l } l.child = cur.child - l.children[1] = cur - cur.child = 1 - cur.parent = l + setChildNotNil(l, 1, cur) + // l.children[1] = cur + // cur.child = 1 + // cur.parent = l cur.height = getMaxChildrenHeight(cur) + 1 l.height = getMaxChildrenHeight(l) + 1 @@ -343,18 +350,18 @@ func (avl *AVLTree) lrotate(cur *Node) *Node { // 设置 需要旋转的节点到当前节点的 链条 r.parent = cur.parent - if cur.parent != nil { - cur.parent.children[cur.child] = r - } else { + if cur.parent == nil { avl.root = r + } else { + cur.parent.children[cur.child] = r } r.child = cur.child // 当前节点旋转到 左边的 链条 - - r.children[0] = cur - cur.child = 0 - cur.parent = r + setChildNotNil(r, 0, cur) + // r.children[0] = cur + // cur.child = 0 + // cur.parent = r // 修复改动过的节点高度 先从低开始到高 cur.height = getMaxChildrenHeight(cur) + 1 @@ -368,10 +375,10 @@ func getMaxAndChildrenHeight(cur *Node) (h1, h2, maxh int) { h2 = getHeight(cur.children[1]) if h1 > h2 { maxh = h1 - return + } else { + maxh = h2 } - maxh = h2 return }