成功 旋转3

This commit is contained in:
eson 2019-04-07 04:32:43 +08:00
parent eb954ffb47
commit c9c8308ebf
9 changed files with 313 additions and 207 deletions

View File

@ -22,13 +22,7 @@ func (pq *PriorityQueue) Size() int {
} }
func (pq *PriorityQueue) Push(value interface{}) { func (pq *PriorityQueue) Push(value interface{}) {
n := pq.queue.Put(value)
if pq.head == nil {
pq.head = n
return
} else if pq.queue.Compare(n.value, pq.head.value) == 1 {
pq.head = n
}
} }
func (pq *PriorityQueue) Top() (result interface{}, ok bool) { func (pq *PriorityQueue) Top() (result interface{}, ok bool) {
@ -39,17 +33,7 @@ func (pq *PriorityQueue) Top() (result interface{}, ok bool) {
} }
func (pq *PriorityQueue) Pop() (result interface{}, ok bool) { func (pq *PriorityQueue) Pop() (result interface{}, ok bool) {
if pq.head != nil {
prev := getPrev(pq.head, 1)
result = pq.head.value
pq.queue.removeNode(pq.head)
if prev != nil {
pq.head = prev
} else {
pq.head = nil
}
return result, true
}
return nil, false return nil, false
} }

View File

@ -27,10 +27,14 @@ func (n *tNode) String() string {
type vbTree struct { type vbTree struct {
root *tNode root *tNode
Compare compare.Compare Compare compare.Compare
top *tNode
iter *vbtIterator
} }
func newVBT(Compare compare.Compare) *vbTree { func newVBT(Compare compare.Compare) *vbTree {
return &vbTree{Compare: Compare} return &vbTree{Compare: Compare, iter: NewIteratorWithCap(nil, 16)}
} }
func (tree *vbTree) String() string { func (tree *vbTree) String() string {
@ -42,7 +46,7 @@ func (tree *vbTree) String() string {
return str return str
} }
func (tree *vbTree) Iterator() *vbtIterator { func (tree *vbTree) vbtIterator() *vbtIterator {
return initIterator(tree) return initIterator(tree)
} }
@ -263,7 +267,9 @@ func (tree *vbTree) GetRange(k1, k2 interface{}) (result []interface{}) {
result = make([]interface{}, 0, 16) result = make([]interface{}, 0, 16)
iter := NewIterator(min) // iter := NewIterator(min)
tree.iter.SetNode(min)
iter := tree.iter
for iter.Next() { for iter.Next() {
result = append(result, iter.Value()) result = append(result, iter.Value())
if iter.cur == max { if iter.cur == max {
@ -288,7 +294,9 @@ func (tree *vbTree) GetRange(k1, k2 interface{}) (result []interface{}) {
result = make([]interface{}, 0, 16) result = make([]interface{}, 0, 16)
iter := NewIterator(max) // iter := NewIterator(max)
tree.iter.SetNode(max)
iter := tree.iter
for iter.Prev() { for iter.Prev() {
result = append(result, iter.Value()) result = append(result, iter.Value())
if iter.cur == min { if iter.cur == min {
@ -338,7 +346,9 @@ func (tree *vbTree) getArountNode(key interface{}) (result [3]*tNode) {
n = n.children[1] n = n.children[1]
lastc = c lastc = c
case 0: case 0:
iter := NewIterator(n) // iter := NewIterator(n)
tree.iter.SetNode(n)
iter := tree.iter
iter.Prev() iter.Prev()
for iter.Prev() { for iter.Prev() {
if tree.Compare(iter.cur.value, n.value) == 0 { if tree.Compare(iter.cur.value, n.value) == 0 {
@ -359,21 +369,21 @@ func (tree *vbTree) getArountNode(key interface{}) (result [3]*tNode) {
if result[1] != nil { if result[1] != nil {
result[0] = getPrev(result[1], 1) result[0] = tree.iter.GetPrev(result[1], 1)
result[2] = getNext(result[1], 1) result[2] = tree.iter.GetNext(result[1], 1)
} else { } else {
result[0] = last result[0] = last
result[2] = getNext(last, 1) result[2] = tree.iter.GetNext(last, 1)
} }
case -1: case -1:
if result[1] != nil { if result[1] != nil {
result[0] = getPrev(result[1], 1) result[0] = tree.iter.GetPrev(result[1], 1)
result[2] = getNext(result[1], 1) result[2] = tree.iter.GetNext(result[1], 1)
} else { } else {
result[2] = last result[2] = last
result[0] = getPrev(last, 1) result[0] = tree.iter.GetPrev(last, 1)
} }
case 0: case 0:
@ -381,8 +391,8 @@ func (tree *vbTree) getArountNode(key interface{}) (result [3]*tNode) {
if result[1] == nil { if result[1] == nil {
return return
} }
result[0] = getPrev(result[1], 1) result[0] = tree.iter.GetPrev(result[1], 1)
result[2] = getNext(result[1], 1) result[2] = tree.iter.GetNext(result[1], 1)
} }
return return
} }
@ -396,7 +406,10 @@ func (tree *vbTree) GetNode(value interface{}) (*tNode, bool) {
case 1: case 1:
n = n.children[1] n = n.children[1]
case 0: case 0:
iter := NewIterator(n) // iter := NewIterator(n)
tree.iter.SetNode(n)
iter := tree.iter
iter.Prev() iter.Prev()
for iter.Prev() { for iter.Prev() {
if tree.Compare(iter.cur.value, n.value) == 0 { if tree.Compare(iter.cur.value, n.value) == 0 {
@ -413,46 +426,56 @@ func (tree *vbTree) GetNode(value interface{}) (*tNode, bool) {
return nil, false return nil, false
} }
func (tree *vbTree) Put(key interface{}) *tNode { func (tree *vbTree) Put(key interface{}) {
node := &tNode{value: key, size: 1} tNode := &tNode{value: key, size: 1}
if tree.root == nil { if tree.root == nil {
tree.root = node tree.root = tNode
return node return
} }
cur := tree.root for cur := tree.root; ; {
parent := cur.parent
child := -1
for {
if cur == nil {
parent.children[child] = node
node.parent = parent
fixed := parent.parent
fsize := getSize(fixed)
if fsize == 3 {
lefts, rigths := getChildrenSize(fixed)
tree.fix3Size(fixed, lefts, rigths)
}
return node
}
if cur.size > 8 { if cur.size > 8 {
ls, rs := cur.children[0].size, cur.children[1].size
factor := cur.size / 10 // or factor = 1 factor := cur.size / 10 // or factor = 1
if rs >= ls*2+factor || ls >= rs*2+factor { if cur.children[1].size >= cur.children[0].size*2+factor || cur.children[0].size >= cur.children[1].size*2+factor {
tree.fixSize(cur, ls, rs) tree.fixSize(cur)
} }
} }
cur.size++ cur.size++
parent = cur
c := tree.Compare(key, cur.value) c := tree.Compare(key, cur.value)
child = (c + 2) / 2 if c < 0 {
cur = cur.children[child] if cur.children[0] == nil {
cur.children[0] = tNode
tNode.parent = cur
if cur.parent != nil && cur.parent.size == 3 {
if cur.parent.children[0] == nil {
tree.lrrotate3(cur.parent)
} else {
tree.rrotate3(cur.parent)
}
}
return
}
cur = cur.children[0]
} else {
if cur.children[1] == nil {
cur.children[1] = tNode
tNode.parent = cur
if cur.parent != nil && cur.parent.size == 3 {
if cur.parent.children[1] == nil {
tree.rlrotate3(cur.parent)
} else {
tree.lrotate3(cur.parent)
}
}
return
}
cur = cur.children[1]
}
} }
} }
@ -601,27 +624,38 @@ func (tree *vbTree) Traversal(every func(v interface{}) bool, traversalMethod ..
} }
} }
func (tree *vbTree) lrrotate3(cur *tNode) { func (tree *vbTree) lrrotate3(cur *tNode) *tNode {
const l = 1 const l = 1
const r = 0 const r = 0
movparent := cur.children[l] ln := cur.children[l]
mov := movparent.children[r] lrn := ln.children[r]
ln.children[r] = nil
mov.value, cur.value = cur.value, mov.value //交换值达到, 相对位移 if cur.parent == nil {
tree.root = lrn
} else {
if cur.parent.children[1] == cur {
cur.parent.children[1] = lrn
} else {
cur.parent.children[0] = lrn
}
}
lrn.parent = cur.parent
cur.children[r] = mov lrn.children[l] = cur.children[l]
mov.parent = cur lrn.children[l].parent = lrn
cur.children[l] = movparent lrn.children[r] = cur
movparent.children[r] = nil lrn.children[r].parent = lrn
cur.children[r] = mov cur.children[l] = nil
mov.parent = cur
// cur.size = 3 lrn.size = 3
// cur.children[r].size = 1 lrn.children[l].size = 1
cur.children[l].size = 1 lrn.children[r].size = 1
return lrn
} }
func (tree *vbTree) lrrotate(cur *tNode) { func (tree *vbTree) lrrotate(cur *tNode) {
@ -664,27 +698,38 @@ func (tree *vbTree) lrrotate(cur *tNode) {
cur.size = getChildrenSumSize(cur) + 1 cur.size = getChildrenSumSize(cur) + 1
} }
func (tree *vbTree) rlrotate3(cur *tNode) { func (tree *vbTree) rlrotate3(cur *tNode) *tNode {
const l = 0 const l = 0
const r = 1 const r = 1
movparent := cur.children[l] ln := cur.children[l]
mov := movparent.children[r] lrn := ln.children[r]
ln.children[r] = nil
mov.value, cur.value = cur.value, mov.value //交换值达到, 相对位移 if cur.parent == nil {
tree.root = lrn
} else {
if cur.parent.children[1] == cur {
cur.parent.children[1] = lrn
} else {
cur.parent.children[0] = lrn
}
}
lrn.parent = cur.parent
cur.children[r] = mov lrn.children[l] = cur.children[l]
mov.parent = cur lrn.children[l].parent = lrn
cur.children[l] = movparent lrn.children[r] = cur
movparent.children[r] = nil lrn.children[r].parent = lrn
cur.children[r] = mov cur.children[l] = nil
mov.parent = cur
// cur.size = 3 lrn.size = 3
// cur.children[r].size = 1 lrn.children[l].size = 1
cur.children[l].size = 1 lrn.children[r].size = 1
return lrn
} }
func (tree *vbTree) rlrotate(cur *tNode) { func (tree *vbTree) rlrotate(cur *tNode) {
@ -725,23 +770,60 @@ func (tree *vbTree) rlrotate(cur *tNode) {
cur.size = getChildrenSumSize(cur) + 1 cur.size = getChildrenSumSize(cur) + 1
} }
func (tree *vbTree) rrotate3(cur *tNode) { func (tree *vbTree) replaceNotRoot(old, new *tNode) {
new.children[0] = old.children[0]
new.children[1] = old.children[1]
if old.parent.children[1] == old {
old.parent.children[1] = new
} else {
old.parent.children[0] = new
}
}
func (tree *vbTree) replace(old, new *tNode) {
new.children[0] = old.children[0]
new.children[1] = old.children[1]
if old.parent == nil {
tree.root = new
} else {
if old.parent.children[1] == old {
old.parent.children[1] = new
} else {
old.parent.children[0] = new
}
}
}
func (tree *vbTree) rrotate3(cur *tNode) *tNode {
const l = 0 const l = 0
const r = 1 const r = 1
// 1 right 0 left // 1 right 0 left
mov := cur.children[l] mov := cur.children[l]
mov.value, cur.value = cur.value, mov.value //交换值达到, 相对位移 if cur.parent == nil {
tree.root = mov
} else {
if cur.parent.children[1] == cur {
cur.parent.children[1] = mov
} else {
cur.parent.children[0] = mov
}
}
mov.parent = cur.parent
cur.children[r] = mov mov.children[r] = cur
mov.size = 1 mov.children[r].parent = mov
cur.children[l] = mov.children[l] cur.children[l] = nil
cur.children[l].parent = cur
mov.children[l] = nil mov.size = 3
cur.size = 1
mov.size = 1 return mov
} }
func (tree *vbTree) rrotate(cur *tNode) { func (tree *vbTree) rrotate(cur *tNode) {
@ -778,23 +860,32 @@ func (tree *vbTree) rrotate(cur *tNode) {
cur.size = getChildrenSumSize(cur) + 1 cur.size = getChildrenSumSize(cur) + 1
} }
func (tree *vbTree) lrotate3(cur *tNode) { func (tree *vbTree) lrotate3(cur *tNode) *tNode {
const l = 1 const l = 1
const r = 0 const r = 0
// 1 right 0 left // 1 right 0 left
mov := cur.children[l] mov := cur.children[l]
mov.value, cur.value = cur.value, mov.value //交换值达到, 相对位移 if cur.parent == nil {
tree.root = mov
} else {
if cur.parent.children[1] == cur {
cur.parent.children[1] = mov
} else {
cur.parent.children[0] = mov
}
}
mov.parent = cur.parent
cur.children[r] = mov mov.children[r] = cur
mov.size = 1 mov.children[r].parent = mov
cur.children[l] = mov.children[l] cur.children[l] = nil
cur.children[l].parent = cur
mov.children[l] = nil mov.size = 3
cur.size = 1
mov.size = 1 return mov
} }
func (tree *vbTree) lrotate(cur *tNode) { func (tree *vbTree) lrotate(cur *tNode) {
@ -850,48 +941,25 @@ func (tree *vbTree) fixSizeWithRemove(cur *tNode) {
for cur != nil { for cur != nil {
cur.size-- cur.size--
if cur.size > 8 { if cur.size > 8 {
ls, rs := getChildrenSize(cur)
factor := cur.size / 10 // or factor = 1 factor := cur.size / 10 // or factor = 1
if rs >= ls*2+factor || ls >= rs*2+factor { if cur.children[1].size >= cur.children[0].size*2+factor || cur.children[0].size >= cur.children[1].size*2+factor {
tree.fixSize(cur, ls, rs) tree.fixSize(cur)
} }
} }
cur = cur.parent cur = cur.parent
} }
} }
func (tree *vbTree) fix3Size(cur *tNode, lefts, rigths int) { func (tree *vbTree) fixSize(cur *tNode) {
if lefts > rigths { if cur.children[0].size > cur.children[1].size {
l := cur.children[0] llsize, lrsize := getChildrenSize(cur.children[0])
llsize, lrsize := getChildrenSize(l)
if lrsize > llsize {
tree.rlrotate3(cur)
} else {
tree.rrotate3(cur)
}
} else {
r := cur.children[1]
rlsize, rrsize := getChildrenSize(r)
if rlsize > rrsize {
tree.lrrotate3(cur)
} else {
tree.lrotate3(cur)
}
}
}
func (tree *vbTree) fixSize(cur *tNode, lefts, rigths int) {
if lefts > rigths {
l := cur.children[0]
llsize, lrsize := getChildrenSize(l)
if lrsize > llsize { if lrsize > llsize {
tree.rlrotate(cur) tree.rlrotate(cur)
} else { } else {
tree.rrotate(cur) tree.rrotate(cur)
} }
} else { } else {
r := cur.children[1] rlsize, rrsize := getChildrenSize(cur.children[1])
rlsize, rrsize := getChildrenSize(r)
if rlsize > rrsize { if rlsize > rrsize {
tree.lrrotate(cur) tree.lrrotate(cur)
} else { } else {
@ -900,16 +968,16 @@ func (tree *vbTree) fixSize(cur *tNode, lefts, rigths int) {
} }
} }
func output(node *tNode, prefix string, isTail bool, str *string) { func output(tNode *tNode, prefix string, isTail bool, str *string) {
if node.children[1] != nil { if tNode.children[1] != nil {
newPrefix := prefix newPrefix := prefix
if isTail { if isTail {
newPrefix += "│ " newPrefix += "│ "
} else { } else {
newPrefix += " " newPrefix += " "
} }
output(node.children[1], newPrefix, false, str) output(tNode.children[1], newPrefix, false, str)
} }
*str += prefix *str += prefix
if isTail { if isTail {
@ -918,30 +986,30 @@ func output(node *tNode, prefix string, isTail bool, str *string) {
*str += "┌── " *str += "┌── "
} }
*str += spew.Sprint(node.value) + "\n" *str += spew.Sprint(tNode.value) + "\n"
if node.children[0] != nil { if tNode.children[0] != nil {
newPrefix := prefix newPrefix := prefix
if isTail { if isTail {
newPrefix += " " newPrefix += " "
} else { } else {
newPrefix += "│ " newPrefix += "│ "
} }
output(node.children[0], newPrefix, true, str) output(tNode.children[0], newPrefix, true, str)
} }
} }
func outputfordebug(node *tNode, prefix string, isTail bool, str *string) { func outputfordebug(tNode *tNode, prefix string, isTail bool, str *string) {
if node.children[1] != nil { if tNode.children[1] != nil {
newPrefix := prefix newPrefix := prefix
if isTail { if isTail {
newPrefix += "│ " newPrefix += "│ "
} else { } else {
newPrefix += " " newPrefix += " "
} }
outputfordebug(node.children[1], newPrefix, false, str) outputfordebug(tNode.children[1], newPrefix, false, str)
} }
*str += prefix *str += prefix
if isTail { if isTail {
@ -952,22 +1020,22 @@ func outputfordebug(node *tNode, prefix string, isTail bool, str *string) {
suffix := "(" suffix := "("
parentv := "" parentv := ""
if node.parent == nil { if tNode.parent == nil {
parentv = "nil" parentv = "nil"
} else { } else {
parentv = spew.Sprint(node.parent.value) parentv = spew.Sprint(tNode.parent.value)
} }
suffix += parentv + "|" + spew.Sprint(node.size) + ")" suffix += parentv + "|" + spew.Sprint(tNode.size) + ")"
*str += spew.Sprint(node.value) + suffix + "\n" *str += spew.Sprint(tNode.value) + suffix + "\n"
if node.children[0] != nil { if tNode.children[0] != nil {
newPrefix := prefix newPrefix := prefix
if isTail { if isTail {
newPrefix += " " newPrefix += " "
} else { } else {
newPrefix += "│ " newPrefix += "│ "
} }
outputfordebug(node.children[0], newPrefix, true, str) outputfordebug(tNode.children[0], newPrefix, true, str)
} }
} }

View File

@ -24,6 +24,18 @@ func NewIterator(n *tNode) *vbtIterator {
return iter return iter
} }
func NewIteratorWithCap(n *tNode, cap int) *vbtIterator {
iter := &vbtIterator{tstack: lastack.NewWithCap(cap)}
iter.up = n
return iter
}
func (iter *vbtIterator) SetNode(n *tNode) {
iter.up = n
iter.dir = 0
iter.tstack.Clear()
}
func (iter *vbtIterator) Value() interface{} { func (iter *vbtIterator) Value() interface{} {
return iter.cur.value return iter.cur.value
} }
@ -46,9 +58,10 @@ func (iter *vbtIterator) Right() bool {
return false return false
} }
func getNext(cur *tNode, idx int) *tNode { func (iter *vbtIterator) GetNext(cur *tNode, idx int) *tNode {
iter := NewIterator(cur) // iter := NewIterator(cur)
iter.SetNode(cur)
iter.curPushNextStack(iter.up) iter.curPushNextStack(iter.up)
iter.up = iter.getNextUp(iter.up) iter.up = iter.getNextUp(iter.up)
@ -103,9 +116,11 @@ func (iter *vbtIterator) Next() (result bool) {
return false return false
} }
func getPrev(cur *tNode, idx int) *tNode {
iter := NewIterator(cur) func (iter *vbtIterator) GetPrev(cur *tNode, idx int) *tNode {
// iter := NewIterator(cur)
iter.SetNode(cur)
iter.curPushPrevStack(iter.up) iter.curPushPrevStack(iter.up)
iter.up = iter.getPrevUp(iter.up) iter.up = iter.getPrevUp(iter.up)

View File

@ -191,10 +191,11 @@ func TestGetAround(t *testing.T) {
func TestPutStable(t *testing.T) { func TestPutStable(t *testing.T) {
tree := newVBT(compare.Int) tree := newVBT(compare.Int)
for i := 0; i < 20; i++ { for i := 0; i < 40; i++ {
v := randomdata.Number(0, 100) v := randomdata.Number(0, 100)
tree.Put(v) tree.Put(v)
t.Error(i, tree.debugString(), v) t.Error(i, v)
t.Error(tree.debugString())
} }
} }
@ -486,7 +487,7 @@ func BenchmarkIterator(b *testing.B) {
b.ResetTimer() b.ResetTimer()
b.StartTimer() b.StartTimer()
iter := tree.Iterator() iter := tree.vbtIterator()
b.N = 0 b.N = 0
for iter.Next() { for iter.Next() {
b.N++ b.N++

View File

@ -24,6 +24,18 @@ func NewIterator(n *Node) *Iterator {
return iter return iter
} }
func NewIteratorWithCap(n *Node, cap int) *Iterator {
iter := &Iterator{tstack: lastack.NewWithCap(cap)}
iter.up = n
return iter
}
func (iter *Iterator) SetNode(n *Node) {
iter.up = n
iter.dir = 0
iter.tstack.Clear()
}
func (iter *Iterator) Value() interface{} { func (iter *Iterator) Value() interface{} {
return iter.cur.value return iter.cur.value
} }
@ -46,9 +58,10 @@ func (iter *Iterator) Right() bool {
return false return false
} }
func GetNext(cur *Node, idx int) *Node { func (iter *Iterator) GetNext(cur *Node, idx int) *Node {
iter := NewIterator(cur) // iter := NewIterator(cur)
iter.SetNode(cur)
iter.curPushNextStack(iter.up) iter.curPushNextStack(iter.up)
iter.up = iter.getNextUp(iter.up) iter.up = iter.getNextUp(iter.up)
@ -103,9 +116,10 @@ func (iter *Iterator) Next() (result bool) {
return false return false
} }
func GetPrev(cur *Node, idx int) *Node { func (iter *Iterator) GetPrev(cur *Node, idx int) *Node {
iter := NewIterator(cur) // iter := NewIterator(cur)
iter.SetNode(cur)
iter.curPushPrevStack(iter.up) iter.curPushPrevStack(iter.up)
iter.up = iter.getPrevUp(iter.up) iter.up = iter.getPrevUp(iter.up)

View File

@ -27,10 +27,12 @@ func (n *Node) String() string {
type Tree struct { type Tree struct {
root *Node root *Node
Compare compare.Compare Compare compare.Compare
iter *Iterator
} }
func New(Compare compare.Compare) *Tree { func New(Compare compare.Compare) *Tree {
return &Tree{Compare: Compare} return &Tree{Compare: Compare, iter: NewIteratorWithCap(nil, 16)}
} }
func (tree *Tree) String() string { func (tree *Tree) String() string {
@ -263,7 +265,9 @@ func (tree *Tree) GetRange(k1, k2 interface{}) (result []interface{}) {
result = make([]interface{}, 0, 16) result = make([]interface{}, 0, 16)
iter := NewIterator(min) // iter := NewIterator(min)
tree.iter.SetNode(min)
iter := tree.iter
for iter.Next() { for iter.Next() {
result = append(result, iter.Value()) result = append(result, iter.Value())
if iter.cur == max { if iter.cur == max {
@ -288,7 +292,9 @@ func (tree *Tree) GetRange(k1, k2 interface{}) (result []interface{}) {
result = make([]interface{}, 0, 16) result = make([]interface{}, 0, 16)
iter := NewIterator(max) // iter := NewIterator(max)
tree.iter.SetNode(max)
iter := tree.iter
for iter.Prev() { for iter.Prev() {
result = append(result, iter.Value()) result = append(result, iter.Value())
if iter.cur == min { if iter.cur == min {
@ -338,7 +344,9 @@ func (tree *Tree) getArountNode(key interface{}) (result [3]*Node) {
n = n.children[1] n = n.children[1]
lastc = c lastc = c
case 0: case 0:
iter := NewIterator(n) // iter := NewIterator(n)
tree.iter.SetNode(n)
iter := tree.iter
iter.Prev() iter.Prev()
for iter.Prev() { for iter.Prev() {
if tree.Compare(iter.cur.value, n.value) == 0 { if tree.Compare(iter.cur.value, n.value) == 0 {
@ -359,21 +367,21 @@ func (tree *Tree) getArountNode(key interface{}) (result [3]*Node) {
if result[1] != nil { if result[1] != nil {
result[0] = GetPrev(result[1], 1) result[0] = tree.iter.GetPrev(result[1], 1)
result[2] = GetNext(result[1], 1) result[2] = tree.iter.GetNext(result[1], 1)
} else { } else {
result[0] = last result[0] = last
result[2] = GetNext(last, 1) result[2] = tree.iter.GetNext(last, 1)
} }
case -1: case -1:
if result[1] != nil { if result[1] != nil {
result[0] = GetPrev(result[1], 1) result[0] = tree.iter.GetPrev(result[1], 1)
result[2] = GetNext(result[1], 1) result[2] = tree.iter.GetNext(result[1], 1)
} else { } else {
result[2] = last result[2] = last
result[0] = GetPrev(last, 1) result[0] = tree.iter.GetPrev(last, 1)
} }
case 0: case 0:
@ -381,8 +389,8 @@ func (tree *Tree) getArountNode(key interface{}) (result [3]*Node) {
if result[1] == nil { if result[1] == nil {
return return
} }
result[0] = GetPrev(result[1], 1) result[0] = tree.iter.GetPrev(result[1], 1)
result[2] = GetNext(result[1], 1) result[2] = tree.iter.GetNext(result[1], 1)
} }
return return
} }
@ -396,7 +404,10 @@ func (tree *Tree) GetNode(value interface{}) (*Node, bool) {
case 1: case 1:
n = n.children[1] n = n.children[1]
case 0: case 0:
iter := NewIterator(n) // iter := NewIterator(n)
tree.iter.SetNode(n)
iter := tree.iter
iter.Prev() iter.Prev()
for iter.Prev() { for iter.Prev() {
if tree.Compare(iter.cur.value, n.value) == 0 { if tree.Compare(iter.cur.value, n.value) == 0 {
@ -615,16 +626,14 @@ func (tree *Tree) lrrotate3(cur *Node) {
const l = 1 const l = 1
const r = 0 const r = 0
movparent := cur.children[l] mov := cur.children[l].children[r]
mov := movparent.children[r]
mov.value, cur.value = cur.value, mov.value //交换值达到, 相对位移 mov.value, cur.value = cur.value, mov.value //交换值达到, 相对位移
cur.children[r] = mov cur.children[r] = mov
mov.parent = cur mov.parent = cur
cur.children[l] = movparent cur.children[l].children[r] = nil
movparent.children[r] = nil
cur.children[r] = mov cur.children[r] = mov
mov.parent = cur mov.parent = cur
@ -678,16 +687,14 @@ func (tree *Tree) rlrotate3(cur *Node) {
const l = 0 const l = 0
const r = 1 const r = 1
movparent := cur.children[l] mov := cur.children[l].children[r]
mov := movparent.children[r]
mov.value, cur.value = cur.value, mov.value //交换值达到, 相对位移 mov.value, cur.value = cur.value, mov.value //交换值达到, 相对位移
cur.children[r] = mov cur.children[r] = mov
mov.parent = cur mov.parent = cur
cur.children[l] = movparent cur.children[l].children[r] = nil
movparent.children[r] = nil
cur.children[r] = mov cur.children[r] = mov
mov.parent = cur mov.parent = cur
@ -744,7 +751,6 @@ func (tree *Tree) rrotate3(cur *Node) {
mov.value, cur.value = cur.value, mov.value //交换值达到, 相对位移 mov.value, cur.value = cur.value, mov.value //交换值达到, 相对位移
cur.children[r] = mov cur.children[r] = mov
mov.size = 1
cur.children[l] = mov.children[l] cur.children[l] = mov.children[l]
cur.children[l].parent = cur cur.children[l].parent = cur
@ -797,7 +803,6 @@ func (tree *Tree) lrotate3(cur *Node) {
mov.value, cur.value = cur.value, mov.value //交换值达到, 相对位移 mov.value, cur.value = cur.value, mov.value //交换值达到, 相对位移
cur.children[r] = mov cur.children[r] = mov
mov.size = 1
cur.children[l] = mov.children[l] cur.children[l] = mov.children[l]
cur.children[l].parent = cur cur.children[l].parent = cur

View File

@ -17,7 +17,7 @@ import (
"github.com/emirpasic/gods/trees/redblacktree" "github.com/emirpasic/gods/trees/redblacktree"
) )
const CompareSize = 5000 const CompareSize = 1000000
const NumberMax = 50000000 const NumberMax = 50000000
func TestSave(t *testing.T) { func TestSave(t *testing.T) {
@ -32,13 +32,13 @@ func TestSave(t *testing.T) {
// l = append(l, v) // l = append(l, v)
// } // }
//m := make(map[int]int) m := make(map[int]int)
for i := 0; len(l) < CompareSize; i++ { for i := 0; len(l) < CompareSize; i++ {
v := randomdata.Number(0, NumberMax) v := randomdata.Number(0, NumberMax)
// if _, ok := m[v]; !ok { if _, ok := m[v]; !ok {
// m[v] = v m[v] = v
l = append(l, v) l = append(l, v)
// } }
} }
var result bytes.Buffer var result bytes.Buffer

View File

@ -24,6 +24,18 @@ func NewIterator(n *Node) *Iterator {
return iter return iter
} }
func NewIteratorWithCap(n *Node, cap int) *Iterator {
iter := &Iterator{tstack: lastack.NewWithCap(cap)}
iter.up = n
return iter
}
func (iter *Iterator) SetNode(n *Node) {
iter.up = n
iter.dir = 0
iter.tstack.Clear()
}
func (iter *Iterator) Value() interface{} { func (iter *Iterator) Value() interface{} {
return iter.cur.value return iter.cur.value
} }
@ -46,9 +58,10 @@ func (iter *Iterator) Right() bool {
return false return false
} }
func GetNext(cur *Node, idx int) *Node { func (iter *Iterator) GetNext(cur *Node, idx int) *Node {
iter := NewIterator(cur) // iter := NewIterator(cur)
iter.SetNode(cur)
iter.curPushNextStack(iter.up) iter.curPushNextStack(iter.up)
iter.up = iter.getNextUp(iter.up) iter.up = iter.getNextUp(iter.up)
@ -103,9 +116,10 @@ func (iter *Iterator) Next() (result bool) {
return false return false
} }
func GetPrev(cur *Node, idx int) *Node { func (iter *Iterator) GetPrev(cur *Node, idx int) *Node {
iter := NewIterator(cur) // iter := NewIterator(cur)
iter.SetNode(cur)
iter.curPushPrevStack(iter.up) iter.curPushPrevStack(iter.up)
iter.up = iter.getPrevUp(iter.up) iter.up = iter.getPrevUp(iter.up)

View File

@ -28,10 +28,12 @@ func (n *Node) String() string {
type Tree struct { type Tree struct {
root *Node root *Node
Compare compare.Compare Compare compare.Compare
iter *Iterator
} }
func New(Compare compare.Compare) *Tree { func New(Compare compare.Compare) *Tree {
return &Tree{Compare: Compare} return &Tree{Compare: Compare, iter: NewIteratorWithCap(nil, 16)}
} }
func (tree *Tree) String() string { func (tree *Tree) String() string {
@ -264,7 +266,8 @@ func (tree *Tree) GetRange(k1, k2 interface{}) (result []interface{}) {
result = make([]interface{}, 0, 16) result = make([]interface{}, 0, 16)
iter := NewIterator(min) tree.iter.SetNode(min)
iter := tree.iter
for iter.Next() { for iter.Next() {
result = append(result, iter.Value()) result = append(result, iter.Value())
if iter.cur == max { if iter.cur == max {
@ -289,7 +292,8 @@ func (tree *Tree) GetRange(k1, k2 interface{}) (result []interface{}) {
result = make([]interface{}, 0, 16) result = make([]interface{}, 0, 16)
iter := NewIterator(max) tree.iter.SetNode(max)
iter := tree.iter
for iter.Prev() { for iter.Prev() {
result = append(result, iter.Value()) result = append(result, iter.Value())
if iter.cur == min { if iter.cur == min {
@ -339,7 +343,8 @@ func (tree *Tree) getArountNode(key interface{}) (result [3]*Node) {
n = n.children[1] n = n.children[1]
lastc = c lastc = c
case 0: case 0:
iter := NewIterator(n) tree.iter.SetNode(n)
iter := tree.iter
iter.Prev() iter.Prev()
for iter.Prev() { for iter.Prev() {
if tree.Compare(iter.cur.value, n.value) == 0 { if tree.Compare(iter.cur.value, n.value) == 0 {
@ -360,21 +365,21 @@ func (tree *Tree) getArountNode(key interface{}) (result [3]*Node) {
if result[1] != nil { if result[1] != nil {
result[0] = GetPrev(result[1], 1) result[0] = tree.iter.GetPrev(result[1], 1)
result[2] = GetNext(result[1], 1) result[2] = tree.iter.GetNext(result[1], 1)
} else { } else {
result[0] = last result[0] = last
result[2] = GetNext(last, 1) result[2] = tree.iter.GetNext(last, 1)
} }
case -1: case -1:
if result[1] != nil { if result[1] != nil {
result[0] = GetPrev(result[1], 1) result[0] = tree.iter.GetPrev(result[1], 1)
result[2] = GetNext(result[1], 1) result[2] = tree.iter.GetNext(result[1], 1)
} else { } else {
result[2] = last result[2] = last
result[0] = GetPrev(last, 1) result[0] = tree.iter.GetPrev(last, 1)
} }
case 0: case 0:
@ -382,8 +387,8 @@ func (tree *Tree) getArountNode(key interface{}) (result [3]*Node) {
if result[1] == nil { if result[1] == nil {
return return
} }
result[0] = GetPrev(result[1], 1) result[0] = tree.iter.GetPrev(result[1], 1)
result[2] = GetNext(result[1], 1) result[2] = tree.iter.GetNext(result[1], 1)
} }
return return
} }
@ -397,7 +402,9 @@ func (tree *Tree) GetNode(value interface{}) (*Node, bool) {
case 1: case 1:
n = n.children[1] n = n.children[1]
case 0: case 0:
iter := NewIterator(n)
tree.iter.SetNode(n)
iter := tree.iter
iter.Prev() iter.Prev()
for iter.Prev() { for iter.Prev() {
if tree.Compare(iter.cur.value, n.value) == 0 { if tree.Compare(iter.cur.value, n.value) == 0 {
@ -746,7 +753,6 @@ func (tree *Tree) rrotate3(cur *Node) {
mov.key, mov.value, cur.key, cur.value = cur.key, cur.value, mov.key, mov.value //交换值达到, 相对位移 mov.key, mov.value, cur.key, cur.value = cur.key, cur.value, mov.key, mov.value //交换值达到, 相对位移
cur.children[r] = mov cur.children[r] = mov
mov.size = 1
cur.children[l] = mov.children[l] cur.children[l] = mov.children[l]
cur.children[l].parent = cur cur.children[l].parent = cur
@ -800,7 +806,6 @@ func (tree *Tree) lrotate3(cur *Node) {
mov.key, mov.value, cur.key, cur.value = cur.key, cur.value, mov.key, mov.value //交换值达到, 相对位移 mov.key, mov.value, cur.key, cur.value = cur.key, cur.value, mov.key, mov.value //交换值达到, 相对位移
cur.children[r] = mov cur.children[r] = mov
mov.size = 1
cur.children[l] = mov.children[l] cur.children[l] = mov.children[l]
cur.children[l].parent = cur cur.children[l].parent = cur