diff --git a/avlkey/avlkey.go b/avlkey/avlkey.go index 418e223..5a3de50 100644 --- a/avlkey/avlkey.go +++ b/avlkey/avlkey.go @@ -219,9 +219,9 @@ func (avl *AVL) GetNode(key interface{}) (*Node, bool) { func (avl *AVL) Put(key, value interface{}) { avl.size++ - node := &Node{key: key, value: value} + if avl.size == 1 { - avl.root = node + avl.root = &Node{key: key, value: value} return } @@ -232,6 +232,7 @@ func (avl *AVL) Put(key, value interface{}) { for { if cur == nil { + node := &Node{key: key, value: value} parent.children[child] = node node.parent = parent node.child = child @@ -243,14 +244,16 @@ func (avl *AVL) Put(key, value interface{}) { } parent = cur - c := avl.comparator(node.key, cur.key) - if c > -1 { // right - child = 1 - cur = cur.children[child] - } else { - child = 0 - cur = cur.children[child] + c := avl.comparator(key, cur.key) + child = (c + 2) / 2 + if c == 0 { + // node := &Node{key: key, value: value} + cur.key = key + cur.value = value + return } + + cur = cur.children[child] } } @@ -320,7 +323,7 @@ func (avl *AVL) lrrotate(cur *Node) { if mov.children[l] != nil { movparent.children[r] = mov.children[l] movparent.children[r].parent = movparent - movparent.children[r].child = 1 + movparent.children[r].child = l } else { movparent.children[r] = nil } @@ -361,7 +364,7 @@ func (avl *AVL) rlrotate(cur *Node) { if mov.children[l] != nil { movparent.children[r] = mov.children[l] movparent.children[r].parent = movparent - movparent.children[r].child = 1 + movparent.children[r].child = r } else { movparent.children[r] = nil } @@ -398,11 +401,16 @@ func (avl *AVL) rrotate(cur *Node) { mov.key, mov.value, cur.key, cur.value = cur.key, cur.value, mov.key, mov.value //交换值达到, 相对位移 - if mov.children[l] != nil { - mov.children[l].parent, cur.children[l] = cur, mov.children[l] - } else { - cur.children[l] = nil - } + // if mov.children[l] != nil { + // mov.children[l].parent = cur + // cur.children[l] = mov.children[l] + // } else { + // cur.children[l] = nil + // } + + // mov.children[l] 不可能为nil + mov.children[l].parent = cur + cur.children[l] = mov.children[l] if mov.children[r] != nil { mov.children[l] = mov.children[r] @@ -423,23 +431,27 @@ func (avl *AVL) rrotate(cur *Node) { mov.height = getMaxChildrenHeight(mov) + 1 cur.height = getMaxChildrenHeight(cur) + 1 - } func (avl *AVL) lrotate(cur *Node) { const l = 1 const r = 0 - // 1 right 0 left + mov := cur.children[l] mov.key, mov.value, cur.key, cur.value = cur.key, cur.value, mov.key, mov.value //交换值达到, 相对位移 - if mov.children[l] != nil { - mov.children[l].parent, cur.children[l] = cur, mov.children[l] - } else { - cur.children[l] = nil - } + // if mov.children[l] != nil { + // mov.children[l].parent = cur + // cur.children[l] = mov.children[l] + // } else { + // cur.children[l] = nil + // } + + // 不可能为 + mov.children[l].parent = cur + cur.children[l] = mov.children[l] if mov.children[r] != nil { mov.children[l] = mov.children[r] @@ -553,15 +565,12 @@ func (avl *AVL) fixPutHeight(cur *Node) { } else { avl.lrotate(cur) } - } else if diff > 1 { - l := cur.children[0] if getHeight(l.children[1]) > getHeight(l.children[0]) { avl.rlrotate(cur) } else { avl.rrotate(cur) - } } else { diff --git a/avlkey/avlkey_test.go b/avlkey/avlkey_test.go index 04e0081..2142d02 100644 --- a/avlkey/avlkey_test.go +++ b/avlkey/avlkey_test.go @@ -128,7 +128,7 @@ func TestPutStable(t *testing.T) { func TestPutComparatorRandom(t *testing.T) { - for n := 0; n < 1000000; n++ { + for n := 0; n < 400000; n++ { avl := New(utils.IntComparator) godsavl := avltree.NewWithIntComparator() @@ -165,7 +165,9 @@ func TestPutComparatorRandom(t *testing.T) { // avl := New(utils.IntComparator) // for i := 0; i < 15; i++ { // avl.Put(randomdata.Number(0, 1000)) -// } +// } const l = 1 +// const r = 0 + // t.Error(avl.String()) // t.Error(avl.Get(500)) // }