diff --git a/priority_queue/avltree.go b/priority_queue/avltree.go index 4d80061..db9a559 100644 --- a/priority_queue/avltree.go +++ b/priority_queue/avltree.go @@ -22,17 +22,17 @@ func assertTreeImplementation() { // Tree holds elements of the AVL tree. type Tree struct { - Root *ANode // Root node + Root *AvlNode // Root node Comparator utils.Comparator // Key comparator size int // Total number of keys in the tree } -// ANode is a single element within the tree -type ANode struct { +// AvlNode is a single element within the tree +type AvlNode struct { Key interface{} Value interface{} - Parent *ANode // Parent node - Children [2]*ANode // Children nodes + Parent *AvlNode // Parent node + Children [2]*AvlNode // Children nodes b int8 } @@ -53,8 +53,9 @@ func NewWithStringComparator() *Tree { // Put inserts node into the tree. // Key should adhere to the comparator's type assertion, otherwise method panics. -func (t *Tree) Put(key interface{}, value interface{}) (bool, *ANode) { - return t.put(key, value, nil, &t.Root) +func (t *Tree) Put(key interface{}, value interface{}) (putNode *AvlNode) { + _, putNode = t.put(key, value, nil, &t.Root) + return } // Get searches the node in the tree by key and returns its value or nil if key is not found in tree. @@ -114,13 +115,13 @@ func (t *Tree) Values() []interface{} { // Left returns the minimum element of the AVL tree // or nil if the tree is empty. -func (t *Tree) Left() *ANode { +func (t *Tree) Left() *AvlNode { return t.bottom(0) } // Right returns the maximum element of the AVL tree // or nil if the tree is empty. -func (t *Tree) Right() *ANode { +func (t *Tree) Right() *AvlNode { return t.bottom(1) } @@ -132,7 +133,7 @@ func (t *Tree) Right() *ANode { // all nodes in the tree is larger than the given node. // // Key should adhere to the comparator's type assertion, otherwise method panics. -func (t *Tree) Floor(key interface{}) (floor *ANode, found bool) { +func (t *Tree) Floor(key interface{}) (floor *AvlNode, found bool) { found = false n := t.Root last := n @@ -164,7 +165,7 @@ func (t *Tree) Floor(key interface{}) (floor *ANode, found bool) { // all nodes in the tree is smaller than the given node. // // Key should adhere to the comparator's type assertion, otherwise method panics. -func (t *Tree) Ceiling(key interface{}) (floor *ANode, found bool) { +func (t *Tree) Ceiling(key interface{}) (floor *AvlNode, found bool) { found = false n := t.Root last := n @@ -202,15 +203,15 @@ func (t *Tree) String() string { return str } -func (n *ANode) String() string { +func (n *AvlNode) String() string { return fmt.Sprintf("%v", n.Key) } -func (t *Tree) put(key interface{}, value interface{}, p *ANode, qp **ANode) (bool, *ANode) { +func (t *Tree) put(key interface{}, value interface{}, p *AvlNode, qp **AvlNode) (bool, *AvlNode) { q := *qp if q == nil { t.size++ - *qp = &ANode{Key: key, Value: value, Parent: p} + *qp = &AvlNode{Key: key, Value: value, Parent: p} return true, *qp } @@ -230,12 +231,12 @@ func (t *Tree) put(key interface{}, value interface{}, p *ANode, qp **ANode) (bo var fix bool fix, node := t.put(key, value, q, &q.Children[a]) if fix { - return putFix(int8(c), qp), *qp + return putFix(int8(c), qp), node } - return false, q + return false, node } -func (t *Tree) remove(key interface{}, qp **ANode) bool { +func (t *Tree) remove(key interface{}, qp **AvlNode) bool { q := *qp if q == nil { return false @@ -271,7 +272,7 @@ func (t *Tree) remove(key interface{}, qp **ANode) bool { return false } -func removeMin(qp **ANode, minKey *interface{}, minVal *interface{}) bool { +func removeMin(qp **AvlNode, minKey *interface{}, minVal *interface{}) bool { q := *qp if q.Children[0] == nil { *minKey = q.Key @@ -289,7 +290,7 @@ func removeMin(qp **ANode, minKey *interface{}, minVal *interface{}) bool { return false } -func putFix(c int8, t **ANode) bool { +func putFix(c int8, t **AvlNode) bool { s := *t if s.b == 0 { s.b = c @@ -310,7 +311,7 @@ func putFix(c int8, t **ANode) bool { return false } -func removeFix(c int8, t **ANode) bool { +func removeFix(c int8, t **AvlNode) bool { s := *t if s.b == 0 { s.b = c @@ -339,14 +340,14 @@ func removeFix(c int8, t **ANode) bool { return true } -func singlerot(c int8, s *ANode) *ANode { +func singlerot(c int8, s *AvlNode) *AvlNode { s.b = 0 s = rotate(c, s) s.b = 0 return s } -func doublerot(c int8, s *ANode) *ANode { +func doublerot(c int8, s *AvlNode) *AvlNode { a := (c + 1) / 2 r := s.Children[a] s.Children[a] = rotate(-c, s.Children[a]) @@ -368,7 +369,7 @@ func doublerot(c int8, s *ANode) *ANode { return p } -func rotate(c int8, s *ANode) *ANode { +func rotate(c int8, s *AvlNode) *AvlNode { a := (c + 1) / 2 r := s.Children[a] s.Children[a] = r.Children[a^1] @@ -381,7 +382,7 @@ func rotate(c int8, s *ANode) *ANode { return r } -func (t *Tree) bottom(d int) *ANode { +func (t *Tree) bottom(d int) *AvlNode { n := t.Root if n == nil { return nil @@ -395,17 +396,17 @@ func (t *Tree) bottom(d int) *ANode { // Prev returns the previous element in an inorder // walk of the AVL tree. -func (n *ANode) Prev() *ANode { +func (n *AvlNode) Prev() *AvlNode { return n.walk1(0) } // Next returns the next element in an inorder // walk of the AVL tree. -func (n *ANode) Next() *ANode { +func (n *AvlNode) Next() *AvlNode { return n.walk1(1) } -func (n *ANode) walk1(a int) *ANode { +func (n *AvlNode) walk1(a int) *AvlNode { if n == nil { return nil } @@ -426,7 +427,7 @@ func (n *ANode) walk1(a int) *ANode { return p } -func output(node *ANode, prefix string, isTail bool, str *string) { +func output(node *AvlNode, prefix string, isTail bool, str *string) { if node.Children[1] != nil { newPrefix := prefix if isTail { diff --git a/priority_queue/iterator.go b/priority_queue/iterator.go index da9d9cb..42570c5 100644 --- a/priority_queue/iterator.go +++ b/priority_queue/iterator.go @@ -13,7 +13,7 @@ func assertIteratorImplementation() { // Iterator holding the iterator's state type Iterator struct { tree *Tree - node *ANode + node *AvlNode position position } diff --git a/priority_queue/priority_queue_test.go b/priority_queue/priority_queue_test.go index 2521427..2d8ea31 100644 --- a/priority_queue/priority_queue_test.go +++ b/priority_queue/priority_queue_test.go @@ -78,10 +78,13 @@ import ( func TestAVL(t *testing.T) { avl := NewWithIntComparator() - for i := 0; i < 100; i++ { + for i := 0; i < 100000; i++ { v := randomdata.Number(0, 100) - ok, n := avl.Put(v, v) - t.Error(v, ok, n) + n := avl.Put(v, v) + if v != n.Value.(int) { + t.Error(v, n) + } + } t.Error(avl.Values()) f, ok := avl.Ceiling(1000)