diff --git a/avlindex/avlindex.go b/avlindex/avlindex.go new file mode 100644 index 0000000..0d3c319 --- /dev/null +++ b/avlindex/avlindex.go @@ -0,0 +1,541 @@ +package avlindex + +import ( + "github.com/davecgh/go-spew/spew" + + "github.com/emirpasic/gods/utils" +) + +type Node struct { + children [2]*Node + parent *Node + size int + value interface{} +} + +func (n *Node) String() string { + if n == nil { + return "nil" + } + + p := "nil" + if n.parent != nil { + p = spew.Sprint(n.parent.value) + } + return spew.Sprint(n.value) + "(" + p + "|" + spew.Sprint(n.size) + ")" +} + +type Tree struct { + root *Node + comparator utils.Comparator +} + +func New(comparator utils.Comparator) *Tree { + return &Tree{comparator: comparator} +} + +func (avl *Tree) String() string { + str := "AVLTree\n" + output(avl.root, "", true, &str) + + return str +} + +func (avl *Tree) Iterator() *Iterator { + return initIterator(avl) +} + +func (avl *Tree) Size() int { + return avl.root.size +} + +func (avl *Tree) Remove(key interface{}) *Node { + return nil +} + +func (avl *Tree) Get(key interface{}) (interface{}, bool) { + n, ok := avl.GetNode(key) + if ok { + return n.value, true + } + return n, false +} + +func (avl *Tree) GetRange(min, max interface{}) []interface{} { + return nil +} + +func (avl *Tree) GetAround(key interface{}) (result [3]interface{}) { + an := avl.GetAroundNode(key) + for i, n := range an { + if n != nil { + result[i] = n.value + } + } + return +} + +func (avl *Tree) GetAroundNode(value interface{}) (result [3]*Node) { + n := avl.root + + for { + + if n == nil { + return + } + + lastc := 0 + switch c := avl.comparator(value, n.value); c { + case -1: + if c != -lastc { + result[0] = n + } + lastc = c + n = n.children[0] + case 1: + if c != -lastc { + result[2] = n + } + lastc = c + n = n.children[1] + case 0: + + switch lastc { + case -1: + if n.children[1] != nil { + result[0] = n.children[1] + } + case 1: + if n.children[0] != nil { + result[2] = n.children[0] + } + case 0: + + if n.children[1] != nil { + result[0] = n.children[1] + } + if n.children[0] != nil { + result[2] = n.children[0] + } + + result[1] = n + return + } + + default: + panic("Get comparator only is allowed in -1, 0, 1") + } + + } +} +func (avl *Tree) GetNode(value interface{}) (*Node, bool) { + + for n := avl.root; n != nil; { + switch c := avl.comparator(value, n.value); c { + case -1: + n = n.children[0] + case 1: + n = n.children[1] + case 0: + return n, true + default: + panic("Get comparator only is allowed in -1, 0, 1") + } + } + return nil, false +} + +func (avl *Tree) Put(value interface{}) { + + node := &Node{value: value, size: 1} + if avl.root == nil { + avl.root = node + return + } + + cur := avl.root + parent := cur.parent + child := -1 + + for { + + if cur == nil { + parent.children[child] = node + node.parent = parent + if getSize(node.parent.parent) == 3 { + lefts, rigths := getChildrenSize(node.parent.parent) + avl.fixPutHeight(node.parent.parent, lefts, rigths) + } + return + } + + if cur.size > 4 { + ls, rs := getChildrenSize(cur) + if rs > ls { + if rs >= ls*2 { + avl.fixPutHeight(cur, ls, rs) + } + } else { + if ls >= rs*2 { + avl.fixPutHeight(cur, ls, rs) + } + } + } + + cur.size++ + parent = cur + c := avl.comparator(value, cur.value) + child = (c + 2) / 2 + cur = cur.children[child] + } + +} + +func (avl *Tree) debugString() string { + str := "AVLTree\n" + outputfordebug(avl.root, "", true, &str) + return str +} + +func (avl *Tree) TraversalBreadth() (result []interface{}) { + result = make([]interface{}, 0, avl.root.size) + var traverasl func(cur *Node) + traverasl = func(cur *Node) { + if cur == nil { + return + } + result = append(result, cur.value) + traverasl(cur.children[0]) + traverasl(cur.children[1]) + } + traverasl(avl.root) + return +} + +func (avl *Tree) TraversalDepth(leftright int) (result []interface{}) { + result = make([]interface{}, 0, avl.root.size) + if leftright < 0 { + var traverasl func(cur *Node) + traverasl = func(cur *Node) { + if cur == nil { + return + } + traverasl(cur.children[0]) + result = append(result, cur.value) + traverasl(cur.children[1]) + } + traverasl(avl.root) + } else { + var traverasl func(cur *Node) + traverasl = func(cur *Node) { + if cur == nil { + return + } + traverasl(cur.children[1]) + result = append(result, cur.value) + traverasl(cur.children[0]) + } + traverasl(avl.root) + } + + return +} + +func (avl *Tree) lrrotate(cur *Node) { + + const l = 1 + const r = 0 + + movparent := cur.children[l] + mov := movparent.children[r] + + mov.value, cur.value = cur.value, mov.value //交换值达到, 相对位移 + + if mov.children[l] != nil { + movparent.children[r] = mov.children[l] + movparent.children[r].parent = movparent + //movparent.children[r].child = l + } else { + movparent.children[r] = nil + } + + if mov.children[r] != nil { + mov.children[l] = mov.children[r] + //mov.children[l].child = l + } else { + mov.children[l] = nil + } + + if cur.children[r] != nil { + mov.children[r] = cur.children[r] + mov.children[r].parent = mov + } else { + mov.children[r] = nil + } + + cur.children[r] = mov + mov.parent = cur + + // cur.size = 3 + // cur.children[0].size = 1 + // cur.children[1].size = 1 + + movparent.size = getChildrenSumSize(movparent) + 1 + mov.size = getChildrenSumSize(mov) + 1 + cur.size = getChildrenSumSize(cur) + 1 + + // mov.height = getMaxChildrenHeight(mov) + 1 + // movparent.height = getMaxChildrenHeight(movparent) + 1 + // cur.height = getMaxChildrenHeight(cur) + 1 +} + +func (avl *Tree) rlrotate(cur *Node) { + + const l = 0 + const r = 1 + + movparent := cur.children[l] + mov := movparent.children[r] + + mov.value, cur.value = cur.value, mov.value //交换值达到, 相对位移 + + if mov.children[l] != nil { + movparent.children[r] = mov.children[l] + movparent.children[r].parent = movparent + } else { + movparent.children[r] = nil + } + + if mov.children[r] != nil { + mov.children[l] = mov.children[r] + } else { + mov.children[l] = nil + } + + if cur.children[r] != nil { + mov.children[r] = cur.children[r] + mov.children[r].parent = mov + } else { + mov.children[r] = nil + } + + cur.children[r] = mov + mov.parent = cur + + movparent.size = getChildrenSumSize(movparent) + 1 + mov.size = getChildrenSumSize(mov) + 1 + cur.size = getChildrenSumSize(cur) + 1 + + // cur.size = 3 + // cur.children[0].size = 1 + // cur.children[1].size = 1 + + // mov.height = getMaxChildrenHeight(mov) + 1 + // movparent.height = getMaxChildrenHeight(movparent) + 1 + // cur.height = getMaxChildrenHeight(cur) + 1 +} + +func (avl *Tree) rrotate(cur *Node) { + + const l = 0 + const r = 1 + // 1 right 0 left + mov := cur.children[l] + + // lsize, rsize := getChildrenHeight(cur) + // movrsize := getSize(mov.children[r]) + + mov.value, cur.value = cur.value, mov.value //交换值达到, 相对位移 + + // mov.children[l]不可能为nil + + mov.children[l].parent = cur + cur.children[l] = mov.children[l] + + // 解决mov节点孩子转移的问题 + if mov.children[r] != nil { + mov.children[l] = mov.children[r] + } else { + mov.children[l] = nil + } + + if cur.children[r] != nil { + mov.children[r] = cur.children[r] + mov.children[r].parent = mov + } else { + mov.children[r] = nil + } + + // 连接转移后的节点 由于mov只是与cur交换值,parent不变 + cur.children[r] = mov + + // cur.size = 3 + // cur.children[0].size = 1 + // cur.children[1].size = 1 + + mov.size = getChildrenSumSize(mov) + 1 + cur.size = getChildrenSumSize(cur) + 1 + // cur.height = getMaxChildrenHeight(cur) + 1 +} + +func (avl *Tree) lrotate(cur *Node) { + + const l = 1 + const r = 0 + // 1 right 0 left + mov := cur.children[l] + + // lsize, rsize := getChildrenHeight(cur) + // movrsize := getSize(mov.children[r]) + + mov.value, cur.value = cur.value, mov.value //交换值达到, 相对位移 + + // mov.children[l]不可能为nil + + mov.children[l].parent = cur + cur.children[l] = mov.children[l] + + // 解决mov节点孩子转移的问题 + if mov.children[r] != nil { + mov.children[l] = mov.children[r] + } else { + mov.children[l] = nil + } + + if cur.children[r] != nil { + mov.children[r] = cur.children[r] + mov.children[r].parent = mov + } else { + mov.children[r] = nil + } + + // 连接转移后的节点 由于mov只是与cur交换值,parent不变 + cur.children[r] = mov + + // cur.size = 3 + // cur.children[0].size = 1 + // cur.children[1].size = 1 + + mov.size = getChildrenSumSize(mov) + 1 + cur.size = getChildrenSumSize(cur) + 1 + +} + +func getChildrenSumSize(cur *Node) int { + return getSize(cur.children[0]) + getSize(cur.children[1]) +} + +func getChildrenSize(cur *Node) (int, int) { + return getSize(cur.children[0]), getSize(cur.children[1]) +} + +func getSize(cur *Node) int { + if cur == nil { + return 0 + } + return cur.size +} + +func (avl *Tree) fixRemoveHeight(cur *Node) { + +} + +func abs(n int) int { + y := n >> 31 + return (n ^ y) - y +} + +func (avl *Tree) fixPutHeight(cur *Node, lefts, rigths int) { + + // lefts, rigths := getChildrenSize(cur) + if lefts < rigths { + r := cur.children[1] + rlsize, rrsize := getChildrenSize(r) + if rlsize > rrsize { + avl.lrrotate(cur) + } else { + avl.lrotate(cur) + } + + } else { + l := cur.children[0] + llsize, lrsize := getChildrenSize(l) + if lrsize > llsize { + avl.rlrotate(cur) + } else { + avl.rrotate(cur) + } + } + +} + +func output(node *Node, prefix string, isTail bool, str *string) { + + if node.children[1] != nil { + newPrefix := prefix + if isTail { + newPrefix += "│ " + } else { + newPrefix += " " + } + output(node.children[1], newPrefix, false, str) + } + *str += prefix + if isTail { + *str += "└── " + } else { + *str += "┌── " + } + + *str += spew.Sprint(node.value) + "\n" + + if node.children[0] != nil { + newPrefix := prefix + if isTail { + newPrefix += " " + } else { + newPrefix += "│ " + } + output(node.children[0], newPrefix, true, str) + } + +} + +func outputfordebug(node *Node, prefix string, isTail bool, str *string) { + + if node.children[1] != nil { + newPrefix := prefix + if isTail { + newPrefix += "│ " + } else { + newPrefix += " " + } + outputfordebug(node.children[1], newPrefix, false, str) + } + *str += prefix + if isTail { + *str += "└── " + } else { + *str += "┌── " + } + + suffix := "(" + parentv := "" + if node.parent == nil { + parentv = "nil" + } else { + parentv = spew.Sprint(node.parent.value) + } + suffix += parentv + "|" + spew.Sprint(node.size) + ")" + *str += spew.Sprint(node.value) + suffix + "\n" + + if node.children[0] != nil { + newPrefix := prefix + if isTail { + newPrefix += " " + } else { + newPrefix += "│ " + } + outputfordebug(node.children[0], newPrefix, true, str) + } +} diff --git a/avlindex/avlindex_test.go b/avlindex/avlindex_test.go new file mode 100644 index 0000000..58a1f92 --- /dev/null +++ b/avlindex/avlindex_test.go @@ -0,0 +1,453 @@ +package avlindex + +import ( + "bytes" + "encoding/gob" + "io/ioutil" + "log" + "os" + "testing" + + "github.com/Pallinder/go-randomdata" + "github.com/emirpasic/gods/trees/avltree" + "github.com/emirpasic/gods/trees/redblacktree" + "github.com/emirpasic/gods/utils" +) + +const CompartorSize = 10000000 +const NumberMax = 60000000 + +func TestSave(t *testing.T) { + + f, err := os.OpenFile("../l.log", os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0666) + if err != nil { + log.Println(err) + } + + //fmt.Println(userBytes) + + var l []int + m := make(map[int]int) + for i := 0; len(l) < CompartorSize; i++ { + v := randomdata.Number(0, NumberMax) + if _, ok := m[v]; !ok { + m[v] = v + l = append(l, v) + } + } + + var result bytes.Buffer + encoder := gob.NewEncoder(&result) + encoder.Encode(l) + lbytes := result.Bytes() + f.Write(lbytes) + +} + +func loadTestData() []int { + data, err := ioutil.ReadFile("../l.log") + if err != nil { + log.Println(err) + } + var l []int + decoder := gob.NewDecoder(bytes.NewReader(data)) + decoder.Decode(&l) + return l +} + +// func TestIterator(t *testing.T) { +// avl := New(utils.IntComparator) +// for _, v := range []int{1, 2, 7, 4, 5, 6, 7, 14, 15, 20, 30, 21, 3} { +// // t.Error(v) +// avl.Put(v) + +// } +// // ` AVLTree +// // │ ┌── 30 +// // │ │ └── 21 +// // │ ┌── 20 +// // │ │ └── 15 +// // └── 14 +// // │ ┌── 7 +// // │ ┌── 7 +// // │ │ └── 6 +// // └── 5 +// // │ ┌── 4 +// // │ │ └── 3 +// // └── 2 +// // └── 1` + +// iter := avl.Iterator() // root start point + +// l := []int{14, 15, 20, 21, 30} + +// for i := 0; iter.Prev(); i++ { +// if iter.Value().(int) != l[i] { +// t.Error("iter prev error", iter.Value(), l[i]) +// } +// } + +// iter.Prev() +// if iter.Value().(int) != 30 { +// t.Error("prev == false", iter.Value(), iter.Prev(), iter.Value()) +// } + +// l = []int{21, 20, 15, 14, 7, 7, 6, 5, 4, 3, 2, 1} +// for i := 0; iter.Next(); i++ { // cur is 30 next is 21 +// if iter.Value().(int) != l[i] { +// t.Error(iter.Value()) +// } +// } + +// if iter.Next() != false { +// t.Error("Next is error, cur is tail, val = 1 Next return false") +// } +// if iter.Value().(int) != 1 { // cur is 1 +// t.Error("next == false", iter.Value(), iter.Next(), iter.Value()) +// } + +// if iter.Prev() != true && iter.Value().(int) != 2 { +// t.Error("next to prev is error") +// } +// } + +// func TestGetAround(t *testing.T) { +// avl := New(utils.IntComparator) +// for _, v := range []int{7, 14, 15, 20, 30, 21, 40, 40, 50, 3, 40, 40, 40} { +// avl.Put(v) +// } + +// if spew.Sprint(avl.GetAround(30)) != "[40 30 21]" { +// t.Error("avl.GetAround(40)) is error", spew.Sprint(avl.GetAround(30))) +// } + +// if spew.Sprint(avl.GetAround(40)) != "[40 40 30]" { +// t.Error("avl.GetAround(40)) is error", spew.Sprint(avl.GetAround(50))) +// } + +// if spew.Sprint(avl.GetAround(50)) != "[ 50 40]" { +// t.Error("avl.GetAround(40)) is error", spew.Sprint(avl.GetAround(50))) +// } +// } + +// // for test error case + +// func TestPutComparatorRandom(t *testing.T) { + +// for n := 0; n < 300000; n++ { +// avl := New(utils.IntComparator) +// godsavl := avltree.NewWithIntComparator() + +// content := "" +// m := make(map[int]int) +// for i := 0; len(m) < 10; i++ { +// v := randomdata.Number(0, 65535) +// if _, ok := m[v]; !ok { +// m[v] = v +// content += spew.Sprint(v) + "," +// avl.Put(v) +// godsavl.Put(v, v) +// } +// } + +// if avl.String() != godsavl.String() { +// t.Error(godsavl.String()) +// t.Error(avl.debugString()) +// t.Error(content, n) +// break +// } +// } +// } + +func TestGet(t *testing.T) { + avl := New(utils.IntComparator) + for _, v := range []int{2383, 7666, 3055, 39016, 57092, 27897, 36513, 1562, 22574, 23202} { + avl.Put(v) + } + + for _, v := range []int{2383, 7666, 3055, 39016, 57092, 27897, 36513, 1562, 22574, 23202} { + v, ok := avl.Get(v) + if !ok { + t.Error("the val not found ", v) + } + } + + if v, ok := avl.Get(10000); ok { + t.Error("the val(1000) is not in tree, but is found", v) + } + +} + +// func TestRemoveAll(t *testing.T) { + +// ALL: +// for c := 0; c < 5000; c++ { +// avl := New(utils.IntComparator) +// gods := avltree.NewWithIntComparator() +// var l []int +// m := make(map[int]int) + +// for i := 0; len(l) < 100; i++ { +// v := randomdata.Number(0, 100000) +// if _, ok := m[v]; !ok { +// m[v] = v +// l = append(l, v) +// avl.Put(v) +// gods.Put(v, v) +// } +// } + +// for i := 0; i < 100; i++ { +// avl.Remove(l[i]) +// gods.Remove(l[i]) +// s1 := spew.Sprint(avl.TraversalDepth(-1)) +// s2 := spew.Sprint(gods.Values()) +// if s1 != s2 { +// t.Error("avl remove error", "avlsize = ", avl.Size()) +// t.Error(s1) +// t.Error(s2) +// break ALL +// } +// } +// } +// } + +// func TestRemove(t *testing.T) { + +// ALL: +// for N := 0; N < 500000; N++ { +// avl := New(utils.IntComparator) +// gods := avltree.NewWithIntComparator() + +// var l []int +// m := make(map[int]int) + +// for i := 0; len(l) < 10; i++ { +// v := randomdata.Number(0, 100) +// if _, ok := m[v]; !ok { +// l = append(l, v) +// m[v] = v +// avl.Put(v) +// gods.Put(v, v) +// } +// } + +// src1 := avl.String() +// src2 := gods.String() + +// for i := 0; i < 10; i++ { +// avl.Remove(l[i]) +// gods.Remove(l[i]) +// if spew.Sprint(gods.Values()) != spew.Sprint(avl.TraversalDepth(-1)) && avl.size != 0 { +// // if gods.String() != avl.String() && gods.Size() != 0 && avl.size != 0 { +// t.Error(src1) +// t.Error(src2) +// t.Error(avl.debugString()) +// t.Error(gods.String()) +// t.Error(l[i]) +// // t.Error(avl.TraversalDepth(-1)) +// // t.Error(gods.Values()) +// break ALL +// } +// } +// } +// } + +// func BenchmarkIterator(b *testing.B) { +// tree := New(utils.IntComparator) + +// l := loadTestData() +// b.N = len(l) + +// for _, v := range l { +// tree.Put(v) +// } + +// b.ResetTimer() +// b.StartTimer() +// iter := tree.Iterator() +// for iter.Next() { +// } +// for iter.Prev() { +// } +// for iter.Next() { +// } + +// } + +// func BenchmarkGodsIterator(b *testing.B) { +// tree := avltree.NewWithIntComparator() + +// l := loadTestData() +// b.N = len(l) + +// for _, v := range l { +// tree.Put(v, v) +// } + +// b.ResetTimer() +// b.StartTimer() +// iter := tree.Iterator() +// for iter.Next() { +// } +// for iter.Prev() { +// } +// for iter.Next() { +// } +// } + +// func BenchmarkRemove(b *testing.B) { +// tree := New(utils.IntComparator) + +// l := loadTestData() + +// b.N = len(l) +// for _, v := range l { +// tree.Put(v) +// } + +// b.ResetTimer() +// b.StartTimer() + +// for i := 0; i < len(l); i++ { +// tree.Remove(l[i]) +// } +// } + +// func BenchmarkGodsRemove(b *testing.B) { +// tree := avltree.NewWithIntComparator() + +// l := loadTestData() + +// b.N = len(l) +// for _, v := range l { +// tree.Put(v, v) +// } + +// b.ResetTimer() +// b.StartTimer() + +// for i := 0; i < len(l); i++ { +// tree.Remove(l[i]) +// } +// } + +// func BenchmarkGodsRBRemove(b *testing.B) { +// tree := redblacktree.NewWithIntComparator() + +// l := loadTestData() + +// b.N = len(l) +// for _, v := range l { +// tree.Put(v, v) +// } + +// b.ResetTimer() +// b.StartTimer() + +// for i := 0; i < len(l); i++ { +// tree.Remove(l[i]) +// } +// } + +func BenchmarkGet(b *testing.B) { + + avl := New(utils.IntComparator) + + l := loadTestData() + b.N = len(l) + + b.ResetTimer() + b.StartTimer() + for i := 0; i < b.N; i++ { + avl.Get(l[i]) + } +} + +// func BenchmarkGodsRBGet(b *testing.B) { +// tree := redblacktree.NewWithIntComparator() + +// l := loadTestData() +// b.N = len(l) + +// b.ResetTimer() +// b.StartTimer() +// for i := 0; i < b.N; i++ { +// tree.Get(l[i]) +// } +// } + +// func BenchmarkGodsAvlGet(b *testing.B) { +// tree := avltree.NewWithIntComparator() + +// l := loadTestData() +// b.N = len(l) + +// b.ResetTimer() +// b.StartTimer() +// for i := 0; i < b.N; i++ { +// tree.Get(l[i]) +// } +// } + +func BenchmarkPut(b *testing.B) { + avl := New(utils.IntComparator) + + l := loadTestData() + + b.ResetTimer() + b.StartTimer() + + b.N = len(l) + for _, v := range l { + avl.Put(v) + } +} + +func TestPutStable(t *testing.T) { + + // l := []int{14, 18, 19, 20, 5, 6, 7, 8, 9, 21, 22, 30, 41, 41, 41, 0, 1, 2, 3, 4, 10, 11, 12, 13} + var l []int + for i := 0; len(l) < 100; i++ { + l = append(l, randomdata.Number(0, 65535)) + } + + for i := 0; len(l) < 1000; i++ { + l = append(l, randomdata.Number(70, 100)) + } + + avl := New(utils.IntComparator) + for _, v := range l { + avl.Put(v) + } + + t.Error(len(l), avl.debugString(), avl.TraversalBreadth(), "\n", "-----------") + +} +func BenchmarkGodsRBPut(b *testing.B) { + tree := redblacktree.NewWithIntComparator() + + l := loadTestData() + + b.ResetTimer() + b.StartTimer() + + b.N = len(l) + for _, v := range l { + tree.Put(v, v) + } +} + +func BenchmarkGodsPut(b *testing.B) { + tree := avltree.NewWithIntComparator() + + l := loadTestData() + + b.ResetTimer() + b.StartTimer() + + b.N = len(l) + for _, v := range l { + tree.Put(v, v) + } +} diff --git a/avlindex/iterator.go b/avlindex/iterator.go new file mode 100644 index 0000000..d15658a --- /dev/null +++ b/avlindex/iterator.go @@ -0,0 +1,156 @@ +package avlindex + +import ( + "474420502.top/eson/structure/lastack" +) + +type Iterator struct { + op *Tree + + dir int + up *Node + cur *Node + tstack *lastack.Stack + // curnext *Node +} + +func initIterator(avltree *Tree) *Iterator { + iter := &Iterator{op: avltree, tstack: lastack.New()} + iter.up = avltree.root + return iter +} + +func NewIterator(tree *Tree) *Iterator { + return initIterator(tree) +} + +func (iter *Iterator) Value() interface{} { + return iter.cur.value +} + +func (iter *Iterator) Left() bool { + if iter.cur.children[0] != nil { + iter.dir = 0 + iter.cur = iter.cur.children[0] + return true + } + return false +} + +func (iter *Iterator) Right() bool { + if iter.cur.children[1] != nil { + iter.dir = 0 + iter.cur = iter.cur.children[1] + return true + } + return false +} + +func (iter *Iterator) Prev() (result bool) { + + if iter.dir > -1 { + if iter.dir == 1 && iter.cur != nil { + iter.tstack.Clear() + iter.curPushPrevStack(iter.cur) + iter.up = iter.getPrevUp(iter.cur) + } + iter.dir = -1 + } + + if iter.tstack.Size() == 0 { + if iter.up == nil { + return false + } + iter.tstack.Push(iter.up) + iter.up = iter.getPrevUp(iter.up) + } + + if v, ok := iter.tstack.Pop(); ok { + iter.cur = v.(*Node) + iter.curPushPrevStack(iter.cur) + return true + } + + return false +} + +func (iter *Iterator) Next() (result bool) { + + if iter.dir < 1 { // 非 1(next 方向定义 -1 为 prev) + if iter.dir == -1 && iter.cur != nil { // 如果上次为prev方向, 则清空辅助计算的栈 + iter.tstack.Clear() + iter.curPushNextStack(iter.cur) // 把当前cur计算的逆向回朔 + iter.up = iter.getNextUp(iter.cur) // cur 寻找下个要计算up + } + iter.dir = 1 + } + + // 如果栈空了, 把up的递归计算入栈, 重新计算 下次的up值 + if iter.tstack.Size() == 0 { + if iter.up == nil { + return false + } + iter.tstack.Push(iter.up) + iter.up = iter.getNextUp(iter.up) + } + + if v, ok := iter.tstack.Pop(); ok { + iter.cur = v.(*Node) + iter.curPushNextStack(iter.cur) + return true + } + + // 如果再次计算的栈为空, 则只能返回false + return false +} + +func getRelationship(cur *Node) int { + if cur.parent.children[1] == cur { + return 1 + } + return 0 +} + +func (iter *Iterator) getNextUp(cur *Node) *Node { + for cur.parent != nil { + if getRelationship(cur) == 1 { // next 在 降序 小值. 如果child在右边, parent 比 child 小, parent才有效, 符合降序 + return cur.parent + } + cur = cur.parent + } + return nil +} + +func (iter *Iterator) curPushNextStack(cur *Node) { + next := cur.children[0] // 当前的左然后向右找, 找到最大, 就是最接近cur 并且小于cur的值 + + if next != nil { + iter.tstack.Push(next) + for next.children[1] != nil { + next = next.children[1] + iter.tstack.Push(next) // 入栈 用于回溯 + } + } +} + +func (iter *Iterator) getPrevUp(cur *Node) *Node { + for cur.parent != nil { + if getRelationship(cur) == 0 { // Prev 在 降序 大值. 如果child在左边, parent 比 child 大, parent才有效 , 符合降序 + return cur.parent + } + cur = cur.parent + } + return nil +} + +func (iter *Iterator) curPushPrevStack(cur *Node) { + prev := cur.children[1] + + if prev != nil { + iter.tstack.Push(prev) + for prev.children[0] != nil { + prev = prev.children[0] + iter.tstack.Push(prev) + } + } +}