From 55ac6c00beaa3e0227e34aecad63412f113bc876 Mon Sep 17 00:00:00 2001 From: huangsimin Date: Thu, 14 Mar 2019 14:32:10 +0800 Subject: [PATCH] TODO: finish Remove --- avlkey/avlkey.go | 258 ++++++++++++++++++++---------------------- avlkey/avlkey_test.go | 166 +++++++++++++++++---------- avlkey/iterator.go | 19 +++- 3 files changed, 241 insertions(+), 202 deletions(-) diff --git a/avlkey/avlkey.go b/avlkey/avlkey.go index 5a3de50..c5a31bc 100644 --- a/avlkey/avlkey.go +++ b/avlkey/avlkey.go @@ -10,7 +10,6 @@ type Node struct { children [2]*Node parent *Node height int - child int key, value interface{} } @@ -30,20 +29,20 @@ func (n *Node) String() string { if n.parent != nil { p = spew.Sprint(n.parent.value) } - return spew.Sprint(n.value) + "(" + p + "-" + spew.Sprint(n.child) + "|" + spew.Sprint(n.height) + ")" + return spew.Sprint(n.value) + "(" + p + "|" + spew.Sprint(n.height) + ")" } -type AVL struct { +type Tree struct { root *Node size int comparator utils.Comparator } -func New(comparator utils.Comparator) *AVL { - return &AVL{comparator: comparator} +func New(comparator utils.Comparator) *Tree { + return &Tree{comparator: comparator} } -func (avl *AVL) String() string { +func (avl *Tree) String() string { if avl.size == 0 { return "" } @@ -53,81 +52,80 @@ func (avl *AVL) String() string { return str } -func (avl *AVL) Iterator() *Iterator { +func (avl *Tree) Iterator() *Iterator { return initIterator(avl) } -func (avl *AVL) Size() int { +func (avl *Tree) Size() int { return avl.size } -// func (avl *AVL) Remove(key interface{}) *Node { +func (avl *Tree) Remove(key interface{}) *Node { -// if n, ok := avl.GetNode(key); ok { + if n, ok := avl.GetNode(key); ok { -// avl.size-- -// if avl.size == 0 { -// avl.root = nil -// return n -// } + avl.size-- + if avl.size == 0 { + avl.root = nil + return n + } -// left := getHeight(n.children[0]) -// right := getHeight(n.children[1]) + left := getHeight(n.children[0]) + right := getHeight(n.children[1]) -// if left == -1 && right == -1 { -// p := n.parent -// p.children[n.child] = nil -// avl.fixRemoveHeight(p) -// return n -// } + if left == -1 && right == -1 { + p := n.parent + p.children[n.child] = nil + avl.fixRemoveHeight(p) + return n + } -// var cur *Node -// if left > right { -// cur = n.children[0] -// for cur.children[1] != nil { -// cur = cur.children[1] -// } + var cur *Node + if left > right { + cur = n.children[0] + for cur.children[1] != nil { + cur = cur.children[1] + } -// cleft := cur.children[0] -// cur.parent.children[cur.child] = cleft -// if cleft != nil { -// cleft.child = cur.child -// cleft.parent = cur.parent -// } + cleft := cur.children[0] + cur.parent.children[cur.child] = cleft + if cleft != nil { + cleft.parent = cur.parent + } -// } else { -// cur = n.children[1] -// for cur.children[0] != nil { -// cur = cur.children[0] -// } + } else { + cur = n.children[1] + for cur.children[0] != nil { + cur = cur.children[0] + } -// cright := cur.children[1] -// cur.parent.children[cur.child] = cright -// if cright != nil { -// cright.child = cur.child -// cright.parent = cur.parent -// } -// } + cright := cur.children[1] + cur.parent.children[cur.child] = cright -// cparent := cur.parent -// // avl.replace(n, cur) 修改为interface -// temp := n.value -// n.value = cur.value -// cur.value = temp -// // 考虑到刚好替换的节点是 被替换节点的孩子节点的时候, 从自身修复高度 -// if cparent == n { -// avl.fixRemoveHeight(n) -// } else { -// avl.fixRemoveHeight(cparent) -// } + if cright != nil { + cright.parent = cur.parent + } + } -// return cur -// } + cparent := cur.parent + // avl.replace(n, cur) 修改为interface + temp := n.value + n.value = cur.value + cur.value = temp + // 考虑到刚好替换的节点是 被替换节点的孩子节点的时候, 从自身修复高度 + if cparent == n { + avl.fixRemoveHeight(n) + } else { + avl.fixRemoveHeight(cparent) + } -// return nil -// } + return cur + } -func (avl *AVL) Get(v interface{}) (interface{}, bool) { + return nil +} + +func (avl *Tree) Get(v interface{}) (interface{}, bool) { n, ok := avl.GetNode(v) if ok { return n.value, true @@ -135,7 +133,7 @@ func (avl *AVL) Get(v interface{}) (interface{}, bool) { return n, false } -func (avl *AVL) GetAround(key interface{}) (result [3]interface{}) { +func (avl *Tree) GetAround(key interface{}) (result [3]interface{}) { an := avl.GetAroundNode(key) for i, n := range an { if n.value != nil { @@ -145,7 +143,7 @@ func (avl *AVL) GetAround(key interface{}) (result [3]interface{}) { return } -func (avl *AVL) GetAroundNode(key interface{}) (result [3]*Node) { +func (avl *Tree) GetAroundNode(key interface{}) (result [3]*Node) { n := avl.root for { @@ -198,7 +196,7 @@ func (avl *AVL) GetAroundNode(key interface{}) (result [3]*Node) { } } -func (avl *AVL) GetNode(key interface{}) (*Node, bool) { +func (avl *Tree) GetNode(key interface{}) (*Node, bool) { n := avl.root for n != nil { @@ -217,11 +215,11 @@ func (avl *AVL) GetNode(key interface{}) (*Node, bool) { return nil, false } -func (avl *AVL) Put(key, value interface{}) { +func (avl *Tree) Put(key, value interface{}) { avl.size++ - + node := &Node{key: key, value: value} if avl.size == 1 { - avl.root = &Node{key: key, value: value} + avl.root = node return } @@ -232,11 +230,8 @@ func (avl *AVL) Put(key, value interface{}) { for { if cur == nil { - node := &Node{key: key, value: value} parent.children[child] = node node.parent = parent - node.child = child - if node.parent.height == 0 { avl.fixPutHeight(node.parent) } @@ -246,28 +241,21 @@ func (avl *AVL) Put(key, value interface{}) { parent = cur c := avl.comparator(key, cur.key) child = (c + 2) / 2 - if c == 0 { - // node := &Node{key: key, value: value} - cur.key = key - cur.value = value - return - } - cur = cur.children[child] } } -func (avl *AVL) debugString() string { +func (avl *Tree) debugString() string { if avl.size == 0 { return "" } - str := "AVL" + "\n" + str := "AVLTree\n" outputfordebug(avl.root, "", true, &str) return str } -func (avl *AVL) TraversalBreadth() (result []interface{}) { +func (avl *Tree) TraversalBreadth() (result []interface{}) { var traverasl func(cur *Node) traverasl = func(cur *Node) { if cur == nil { @@ -281,7 +269,7 @@ func (avl *AVL) TraversalBreadth() (result []interface{}) { return } -func (avl *AVL) TraversalDepth(leftright int) (result []interface{}) { +func (avl *Tree) TraversalDepth(leftright int) (result []interface{}) { if leftright < 0 { var traverasl func(cur *Node) @@ -310,7 +298,7 @@ func (avl *AVL) TraversalDepth(leftright int) (result []interface{}) { return } -func (avl *AVL) lrrotate(cur *Node) { +func (avl *Tree) lrrotate(cur *Node) { const l = 1 const r = 0 @@ -323,14 +311,14 @@ func (avl *AVL) lrrotate(cur *Node) { if mov.children[l] != nil { movparent.children[r] = mov.children[l] movparent.children[r].parent = movparent - movparent.children[r].child = l + //movparent.children[r].child = l } else { movparent.children[r] = nil } if mov.children[r] != nil { mov.children[l] = mov.children[r] - mov.children[l].child = l + //mov.children[l].child = l } else { mov.children[l] = nil } @@ -343,7 +331,6 @@ func (avl *AVL) lrrotate(cur *Node) { } cur.children[r] = mov - mov.child = r mov.parent = cur mov.height = getMaxChildrenHeight(mov) + 1 @@ -351,7 +338,7 @@ func (avl *AVL) lrrotate(cur *Node) { cur.height = getMaxChildrenHeight(cur) + 1 } -func (avl *AVL) rlrotate(cur *Node) { +func (avl *Tree) rlrotate(cur *Node) { const l = 0 const r = 1 @@ -364,14 +351,12 @@ func (avl *AVL) rlrotate(cur *Node) { if mov.children[l] != nil { movparent.children[r] = mov.children[l] movparent.children[r].parent = movparent - movparent.children[r].child = r } else { movparent.children[r] = nil } if mov.children[r] != nil { mov.children[l] = mov.children[r] - mov.children[l].child = l } else { mov.children[l] = nil } @@ -384,7 +369,6 @@ func (avl *AVL) rlrotate(cur *Node) { } cur.children[r] = mov - mov.child = r mov.parent = cur mov.height = getMaxChildrenHeight(mov) + 1 @@ -392,7 +376,7 @@ func (avl *AVL) rlrotate(cur *Node) { cur.height = getMaxChildrenHeight(cur) + 1 } -func (avl *AVL) rrotate(cur *Node) { +func (avl *Tree) rrotate(cur *Node) { const l = 0 const r = 1 @@ -408,13 +392,13 @@ func (avl *AVL) rrotate(cur *Node) { // cur.children[l] = nil // } - // mov.children[l] 不可能为nil + // 不可能为nil mov.children[l].parent = cur cur.children[l] = mov.children[l] + // 解决mov节点孩子转移的问题 if mov.children[r] != nil { mov.children[l] = mov.children[r] - mov.children[l].child = l } else { mov.children[l] = nil } @@ -426,14 +410,14 @@ func (avl *AVL) rrotate(cur *Node) { mov.children[r] = nil } + // 连接转移后的节点 由于mov只是与cur交换值,parent不变 cur.children[r] = mov - mov.child = r mov.height = getMaxChildrenHeight(mov) + 1 cur.height = getMaxChildrenHeight(cur) + 1 } -func (avl *AVL) lrotate(cur *Node) { +func (avl *Tree) lrotate(cur *Node) { const l = 1 const r = 0 @@ -449,13 +433,12 @@ func (avl *AVL) lrotate(cur *Node) { // cur.children[l] = nil // } - // 不可能为 + // 不可能为nil mov.children[l].parent = cur cur.children[l] = mov.children[l] if mov.children[r] != nil { mov.children[l] = mov.children[r] - mov.children[l].child = l } else { mov.children[l] = nil } @@ -468,7 +451,6 @@ func (avl *AVL) lrotate(cur *Node) { } cur.children[r] = mov - mov.child = r mov.height = getMaxChildrenHeight(mov) + 1 cur.height = getMaxChildrenHeight(cur) + 1 @@ -502,54 +484,54 @@ func getHeight(cur *Node) int { return cur.height } -// func (avl *AVL) fixRemoveHeight(cur *Node) { +func (avl *Tree) fixRemoveHeight(cur *Node) { -// for { + for { -// lefth, rigthh, lrmax := getMaxAndChildrenHeight(cur) + lefth, rigthh, lrmax := getMaxAndChildrenHeight(cur) -// // 判断当前节点是否有变化, 如果没变化的时候, 不需要往上修复 -// isBreak := false -// if cur.height == lrmax+1 { -// isBreak = true -// } else { -// cur.height = lrmax + 1 -// } + // 判断当前节点是否有变化, 如果没变化的时候, 不需要往上修复 + isBreak := false + if cur.height == lrmax+1 { + isBreak = true + } else { + cur.height = lrmax + 1 + } -// // 计算高度的差值 绝对值大于2的时候需要旋转 -// diff := lefth - rigthh -// if diff < -1 { -// r := cur.children[1] // 根据左旋转的右边节点的子节点 左右高度选择旋转的方式 -// if getHeight(r.children[0]) > getHeight(r.children[1]) { -// cur = avl.lrrotate(cur) -// } else { -// cur = avl.lrotate(cur) -// } -// } else if diff > 1 { -// l := cur.children[0] -// if getHeight(l.children[1]) > getHeight(l.children[0]) { -// cur = avl.rlrotate(cur) -// } else { -// cur = avl.rrotate(cur) -// } -// } else { + // 计算高度的差值 绝对值大于2的时候需要旋转 + diff := lefth - rigthh + if diff < -1 { + r := cur.children[1] // 根据左旋转的右边节点的子节点 左右高度选择旋转的方式 + if getHeight(r.children[0]) > getHeight(r.children[1]) { + avl.lrrotate(cur) + } else { + avl.lrotate(cur) + } + } else if diff > 1 { + l := cur.children[0] + if getHeight(l.children[1]) > getHeight(l.children[0]) { + avl.rlrotate(cur) + } else { + avl.rrotate(cur) + } + } else { -// if isBreak { -// return -// } + if isBreak { + return + } -// } + } -// if cur.parent == nil { -// return -// } + if cur.parent == nil { + return + } -// cur = cur.parent -// } + cur = cur.parent + } -// } +} -func (avl *AVL) fixPutHeight(cur *Node) { +func (avl *Tree) fixPutHeight(cur *Node) { for { @@ -646,7 +628,7 @@ func outputfordebug(node *Node, prefix string, isTail bool, str *string) { } else { parentv = spew.Sprint(node.parent.value) } - suffix += parentv + "-" + spew.Sprint(node.child) + "|" + spew.Sprint(node.height) + ")" + suffix += parentv + "|" + spew.Sprint(node.height) + ")" *str += spew.Sprint(node.value) + suffix + "\n" if node.children[0] != nil { diff --git a/avlkey/avlkey_test.go b/avlkey/avlkey_test.go index 2142d02..d26f426 100644 --- a/avlkey/avlkey_test.go +++ b/avlkey/avlkey_test.go @@ -1,11 +1,17 @@ package avl import ( + "bytes" + "encoding/gob" + "io/ioutil" + "log" + "os" "testing" "github.com/Pallinder/go-randomdata" "github.com/davecgh/go-spew/spew" "github.com/emirpasic/gods/trees/avltree" + "github.com/emirpasic/gods/trees/redblacktree" "github.com/emirpasic/gods/utils" ) @@ -35,45 +41,45 @@ import ( // } // } -// func TestIterator(t *testing.T) { -// avl := New(utils.IntComparator) -// for _, v := range []int{1, 2, 7, 4, 5, 6, 7, 14, 15, 20, 30, 21, 3} { -// // t.Error(v) -// avl.Put(v) +func TestIterator(t *testing.T) { + avl := New(utils.IntComparator) + for _, v := range []int{1, 2, 7, 4, 5, 6, 7, 14, 15, 20, 30, 21, 3} { + // t.Error(v) + avl.Put(v, v) -// } -// t.Error(avl.TraversalDepth(1)) -// t.Error(avl.debugString()) -// iter := avl.Iterator() + } + t.Error(avl.TraversalDepth(1)) + t.Error(avl.debugString()) + iter := avl.Iterator() -// for iter.Prev() { -// t.Error(iter.Value()) -// } -// t.Error("prev == false", iter.Value(), iter.Prev(), iter.Value()) + for iter.Prev() { + t.Error(iter.Value()) + } + t.Error("prev == false", iter.Value(), iter.Prev(), iter.Value()) -// for iter.Next() { -// t.Error(iter.Value()) -// } -// t.Error("next == false", iter.Value(), iter.Next(), iter.Value()) + for iter.Next() { + t.Error(iter.Value()) + } + t.Error("next == false", iter.Value(), iter.Next(), iter.Value()) -// for iter.Prev() { -// t.Error(iter.Value()) -// } -// t.Error("prev == false", iter.Value()) + for iter.Prev() { + t.Error(iter.Value()) + } + t.Error("prev == false", iter.Value()) -// for i := 0; iter.Next(); i++ { -// t.Error(iter.Value()) -// if i >= 7 { -// break -// } -// } -// t.Error("next == false", iter.Value()) + for i := 0; iter.Next(); i++ { + t.Error(iter.Value()) + if i >= 7 { + break + } + } + t.Error("next == false", iter.Value()) -// for iter.Prev() { -// t.Error(iter.Value()) -// } -// t.Error("prev == false", iter.Value()) -// } + for iter.Prev() { + t.Error(iter.Value()) + } + t.Error("prev == false", iter.Value()) +} // func TestGetAround(t *testing.T) { // avl := New(utils.IntComparator) @@ -128,7 +134,7 @@ func TestPutStable(t *testing.T) { func TestPutComparatorRandom(t *testing.T) { - for n := 0; n < 400000; n++ { + for n := 0; n < 300000; n++ { avl := New(utils.IntComparator) godsavl := avltree.NewWithIntComparator() @@ -256,7 +262,7 @@ func TestPutComparatorRandom(t *testing.T) { // } // } -const CompartorSize = 300000 +const CompartorSize = 500000 const NumberMax = 60000000 // func BenchmarkIterator(b *testing.B) { @@ -454,52 +460,96 @@ const NumberMax = 60000000 // // } // } +func TestSave(t *testing.T) { + + f, err := os.OpenFile("./l.log", os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0666) + if err != nil { + log.Println(err) + } + + //fmt.Println(userBytes) + + var l []int + m := make(map[int]int) + for i := 0; len(l) < CompartorSize; i++ { + v := randomdata.Number(0, NumberMax) + if _, ok := m[v]; !ok { + m[v] = v + l = append(l, v) + } + } + + var result bytes.Buffer + encoder := gob.NewEncoder(&result) + encoder.Encode(l) + lbytes := result.Bytes() + f.Write(lbytes) + +} + func BenchmarkPut(b *testing.B) { avl := New(utils.IntComparator) - for i := 0; i < 100000; i++ { - avl.Put(randomdata.Number(0, NumberMax), i) + data, err := ioutil.ReadFile("./l.log") + if err != nil { + b.Error(err) } + var l []int + + decoder := gob.NewDecoder(bytes.NewReader(data)) + decoder.Decode(&l) + b.ResetTimer() b.StartTimer() - b.N = CompartorSize - for i := 0; i < b.N; i++ { - avl.Put(randomdata.Number(0, NumberMax), i) + b.N = len(l) + for _, v := range l { + avl.Put(v, v) } } -// func BenchmarkGodsRBPut(b *testing.B) { -// avl := redblacktree.NewWithIntComparator() +func BenchmarkGodsRBPut(b *testing.B) { + tree := redblacktree.NewWithIntComparator() -// for i := 0; i < 100000; i++ { -// avl.Put(randomdata.Number(0, NumberMax), i) -// } + data, err := ioutil.ReadFile("./l.log") + if err != nil { + b.Error(err) + } -// b.ResetTimer() -// b.StartTimer() + var l []int -// b.N = CompartorSize -// for i := 0; i < b.N; i++ { -// avl.Put(randomdata.Number(0, NumberMax), i) -// } + decoder := gob.NewDecoder(bytes.NewReader(data)) + decoder.Decode(&l) -// } + b.ResetTimer() + b.StartTimer() + + b.N = len(l) + for _, v := range l { + tree.Put(v, v) + } +} func BenchmarkGodsPut(b *testing.B) { - avl := avltree.NewWithIntComparator() + tree := avltree.NewWithIntComparator() - for i := 0; i < 100000; i++ { - avl.Put(randomdata.Number(0, NumberMax), i) + data, err := ioutil.ReadFile("./l.log") + if err != nil { + b.Error(err) } + var l []int + + decoder := gob.NewDecoder(bytes.NewReader(data)) + decoder.Decode(&l) + b.ResetTimer() b.StartTimer() - b.N = CompartorSize - for i := 0; i < b.N; i++ { - avl.Put(randomdata.Number(0, NumberMax), i) + b.N = len(l) + for _, v := range l { + tree.Put(v, v) } } diff --git a/avlkey/iterator.go b/avlkey/iterator.go index 65c9a31..3ca4567 100644 --- a/avlkey/iterator.go +++ b/avlkey/iterator.go @@ -5,7 +5,7 @@ import ( ) type Iterator struct { - op *AVL + op *Tree dir int up *Node @@ -14,7 +14,7 @@ type Iterator struct { // curnext *Node } -func initIterator(avltree *AVL) *Iterator { +func initIterator(avltree *Tree) *Iterator { iter := &Iterator{op: avltree, tstack: lastack.New()} iter.up = avltree.root return iter @@ -100,9 +100,16 @@ func (iter *Iterator) Next() (result bool) { return false } +func getRelationship(cur *Node) int { + if cur.parent.children[1] == cur { + return 1 + } + return 0 +} + func (iter *Iterator) getNextUp(cur *Node) *Node { - for cur != nil { - if cur.child == 1 { // next 在 降序 小值. 如果child在右边, parent 比 child 小, parent才有效, 符合降序 + for cur.parent != nil { + if getRelationship(cur) == 1 { // next 在 降序 小值. 如果child在右边, parent 比 child 小, parent才有效, 符合降序 return cur.parent } cur = cur.parent @@ -123,8 +130,8 @@ func (iter *Iterator) curPushNextStack(cur *Node) { } func (iter *Iterator) getPrevUp(cur *Node) *Node { - for cur != nil { - if cur.child == 0 { // Prev 在 降序 大值. 如果child在左边, parent 比 child 大, parent才有效 , 符合降序 + for cur.parent != nil { + if getRelationship(cur) == 0 { // Prev 在 降序 大值. 如果child在左边, parent 比 child 大, parent才有效 , 符合降序 return cur.parent } cur = cur.parent