From b26bfdde9bfbe63f929dfa43f11f97b3ae92319e Mon Sep 17 00:00:00 2001
From: huangsimin <huangsimin@youmi.net>
Date: Mon, 18 Mar 2019 19:28:33 +0800
Subject: [PATCH] fixRemoveHeight

---
 avl/avl.go           | 12 ++++-----
 avlindex/avlindex.go | 58 ++++++++++++++++++++++++++++++++++++++++++--
 2 files changed, 61 insertions(+), 9 deletions(-)

diff --git a/avl/avl.go b/avl/avl.go
index 8c7db66..d9cdcb0 100644
--- a/avl/avl.go
+++ b/avl/avl.go
@@ -29,7 +29,6 @@ type Tree struct {
 	root       *Node
 	size       int
 	comparator utils.Comparator
-	count      int
 }
 
 func New(comparator utils.Comparator) *Tree {
@@ -559,6 +558,11 @@ func getHeight(cur *Node) int {
 	return cur.height
 }
 
+func abs(n int) int {
+	y := n >> 31
+	return (n ^ y) - y
+}
+
 func (avl *Tree) fixRemoveHeight(cur *Node) {
 
 	for {
@@ -576,7 +580,6 @@ func (avl *Tree) fixRemoveHeight(cur *Node) {
 		// 计算高度的差值 绝对值大于2的时候需要旋转
 		diff := lefth - rigthh
 		if diff < -1 {
-			avl.count++
 			r := cur.children[1] // 根据左旋转的右边节点的子节点 左右高度选择旋转的方式
 			if getHeight(r.children[0]) > getHeight(r.children[1]) {
 				avl.lrrotate(cur)
@@ -584,7 +587,6 @@ func (avl *Tree) fixRemoveHeight(cur *Node) {
 				avl.lrotate(cur)
 			}
 		} else if diff > 1 {
-			avl.count++
 			l := cur.children[0]
 			if getHeight(l.children[1]) > getHeight(l.children[0]) {
 				avl.rlrotate(cur)
@@ -592,11 +594,9 @@ func (avl *Tree) fixRemoveHeight(cur *Node) {
 				avl.rrotate(cur)
 			}
 		} else {
-
 			if isBreak {
 				return
 			}
-
 		}
 
 		if cur.parent == nil {
@@ -620,7 +620,6 @@ func (avl *Tree) fixPutHeight(cur *Node) {
 		// 计算高度的差值 绝对值大于2的时候需要旋转
 		diff := lefth - rigthh
 		if diff < -1 {
-			avl.count++
 			r := cur.children[1] // 根据左旋转的右边节点的子节点 左右高度选择旋转的方式
 			if getHeight(r.children[0]) > getHeight(r.children[1]) {
 				avl.lrrotate(cur)
@@ -628,7 +627,6 @@ func (avl *Tree) fixPutHeight(cur *Node) {
 				avl.lrotate(cur)
 			}
 		} else if diff > 1 {
-			avl.count++
 			l := cur.children[0]
 			if getHeight(l.children[1]) > getHeight(l.children[0]) {
 				avl.rlrotate(cur)
diff --git a/avlindex/avlindex.go b/avlindex/avlindex.go
index dc12665..78079b7 100644
--- a/avlindex/avlindex.go
+++ b/avlindex/avlindex.go
@@ -28,7 +28,6 @@ func (n *Node) String() string {
 type Tree struct {
 	root       *Node
 	comparator utils.Comparator
-	count      int
 }
 
 func New(comparator utils.Comparator) *Tree {
@@ -51,6 +50,62 @@ func (avl *Tree) Size() int {
 }
 
 func (avl *Tree) Remove(key interface{}) *Node {
+
+	if n, ok := avl.GetNode(key); ok {
+		if avl.root == n {
+			avl.root = nil
+			return n
+		}
+
+		ls, rs := getChildrenSize(n)
+		if ls == 0 && rs == 0 {
+			p := n.parent
+			p.children[getRelationship(n)] = nil
+			avl.fixRemoveHeight(p)
+			return n
+		}
+
+		var cur *Node
+		if ls > ls {
+			cur = n.children[0]
+			for cur.children[1] != nil {
+				cur = cur.children[1]
+			}
+
+			cleft := cur.children[0]
+			cur.parent.children[getRelationship(cur)] = cleft
+			if cleft != nil {
+				cleft.parent = cur.parent
+			}
+
+		} else {
+			cur = n.children[1]
+			for cur.children[0] != nil {
+				cur = cur.children[0]
+			}
+
+			cright := cur.children[1]
+			cur.parent.children[getRelationship(cur)] = cright
+
+			if cright != nil {
+				cright.parent = cur.parent
+			}
+		}
+
+		cparent := cur.parent
+		// 修改为interface 交换
+		n.value, cur.value = cur.value, n.value
+
+		// 考虑到刚好替换的节点是 被替换节点的孩子节点的时候, 从自身修复高度
+		if cparent == n {
+			avl.fixRemoveHeight(n)
+		} else {
+			avl.fixRemoveHeight(cparent)
+		}
+
+		return cur
+	}
+
 	return nil
 }
 
@@ -434,7 +489,6 @@ func (avl *Tree) fixRemoveHeight(cur *Node) {
 // }
 
 func (avl *Tree) fixPutHeight(cur *Node, lefts, rigths int) {
-	avl.count++
 	if lefts > rigths {
 		l := cur.children[0]
 		llsize, lrsize := getChildrenSize(l)